From f2c8bab8531da961de09e4bf808695a4a269bb75 Mon Sep 17 00:00:00 2001 From: huidesheng <1832140001@qq.com> Date: Tue, 12 May 2026 17:45:29 +0800 Subject: [PATCH 1/4] add chunkprefill and prefill cuda graph --- csrc/engine/compiler/general_compiler.cpp | 12 ++- csrc/engine/compiler/general_compiler.hpp | 5 +- csrc/engine/infer_engine.cpp | 4 + csrc/engine/infer_engine.hpp | 2 + csrc/engine/rank_worker.cpp | 6 +- csrc/engine/rank_worker.hpp | 3 + csrc/pybind11/engine/engine.hpp | 6 ++ python/infinilm/base_config.py | 4 + python/infinilm/infer_engine.py | 2 + python/infinilm/llm/llm.py | 44 ++++++++++- python/infinilm/llm/request.py | 17 +++++ python/infinilm/llm/scheduler.py | 41 +++++++++- .../processors/basic_llm_processor.py | 40 +++++++--- python/infinilm/server/inference_server.py | 14 ++++ scripts/infer_task.py | 21 ++++++ scripts/launch_server.py | 75 +++++++++++++++---- 16 files changed, 266 insertions(+), 30 deletions(-) diff --git a/csrc/engine/compiler/general_compiler.cpp b/csrc/engine/compiler/general_compiler.cpp index 84ee670d..36c6420f 100644 --- a/csrc/engine/compiler/general_compiler.cpp +++ b/csrc/engine/compiler/general_compiler.cpp @@ -1,13 +1,18 @@ #include "general_compiler.hpp" namespace infinilm::engine { -GeneralCompiler::GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier) : GraphCompiler(model, barrier) { +GeneralCompiler::GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier, bool enable_chunk_prefill_graph) + : GraphCompiler(model, barrier), enable_chunk_prefill_graph_(enable_chunk_prefill_graph) { static_batching_compiler_ = std::make_unique(model_, barrier); + chunk_prefill_compiler_ = std::make_unique(model_, barrier); paged_compiler_ = std::make_unique(model_, barrier); } void GeneralCompiler::compile() { static_batching_compiler_->compile(); + if (enable_chunk_prefill_graph_) { + chunk_prefill_compiler_->compile(); + } paged_compiler_->compile(); } @@ -19,6 +24,11 @@ GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Inp if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) { return result; } + // chunk-prefill must be checked before decode (decode would also match if chunk_size==1) + result = chunk_prefill_compiler_.get()->get_compiled(input); + if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) { + return result; + } result = paged_compiler_.get()->get_compiled(input); return result; } diff --git a/csrc/engine/compiler/general_compiler.hpp b/csrc/engine/compiler/general_compiler.hpp index e8b84b5d..3edbcea0 100644 --- a/csrc/engine/compiler/general_compiler.hpp +++ b/csrc/engine/compiler/general_compiler.hpp @@ -1,12 +1,13 @@ #pragma once +#include "chunk_prefill_compiler.hpp" #include "paged_compiler.hpp" #include "static_batching_compiler.hpp" namespace infinilm::engine { class GeneralCompiler : public GraphCompiler { public: - GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier); + GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier, bool enable_chunk_prefill_graph = false); void compile() override; @@ -15,5 +16,7 @@ class GeneralCompiler : public GraphCompiler { private: std::unique_ptr static_batching_compiler_; std::unique_ptr paged_compiler_; + std::unique_ptr chunk_prefill_compiler_; + bool enable_chunk_prefill_graph_; }; } // namespace infinilm::engine diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index db0dfdd4..5b6ea143 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -25,6 +25,7 @@ InferEngine::InferEngine( infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend) // Changed parameter : communication_group_(distributed_config, device_type), legacy_model_config_(config), @@ -43,6 +44,7 @@ InferEngine::InferEngine( cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), enable_graph_compiling, + enable_chunk_prefill_graph, attention_backend_)); } @@ -56,6 +58,7 @@ InferEngine::InferEngine( infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend, std::optional kv_cache_dtype) // Changed parameter : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { @@ -82,6 +85,7 @@ InferEngine::InferEngine( cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), enable_graph_compiling, + enable_chunk_prefill_graph, attention_backend_)); } // Compile the model on all workers diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index e36ec369..153600c4 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -39,6 +39,7 @@ class InferEngine { infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, + bool enable_chunk_prefill_graph = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); InferEngine( @@ -47,6 +48,7 @@ class InferEngine { infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, + bool enable_chunk_prefill_graph = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, std::optional kv_cache_dtype = std::nullopt); diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 8a94c441..e607c569 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -27,11 +27,13 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config, const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend) : legacy_model_config_(model_config), rank_info_(rank_info), attention_backend_(attention_backend), enable_graph_compiling_(enable_graph_compiling), + enable_chunk_prefill_graph_(enable_chunk_prefill_graph), job_cmd_(Command::INIT), has_job_(false), job_done_(false), @@ -56,12 +58,14 @@ RankWorker::RankWorker( const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend) : infinilm_config_(infinilm_config), model_config_(infinilm_config->model_config), rank_info_(rank_info), attention_backend_(attention_backend), enable_graph_compiling_(enable_graph_compiling), + enable_chunk_prefill_graph_(enable_chunk_prefill_graph), job_cmd_(Command::INIT), has_job_(false), job_done_(false), @@ -303,7 +307,7 @@ void RankWorker::thread_loop() { throw std::runtime_error("Failed to create model"); } if (enable_graph_compiling_) { - compiler_ = std::make_unique(model_, barrier_); + compiler_ = std::make_unique(model_, barrier_, enable_chunk_prefill_graph_); } init_done_ = true; diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index f6adcf47..b045adf6 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -75,6 +75,7 @@ class RankWorker { const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend); RankWorker(std::shared_ptr infinilm_config, @@ -82,6 +83,7 @@ class RankWorker { const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, backends::AttentionBackend attention_backend); // Submit a parameter load job and wait until the load completes on the worker thread. @@ -131,6 +133,7 @@ class RankWorker { // Graph Compiling bool enable_graph_compiling_; + bool enable_chunk_prefill_graph_; std::unique_ptr compiler_; // Command for the pending job (protected by mutex_) diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 2741c9cd..a479f66b 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -37,6 +37,7 @@ inline void bind_infer_engine(py::module &m) { infinicore::Device::Type dev, std::shared_ptr cache_cfg, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, const std::string &attention_backend) { return std::make_shared( cfg, @@ -44,6 +45,7 @@ inline void bind_infer_engine(py::module &m) { dev, cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, + enable_chunk_prefill_graph, infinilm::backends::parse_attention_backend(attention_backend)); }), py::arg("config"), @@ -51,6 +53,7 @@ inline void bind_infer_engine(py::module &m) { py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, + py::arg("enable_chunk_prefill_graph") = false, py::arg("attention_backend") = "default") .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), @@ -81,6 +84,7 @@ inline void bind_infer_engine(py::module &m) { infinicore::Device::Type dev, std::shared_ptr cache_cfg, bool enable_graph_compiling, + bool enable_chunk_prefill_graph, const std::string &attention_backend, std::optional kv_cache_dtype) { return std::make_shared( @@ -89,6 +93,7 @@ inline void bind_infer_engine(py::module &m) { dev, cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, + enable_chunk_prefill_graph, infinilm::backends::parse_attention_backend(attention_backend), kv_cache_dtype); }), @@ -97,6 +102,7 @@ inline void bind_infer_engine(py::module &m) { py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, + py::arg("enable_chunk_prefill_graph") = false, py::arg("attention_backend") = "default", py::arg("kv_cache_dtype") = py::none()) .def("load_param", &InferEngine::load_param, diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index aab5dd45..d7c32568 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -61,6 +61,8 @@ def __init__(self): self.attn = self.args.attn self.enable_graph = self.args.enable_graph + self.enable_chunk_prefill_graph = self.args.enable_chunk_prefill_graph + self.chunk_size = self.args.chunk_size self.enable_paged_attn = self.args.enable_paged_attn self.num_blocks = self.args.num_blocks self.block_size = self.args.block_size @@ -122,6 +124,8 @@ def _add_common_args(self): choices=["default", "paged-attn", "flash-attn"], ) self.parser.add_argument("--enable-graph", action="store_true") + self.parser.add_argument("--enable-chunk-prefill-graph", action="store_true", help="enable chunk-prefill graph compiling") + self.parser.add_argument("--chunk-size", type=int, default=512, help="tokens per chunked-prefill slice (0 to disable)") self.parser.add_argument( "--enable-paged-attn", action="store_true", diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 13bb18a1..2477bbc6 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -45,6 +45,7 @@ def __init__( distributed_config=DistConfig(1), cache_config=None, enable_graph_compiling=False, + enable_chunk_prefill_graph=False, attention_backend="default", kv_cache_dtype=None, ): @@ -60,6 +61,7 @@ def __init__( device._underlying.type, cache_config, enable_graph_compiling, + enable_chunk_prefill_graph, attention_backend, ( parse_dtype(kv_cache_dtype)._underlying diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index cba3af83..90de3edc 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -72,6 +72,8 @@ class EngineConfig: top_p: float = 0.8 top_k: int = 1 enable_graph: bool = False + enable_chunk_prefill_graph: bool = False + chunk_size: int = 0 attn_backend: str = "default" skip_load: bool = False @@ -91,6 +93,7 @@ def __init__(self, config: EngineConfig): device=self.device, distributed_config=DistConfig(config.tensor_parallel_size), enable_graph_compiling=config.enable_graph, + enable_chunk_prefill_graph=config.enable_chunk_prefill_graph, attention_backend=config.attn_backend, ) @@ -167,6 +170,8 @@ def _init_device(self): def add_request(self, request: InferenceRequest): """Add a request to the scheduler.""" + if self.cache_type == "paged" and self.config.chunk_size > 0: + request.chunk_size = self.config.chunk_size self.scheduler.add_request(request) def step(self) -> tuple[list[InferenceRequest], list[tuple]]: @@ -210,7 +215,18 @@ def _update_requests( sampled_tokens: List[int], ) -> List[tuple]: """Update request status after inference step.""" - if is_prefill: + # Detect a chunked-prefill mid-step: single request, prefill phase, + # and this chunk does not yet cover the whole prompt. In that case + # we must NOT consume a sampled token, NOT commit prefill blocks, + # and re-enqueue the request to keep chunking. + chunk_mid_step = ( + is_prefill + and len(requests) == 1 + and requests[0].is_chunking() + and not requests[0].chunk_is_last() + ) + + if is_prefill and not chunk_mid_step: match self.cache_type: case "paged": self.scheduler.cache_manager.reset_req_blocks() @@ -218,6 +234,20 @@ def _update_requests( self.scheduler.update_cache() case _: raise ValueError(f"Unsupported cache_type: {self.cache_type}") + + if chunk_mid_step: + req = requests[0] + req.chunk_prefill_offset += req.chunk_size + # If this request was aborted while chunking, drop it. + if req.is_aborted(): + logger.info( + f"Request {req.request_id} aborted by client during chunked-prefill" + ) + return [] + # Re-enqueue to keep producing chunks; no token sampled yet. + self.scheduler.requeue_chunking(req) + return [] + pending = [] for req, token_id in zip(requests, sampled_tokens): if req.is_aborted(): @@ -227,6 +257,10 @@ def _update_requests( continue if req.is_prefill: + # Clean up chunked-prefill state on the final chunk so the + # next forward pass on this request takes the decode path. + req.chunk_prefill_offset = 0 + req.chunk_size = 0 req.is_prefill = False req.generated_token_ids.append(token_id) @@ -361,6 +395,8 @@ def __init__( top_p: float = 0.8, top_k: int = 1, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", skip_load: bool = False, ): @@ -398,6 +434,8 @@ def __init__( top_p=top_p, top_k=top_k, enable_graph=enable_graph, + enable_chunk_prefill_graph=enable_chunk_prefill_graph, + chunk_size=chunk_size, attn_backend=attn_backend, skip_load=skip_load, ) @@ -539,6 +577,8 @@ def __init__( top_p: float = 0.8, top_k: int = 1, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", ): """Initialize AsyncLLMEngine. @@ -575,6 +615,8 @@ def __init__( top_p=top_p, top_k=top_k, enable_graph=enable_graph, + enable_chunk_prefill_graph=enable_chunk_prefill_graph, + chunk_size=chunk_size, attn_backend=attn_backend, ) self.engine = LLMEngine(config) diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index 15bcf69f..679b6e4d 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -144,6 +144,11 @@ def __init__( self.num_cached_tokens: int = 0 self.num_blocks: int = 0 + # Chunked-prefill state (0 = disabled, otherwise tokens per chunk) + self.chunk_size: int = 0 + # Number of prompt tokens already fed through forward as chunked-prefill + self.chunk_prefill_offset: int = 0 + # For server use self.request_data: Optional[dict] = request_data self.http_request: Optional[Any] = http_request @@ -186,6 +191,18 @@ def get_num_blocks_required(self, block_size: int) -> int: def get_max_tokens(self) -> Optional[int]: return self.sampling_params.max_tokens + def is_chunking(self) -> bool: + """Return True if this request is in the middle of chunked-prefill.""" + return ( + self.chunk_size > 0 + and self.is_prefill + and self.prompt_length > self.chunk_size + ) + + def chunk_is_last(self) -> bool: + """Return True if the next chunk would finish the prompt.""" + return self.chunk_prefill_offset + self.chunk_size >= self.prompt_length + def is_finished(self) -> bool: return self.status in [ RequestStatus.FINISHED, diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index f9c11635..95a84480 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -42,6 +42,9 @@ def __init__( ): self.waiting_queue = janus.Queue() self.running_queue = janus.Queue() + # Requests in the middle of chunked-prefill — scheduled at high priority, + # single-request batches only (to match the C++ ChunkPrefillCompiler graph signature). + self.chunking_queue = janus.Queue() self.max_batch_size = max_batch_size self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size) @@ -53,7 +56,27 @@ def add_request(self, request: InferenceRequest): self.waiting_queue.sync_q.put(request) def schedule(self) -> Optional[SchedulerOutput]: - """Schedule and return batch of requests to execute.""" + """Schedule and return batch of requests to execute. + + Priority (mirrors launch_server.py chunked-prefill scheduling): + 1. Running queue (decode) — short / latency-sensitive + 2. Chunking queue (in-flight chunked-prefill) — single-request slice + 3. Waiting queue (new prefill) — may start chunking if prompt is long + """ + # 2) Continue an in-flight chunked-prefill request (single-request batch). + try: + req = self.chunking_queue.sync_q.get_nowait() + except queue.Empty: + req = None + if req is not None: + if req.is_finished(): + self.complete_requests([req]) + else: + return SchedulerOutput( + scheduled_requests=[req], + is_prefill=True, + ) + scheduled_requests = [] is_prefill = False @@ -91,6 +114,18 @@ def schedule(self) -> Optional[SchedulerOutput]: req.num_blocks = len(req.block_table) req.status = RequestStatus.RUNNING + + # Start chunked-prefill: enqueue into chunking_queue and emit a + # single-request batch immediately. We don't mix chunked-prefill + # with other requests in the same batch — the C++ ChunkPrefillCompiler + # graph is keyed on (batch_size, chunk_size). + if req.chunk_size > 0 and req.prompt_length > req.chunk_size: + req.chunk_prefill_offset = 0 + return SchedulerOutput( + scheduled_requests=[req], + is_prefill=True, + ) + scheduled_requests.append(req) # Return prefill batch if any waiting requests were scheduled @@ -135,6 +170,10 @@ def schedule(self) -> Optional[SchedulerOutput]: return None + def requeue_chunking(self, req: InferenceRequest): + """Put a request back into the chunking queue after a chunk has run.""" + self.chunking_queue.sync_q.put(req) + def complete_requests(self, requests: List[InferenceRequest]): """Handle completed requests and free their blocks.""" for req in requests: diff --git a/python/infinilm/processors/basic_llm_processor.py b/python/infinilm/processors/basic_llm_processor.py index 070a4062..f5e603ba 100644 --- a/python/infinilm/processors/basic_llm_processor.py +++ b/python/infinilm/processors/basic_llm_processor.py @@ -185,19 +185,39 @@ def _build_model_input_from_batch_scheduler_output( if scheduler_output.is_prefill: # Prefill phase req_tokens = req.get_input_tokens() - tokens_to_compute = req_tokens[num_cached:] - tokens.extend(tokens_to_compute) - compute_len = len(tokens_to_compute) - seq_len = len(req_tokens) - seq_lens.append(seq_len) + # Chunked-prefill: only feed [chunk_prefill_offset : +chunk_size). + # past_kv_lengths = chunk_prefill_offset (attention sees the prefix + # already committed); total_kv_lengths = chunk_prefill_offset + + # len(tokens_to_compute). This keeps batch_size=1 and total_tokens + # == chunk_size so the C++ ChunkPrefillCompiler graph hits. + if req.is_chunking(): + start = req.chunk_prefill_offset + end = min(start + req.chunk_size, len(req_tokens)) + tokens_to_compute = req_tokens[start:end] + compute_len = len(tokens_to_compute) + tokens.extend(tokens_to_compute) + seq_len = end # attention prefix length after this chunk + seq_lens.append(seq_len) + current_offset += compute_len + seq_offsets.append(current_offset) + slot_mapping.extend(req.slot_mapping[start:end]) + cached_lens.append(start) + position_ids.extend(range(start, end)) + else: + tokens_to_compute = req_tokens[num_cached:] + tokens.extend(tokens_to_compute) - current_offset += compute_len - seq_offsets.append(current_offset) + compute_len = len(tokens_to_compute) + seq_len = len(req_tokens) + seq_lens.append(seq_len) - slot_mapping.extend(req.slot_mapping) - cached_lens.append(num_cached) - position_ids.extend(range(num_cached, num_cached + compute_len)) + current_offset += compute_len + seq_offsets.append(current_offset) + + slot_mapping.extend(req.slot_mapping) + cached_lens.append(num_cached) + position_ids.extend(range(num_cached, num_cached + compute_len)) else: # Decode phase diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index 71e9c992..ac7e94e7 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -108,6 +108,8 @@ def __init__( host: str = "0.0.0.0", port: int = 8000, enable_graph: bool = False, + enable_chunk_prefill_graph: bool = False, + chunk_size: int = 0, attn_backend: str = "default", ignore_eos: bool = False, ): @@ -130,6 +132,10 @@ def __init__( host: Server host address. port: Server port number. enable_graph: Whether to enable graph compiling. + enable_chunk_prefill_graph: Whether to enable chunk-prefill graph compiling. + chunk_size: Tokens per chunked-prefill slice (0 = disabled). When > 0 and paged + cache is used, long prompts are sliced and each slice goes through forward + separately so the C++ ChunkPrefillCompiler precompiled graph can be reused. attn_backend: Attention backend to use ('default', 'flash-attn'). """ self.model_path = model_path @@ -150,6 +156,8 @@ def __init__( self.host = host self.port = port self.enable_graph = enable_graph + self.enable_chunk_prefill_graph = enable_chunk_prefill_graph + self.chunk_size = chunk_size self.attn_backend = attn_backend self.ignore_eos = ignore_eos @@ -182,11 +190,15 @@ async def lifespan(app: FastAPI): top_p=self.top_p, top_k=self.top_k, enable_graph=self.enable_graph, + enable_chunk_prefill_graph=self.enable_chunk_prefill_graph, + chunk_size=self.chunk_size, attn_backend=self.attn_backend, ) self.engine.start() logger.info(f"Engine initialized with model at {self.model_path}") logger.info(f" enable_graph: {self.enable_graph}") + logger.info(f" enable_chunk_prefill_graph: {self.enable_chunk_prefill_graph}") + logger.info(f" chunk_size: {self.chunk_size}") yield self.engine.stop() @@ -572,6 +584,8 @@ def main(): host=cfg.host, port=cfg.port, enable_graph=cfg.enable_graph, + enable_chunk_prefill_graph=cfg.enable_chunk_prefill_graph, + chunk_size=cfg.chunk_size, attn_backend=cfg.attn, ignore_eos=cfg.ignore_eos, ) diff --git a/scripts/infer_task.py b/scripts/infer_task.py index 0d1231b7..1851f0a0 100644 --- a/scripts/infer_task.py +++ b/scripts/infer_task.py @@ -10,6 +10,8 @@ def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): self.end_tokens = end_tokens self._kv_cache = None self.pos = 0 + self._discard_output = False + self._remaining_tokens = None def bind_kvcache(self, kv_cache, pos=0): self._kv_cache = kv_cache @@ -24,6 +26,25 @@ def release_kvcache(self): def kvcache(self): return self._kv_cache + def setup_chunked_prefill(self, chunk_size): + if chunk_size <= 0 or len(self.tokens) <= chunk_size: + return + self._remaining_tokens = self.tokens[chunk_size:] + self.tokens = self.tokens[:chunk_size] + self._discard_output = True + + def advance_prefill_chunk(self, chunk_size): + self._kv_cache.update_tokens(self.tokens, self.pos) + self.pos += len(self.tokens) + + if len(self._remaining_tokens) <= chunk_size: + self.tokens = self._remaining_tokens + self._remaining_tokens = None + self._discard_output = False + else: + self.tokens = self._remaining_tokens[:chunk_size] + self._remaining_tokens = self._remaining_tokens[chunk_size:] + def next(self, out_token): self._kv_cache.update_tokens(self.tokens, self.pos) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index d04d4f69..0639a28b 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -64,6 +64,13 @@ def parse_args(): default=None, help="Max token sequence length that model will handle (follows model config if not provided)", ) + parser.add_argument( + "--chunk-size", + type=int, + default=512, + help="Maximum number of tokens per prefill chunk (default: 512). " + "Set to 0 to disable chunked prefill.", + ) parser.add_argument( "--awq", action="store_true", @@ -86,8 +93,10 @@ def parse_args(): USE_AWQ = args.awq USE_GPTQ = args.gptq MAX_BATCH = args.max_batch +CHUNK_SIZE = args.chunk_size print( - f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." + f"Using MAX_BATCH={MAX_BATCH}, CHUNK_SIZE={CHUNK_SIZE}. " + f"Try reduce these values if out of memory error occurs." ) @@ -163,32 +172,66 @@ async def lifespan(app: FastAPI): # App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. +# Uses priority scheduling: decode/short tasks first, then prefill chunks. def worker_loop(app): + pending_prefill = [] # Low priority: chunked prefill tasks + while True: + # Drain all available tasks from the queue + incoming = [] try: task = app.state.request_queue.sync_q.get(timeout=0.01) + if task is None: + return + incoming.append(task) except queue.Empty: - continue - - if task is None: - return + pass - batch = [task] - while len(batch) < MAX_BATCH: + while True: try: - req = app.state.request_queue.sync_q.get_nowait() - if req is not None: - batch.append(req) + task = app.state.request_queue.sync_q.get_nowait() + if task is None: + return + incoming.append(task) except queue.Empty: break + + # Separate into high priority (decode/new short) and low priority (prefill chunks) + high_priority = [] + for t in incoming: + if t._discard_output: + pending_prefill.append(t) + else: + high_priority.append(t) + + # Build batch: high priority first, then fill with prefill chunks + batch = [] + while high_priority and len(batch) < MAX_BATCH: + batch.append(high_priority.pop(0)) + while pending_prefill and len(batch) < MAX_BATCH: + batch.append(pending_prefill.pop(0)) + + if not batch: + continue + output_tokens = app.state.model.batch_infer_one_round(batch) for task, token in zip(batch, output_tokens): - task.output(token) - if task.finish_reason is None: - app.state.request_queue.sync_q.put(task) + if task._discard_output: + task.advance_prefill_chunk(CHUNK_SIZE) + if task.finish_reason is None: + if task._discard_output: + pending_prefill.append(task) + else: + app.state.request_queue.sync_q.put(task) + else: + app.state.kv_cache_pool.release_sync(task) else: - print(f"[INFO] Task {task.id} finished infer.") - app.state.kv_cache_pool.release_sync(task) + task.output(token) + if task.finish_reason is None: + app.state.request_queue.sync_q.put(task) + else: + print(f"[INFO] Task {task.id} finished infer.") + app.state.kv_cache_pool.release_sync(task) def build_task(id_, request_data, request: Request): @@ -214,6 +257,7 @@ async def chat_stream(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) await request.app.state.kv_cache_pool.acquire(infer_task) + infer_task.setup_chunked_prefill(CHUNK_SIZE) # Initial empty content chunk = json.dumps( @@ -255,6 +299,7 @@ async def chat(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) await request.app.state.kv_cache_pool.acquire(infer_task) + infer_task.setup_chunked_prefill(CHUNK_SIZE) request.app.state.request_queue.sync_q.put(infer_task) output = [] while True: From bb68ca563604f079cfd35d3f5509b8e56ac654bf Mon Sep 17 00:00:00 2001 From: huidesheng <1832140001@qq.com> Date: Wed, 13 May 2026 11:15:15 +0800 Subject: [PATCH 2/4] add chunk_prefill_compiler.cpp/.hpp --- .../compiler/chunk_prefill_compiler.cpp | 186 ++++++++++++++++++ .../compiler/chunk_prefill_compiler.hpp | 42 ++++ 2 files changed, 228 insertions(+) create mode 100644 csrc/engine/compiler/chunk_prefill_compiler.cpp create mode 100644 csrc/engine/compiler/chunk_prefill_compiler.hpp diff --git a/csrc/engine/compiler/chunk_prefill_compiler.cpp b/csrc/engine/compiler/chunk_prefill_compiler.cpp new file mode 100644 index 00000000..266bd0e7 --- /dev/null +++ b/csrc/engine/compiler/chunk_prefill_compiler.cpp @@ -0,0 +1,186 @@ +#include "chunk_prefill_compiler.hpp" +#include "infinicore/context/context.hpp" + + +namespace { +inline void set_zeros(infinicore::Tensor &tensor) { + std::vector zeros(tensor->nbytes(), 0); + infinicore::context::memcpyH2D(tensor->data(), zeros.data(), tensor->nbytes(), false); +} +} // namespace + +namespace infinilm::engine { + +ChunkPrefillCompiler::ChunkPrefillCompiler(const std::shared_ptr &model, RankBarrier *barrier) + : GraphCompiler(model, barrier) { + // Enumerate chunk sizes for chunk-prefill + for (size_t cs : {64, 128, 256, 512, 1024, 2048}) { + chunk_sizes_.push_back(cs); + } + // Enumerate batch sizes for prefill (typically smaller than decode) + for (size_t b = 1; b < 32; b++) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 32; b < 64; b += 8) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 64; b < 128; b += 16) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 128; b < 256; b += 32) { + prefill_batch_sizes_.push_back(b); + } + for (size_t b = 256; b <= 512; b += 64) { + prefill_batch_sizes_.push_back(b); + } +} + +void ChunkPrefillCompiler::compile() { + if (model_->get_cache_config() != nullptr && + dynamic_cast(model_->get_cache_config())) { + + const auto *paged_config = + dynamic_cast(model_->get_cache_config()); + size_t nblocks = paged_config->num_blocks(); + + compiled_map_prefill_.clear(); + + // Max total tokens to avoid OOM during graph recording + constexpr size_t MAX_TOTAL_TOKENS = 4096; + + // Pre-allocate a shared block_tables_holder for the largest (batch_size) we'll use + size_t max_batch = *std::max_element(prefill_batch_sizes_.begin(), prefill_batch_sizes_.end()); + size_t block_per_req = nblocks / max_batch; + block_tables_holder_ = infinicore::Tensor::empty( + {nblocks}, infinicore::DataType::I32, infinicore::context::getDevice()); + set_zeros(block_tables_holder_); + + for (size_t b : prefill_batch_sizes_) { + for (size_t cs : chunk_sizes_) { + size_t total_tokens = b * cs; + if (total_tokens > MAX_TOTAL_TOKENS) { + continue; + } + + size_t bpr = nblocks / b; // block_per_req for this batch size + + InfinilmModel::Input input; + + // input_ids: [1, total_tokens] — all tokens for this batch packed together + input.input_ids = infinicore::Tensor::empty( + {1, total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + set_zeros(input.input_ids.value()); + + // position_ids: [total_tokens] + input.position_ids = infinicore::Tensor::empty( + {total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + set_zeros(input.position_ids.value()); + + // total_sequence_lengths: [b], set to cs (first-chunk scenario) + input.total_sequence_lengths = infinicore::Tensor::empty( + {b}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector tsl(b, static_cast(cs)); + infinicore::context::memcpyH2D( + input.total_sequence_lengths.value()->data(), + tsl.data(), b * sizeof(int32_t), false); + } + + // input_offsets: [b+1], stride = cs + input.input_offsets = infinicore::Tensor::empty( + {b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector offsets(b + 1); + for (size_t i = 0; i <= b; i++) { + offsets[i] = static_cast(i * cs); + } + infinicore::context::memcpyH2D( + input.input_offsets.value()->data(), + offsets.data(), (b + 1) * sizeof(int32_t), false); + } + + // cu_seqlens: [b+1], same layout as input_offsets for prefill + input.cu_seqlens = infinicore::Tensor::empty( + {b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + { + std::vector cu(b + 1); + for (size_t i = 0; i <= b; i++) { + cu[i] = static_cast(i * cs); + } + infinicore::context::memcpyH2D( + input.cu_seqlens.value()->data(), + cu.data(), (b + 1) * sizeof(int32_t), false); + } + + // block_tables: view into the shared holder [b, bpr] + input.block_tables = block_tables_holder_->as_strided( + {b, bpr}, {(ptrdiff_t)bpr, 1}); + + // slot_mapping: [total_tokens] + input.slot_mapping = infinicore::Tensor::empty( + {total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); + set_zeros(input.slot_mapping.value()); + + barrier_->wait(); + infinicore::context::startGraphRecording(); + auto output = model_->forward(input); + auto graph = infinicore::context::stopGraphRecording(); + barrier_->wait(); + + auto shared_output = std::shared_ptr( + new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)}); + + compiled_map_prefill_[std::make_tuple(b, cs)] = + CompiledResult{std::move(input), std::make_tuple(graph, shared_output)}; + } + } + } +} + +ChunkPrefillCompiler::Compiled ChunkPrefillCompiler::get_compiled(const InfinilmModel::Input &input) { + if (model_->get_cache_config() == nullptr || + !dynamic_cast(model_->get_cache_config())) { + return {nullptr, nullptr}; + } + + if (!input.block_tables.has_value() || !input.input_ids.has_value()) { + return {nullptr, nullptr}; + } + + size_t batch_size = input.block_tables.value()->size(0); + size_t block_per_req = input.block_tables.value()->size(1); + size_t total_tokens = input.input_ids.value()->size(1); + + // Prefill: total_tokens is a multiple of batch_size, and chunk_size > 1 + if (total_tokens == 0 || total_tokens % batch_size != 0) { + return {nullptr, nullptr}; + } + size_t chunk_size = total_tokens / batch_size; + if (chunk_size <= 1) { + // Single-token case belongs to decode + return {nullptr, nullptr}; + } + + auto result = compiled_map_prefill_.find(std::make_tuple(batch_size, chunk_size)); + if (result == compiled_map_prefill_.end()) { + return {nullptr, nullptr}; + } + + auto &graph_input = result->second.input; + + graph_input.input_ids.value()->copy_from(input.input_ids.value()); + graph_input.position_ids.value()->copy_from(input.position_ids.value()); + graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); + graph_input.input_offsets.value()->copy_from(input.input_offsets.value()); + graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value()); + graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value()); + graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); + + auto graph = std::get<0>(result->second.compiled); + auto shared_output = std::shared_ptr( + new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + + return std::make_tuple(graph, shared_output); +} + +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/chunk_prefill_compiler.hpp b/csrc/engine/compiler/chunk_prefill_compiler.hpp new file mode 100644 index 00000000..bd701158 --- /dev/null +++ b/csrc/engine/compiler/chunk_prefill_compiler.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include "graph_compiler.hpp" + +#include + +namespace infinilm::engine { +class ChunkPrefillCompiler : public GraphCompiler { +public: + ChunkPrefillCompiler(const std::shared_ptr &model, RankBarrier *barrier); + + void compile() override; + + Compiled get_compiled(const InfinilmModel::Input &input) override; + +private: + struct TupleHash { + size_t operator()(const std::tuple &t) const noexcept { + auto h1 = std::hash{}(std::get<0>(t)); + auto h2 = std::hash{}(std::get<1>(t)); + return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2)); + } + }; + + std::vector chunk_sizes_; + std::vector prefill_batch_sizes_; + + infinicore::Tensor block_tables_holder_; + + struct CompiledResult { + InfinilmModel::Input input; + Compiled compiled; + }; + + // Key: (batch_size, chunk_size) + std::unordered_map< + std::tuple, + CompiledResult, + TupleHash> + compiled_map_prefill_; +}; +} // namespace infinilm::engine From c596fff4b8bd452fb501517789ee1b19ad1fc2a9 Mon Sep 17 00:00:00 2001 From: huidesheng <1832140001@qq.com> Date: Fri, 15 May 2026 12:05:29 +0000 Subject: [PATCH 3/4] fix attn_metadata bug --- csrc/engine/compiler/chunk_prefill_compiler.cpp | 11 +++++++++++ scripts/test_perf.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/csrc/engine/compiler/chunk_prefill_compiler.cpp b/csrc/engine/compiler/chunk_prefill_compiler.cpp index 266bd0e7..2b867800 100644 --- a/csrc/engine/compiler/chunk_prefill_compiler.cpp +++ b/csrc/engine/compiler/chunk_prefill_compiler.cpp @@ -1,4 +1,5 @@ #include "chunk_prefill_compiler.hpp" +#include "../../global_state/global_state.hpp" #include "infinicore/context/context.hpp" @@ -121,6 +122,16 @@ void ChunkPrefillCompiler::compile() { {total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); set_zeros(input.slot_mapping.value()); + // Attention reads attn_metadata from thread-local forward context. + infinilm::global_state::get_forward_context().attn_metadata = { + input.past_sequence_lengths, + input.total_sequence_lengths, + input.input_offsets, + input.cu_seqlens, + input.block_tables, + input.slot_mapping, + }; + barrier_->wait(); infinicore::context::startGraphRecording(); auto output = model_->forward(input); diff --git a/scripts/test_perf.py b/scripts/test_perf.py index 6a33d8f0..74066ddc 100644 --- a/scripts/test_perf.py +++ b/scripts/test_perf.py @@ -29,7 +29,7 @@ NUM_REQUESTS = 64 CONCURRENCY = 20 -API_URL = "http://127.0.0.1:8000" +API_URL = "http://127.0.0.1:3456" MODEL = "FM9G-7B" From 50d8eb268d0b515646fe16f883a4171d350c84fd Mon Sep 17 00:00:00 2001 From: huidesheng <1832140001@qq.com> Date: Wed, 20 May 2026 11:46:49 +0000 Subject: [PATCH 4/4] fix schedule priority; add anti-starve mechanism; fix issue incompatible with Prefix Sharing --- .../compiler/chunk_prefill_compiler.cpp | 2 +- python/infinilm/base_config.py | 2 +- python/infinilm/llm/cache_manager.py | 41 +++- python/infinilm/llm/request.py | 2 +- python/infinilm/llm/scheduler.py | 203 +++++++++++++----- .../processors/basic_llm_processor.py | 40 ++-- 6 files changed, 207 insertions(+), 83 deletions(-) diff --git a/csrc/engine/compiler/chunk_prefill_compiler.cpp b/csrc/engine/compiler/chunk_prefill_compiler.cpp index 2b867800..55ad56f3 100644 --- a/csrc/engine/compiler/chunk_prefill_compiler.cpp +++ b/csrc/engine/compiler/chunk_prefill_compiler.cpp @@ -15,7 +15,7 @@ namespace infinilm::engine { ChunkPrefillCompiler::ChunkPrefillCompiler(const std::shared_ptr &model, RankBarrier *barrier) : GraphCompiler(model, barrier) { // Enumerate chunk sizes for chunk-prefill - for (size_t cs : {64, 128, 256, 512, 1024, 2048}) { + for (size_t cs : {256}) { chunk_sizes_.push_back(cs); } // Enumerate batch sizes for prefill (typically smaller than decode) diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index d7c32568..5ef2a8ff 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -125,7 +125,7 @@ def _add_common_args(self): ) self.parser.add_argument("--enable-graph", action="store_true") self.parser.add_argument("--enable-chunk-prefill-graph", action="store_true", help="enable chunk-prefill graph compiling") - self.parser.add_argument("--chunk-size", type=int, default=512, help="tokens per chunked-prefill slice (0 to disable)") + self.parser.add_argument("--chunk-size", type=int, default=0, help="tokens per chunked-prefill slice (0 to disable)") self.parser.add_argument( "--enable-paged-attn", action="store_true", diff --git a/python/infinilm/llm/cache_manager.py b/python/infinilm/llm/cache_manager.py index 44ca1376..df9f1957 100644 --- a/python/infinilm/llm/cache_manager.py +++ b/python/infinilm/llm/cache_manager.py @@ -119,24 +119,51 @@ def allocate_blocks( ) -> tuple[List[int], List[int], int]: """Allocate cache blocks for new request with prefix caching support. - Args: - token_ids: Input token sequence - block_table: Existing block_table (for decode phase) + Idempotent: if block_table already fully covers token_ids with valid + (still-active) blocks, returns a consistent (block_table, slot_mapping, + num_cached_tokens=0) without re-allocating. - Returns: - Tuple of (block_table, slot_mapping, num_cached_tokens) + Convention: len(slot_mapping) == num_tokens - num_cached_tokens + (one slot per token that needs to be (re)computed). """ if block_table is None: block_table = [] num_tokens = len(token_ids) - num_blocks = (num_tokens + self.block_size - 1) // self.block_size + if num_tokens == 0: + return [], [], 0 + + num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size + + # -------------------------------------------------------------- # + # Idempotent re-entry path # + # -------------------------------------------------------------- # + # If block_table already covers the prompt AND all those blocks + # are still alive (ref_count > 0), reconstruct slot_mapping from + # block_table and return num_cached_tokens=0 (i.e., the forward + # will recompute everything into the same slots — wasteful but + # always correct, and keeps the slot_mapping length convention). + if block_table and len(block_table) >= num_blocks_needed: + bt = list(block_table[:num_blocks_needed]) + if all(self.blocks[bid].ref_count > 0 for bid in bt): + slot_mapping = [ + bt[i // self.block_size] * self.block_size + (i % self.block_size) + for i in range(num_tokens) + ] + # length = num_tokens = num_tokens - 0 ✓ matches convention + return bt, slot_mapping, 0 + # Otherwise the block_table is stale — drop it and re-allocate. + block_table = [] + + # -------------------------------------------------------------- # + # Below: original code unchanged # + # -------------------------------------------------------------- # slot_mapping = [] num_cached_tokens = 0 prefix_hash = -1 cache_miss = False - for block_idx in range(num_blocks): + for block_idx in range(num_blocks_needed): start_idx = block_idx * self.block_size end_idx = min(start_idx + self.block_size, num_tokens) block_tokens = token_ids[start_idx:end_idx] diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index 679b6e4d..ef5c8cd2 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -196,7 +196,7 @@ def is_chunking(self) -> bool: return ( self.chunk_size > 0 and self.is_prefill - and self.prompt_length > self.chunk_size + and (self.prompt_length - self.num_cached_tokens) > self.chunk_size ) def chunk_is_last(self) -> bool: diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index 95a84480..5794de88 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -28,10 +28,17 @@ def __init__( class Scheduler: """Request scheduler with integrated BlockManager for KV cache management. - Scheduling logic: - 1. Running queue: Check for new blocks needed, update slot_mapping - 2. Waiting queue: Try block reuse (prefix caching), allocate new blocks - 3. Reference counting: Free blocks when requests complete + Scheduling priority (option A + B): + 1. Decode (running_queue) — latency-sensitive, never starves anyone. + 2. New prefill (waiting_queue) — preempts in-flight chunking so newly + arrived short requests don't wait for an entire long prefill. + 3. Continue chunked-prefill (chunking_queue) — single-request batch. + + Anti-starvation (option B): + After `max_waiting_yields` consecutive steps where waiting_queue won + over a non-empty chunking_queue, the next step is forced onto the + chunking_queue. This bounds the worst-case delay a long-prompt request + can suffer when there is a steady inflow of new short requests. """ def __init__( @@ -39,67 +46,137 @@ def __init__( max_batch_size: int = 16, num_blocks: int = 512, block_size: int = 256, + max_waiting_yields: int = 4, ): self.waiting_queue = janus.Queue() self.running_queue = janus.Queue() - # Requests in the middle of chunked-prefill — scheduled at high priority, - # single-request batches only (to match the C++ ChunkPrefillCompiler graph signature). + # Requests in the middle of chunked-prefill — single-request batch only + # (matches the C++ ChunkPrefillCompiler graph signature). self.chunking_queue = janus.Queue() self.max_batch_size = max_batch_size self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size) self.block_size = block_size + # ---- Anti-starvation state ---- + # How many times waiting_queue has won over a non-empty chunking_queue + # since the last time chunking actually ran. Reset to 0 every time we + # run a chunking step. + self._waiting_yields_in_a_row: int = 0 + # Upper bound on _waiting_yields_in_a_row before chunking is forced. + self.max_waiting_yields: int = max_waiting_yields + def add_request(self, request: InferenceRequest): if request is not None: request.status = RequestStatus.WAITING self.waiting_queue.sync_q.put(request) + # ------------------------------------------------------------------ # + # Main scheduling entrypoint # + # ------------------------------------------------------------------ # def schedule(self) -> Optional[SchedulerOutput]: """Schedule and return batch of requests to execute. - Priority (mirrors launch_server.py chunked-prefill scheduling): - 1. Running queue (decode) — short / latency-sensitive - 2. Chunking queue (in-flight chunked-prefill) — single-request slice - 3. Waiting queue (new prefill) — may start chunking if prompt is long + Priority (revised so chunk-prefill actually interleaves): + 1. Decode + 2. New prefill (waiting_queue) + 3. Continue chunked-prefill (chunking_queue) + + Order each call: + 1. Try decode (running_queue). + 2. If we've yielded to waiting `max_waiting_yields` times in a row + AND chunking_queue is non-empty → force chunking this step. + 3. Otherwise try waiting (may start a new chunking session and + emit a single-request batch immediately). + 4. Otherwise try chunking (continue an in-flight chunked-prefill). + 5. Otherwise (rare) — try waiting again as a final fallback. """ - # 2) Continue an in-flight chunked-prefill request (single-request batch). - try: - req = self.chunking_queue.sync_q.get_nowait() - except queue.Empty: - req = None - if req is not None: + # 1) Decode first — cheap and latency-sensitive. + decode_out = self._try_schedule_decode() + if decode_out is not None: + return decode_out + + # 2) Forced chunking after too many consecutive yields. + if self._waiting_yields_in_a_row >= self.max_waiting_yields: + chunking_out = self._try_schedule_chunking() + if chunking_out is not None: + self._waiting_yields_in_a_row = 0 + return chunking_out + # chunking_queue was actually empty — fall through to normal path. + + # 3) Waiting queue — newly arrived prefill preempts in-flight chunking. + # Snapshot whether chunking had anything BEFORE we drain waiting, + # so we can decide whether this counts as a "yield over chunking". + # (qsize() is racy but only affects the counter by ±1 in rare edge cases — not correctness) + chunking_was_nonempty = self.chunking_queue.sync_q.qsize() > 0 + + waiting_out = self._try_schedule_waiting() + if waiting_out is not None: + if chunking_was_nonempty: + # We took a step that COULD have been chunking — count it. + self._waiting_yields_in_a_row += 1 + else: + # No chunking to yield from; nothing was actually deferred. + self._waiting_yields_in_a_row = 0 + return waiting_out + + # 4) Continue an in-flight chunked-prefill request. + chunking_out = self._try_schedule_chunking() + if chunking_out is not None: + self._waiting_yields_in_a_row = 0 + return chunking_out + + return None + + # ------------------------------------------------------------------ # + # Per-queue schedulers # + # ------------------------------------------------------------------ # + def _try_schedule_chunking(self) -> Optional[SchedulerOutput]: + """Pull one in-flight chunked-prefill request and emit a single-request batch. + + The C++ ChunkPrefillCompiler graph is keyed on (batch_size, chunk_size). + Python currently sends batch=1 — see chunk_prefill_compiler.cpp. + """ + while True: + try: + req = self.chunking_queue.sync_q.get_nowait() + except queue.Empty: + return None if req.is_finished(): + # Drain finished entries silently and keep looking. self.complete_requests([req]) - else: - return SchedulerOutput( - scheduled_requests=[req], - is_prefill=True, - ) + continue + return SchedulerOutput( + scheduled_requests=[req], + is_prefill=True, + ) - scheduled_requests = [] - is_prefill = False + def _try_schedule_waiting(self) -> Optional[SchedulerOutput]: + """Pull new prefill requests from waiting_queue and form a prefill batch. + + If any request triggers chunked-prefill (prompt_length > chunk_size > 0), + it's emitted alone as a single-request batch (the chunking graph requires + a uniform chunk_size across the batch, and we don't mix chunking with + regular prefill in the same batch). + """ + scheduled_requests: List[InferenceRequest] = [] - # Process Waiting queue (prefill phase) while len(scheduled_requests) < self.max_batch_size: try: req = self.waiting_queue.sync_q.get_nowait() except queue.Empty: break - # Skip requests that were already finished (e.g., timed out/canceled while waiting) + + # Skip requests that were already finished (timed out / canceled while waiting). if req.is_finished(): self.complete_requests([req]) continue if not self.can_accept_request(req): + # Put it back; we'll retry next tick when cache pressure eases. self.waiting_queue.sync_q.put(req) break - # Skip requests that were already finished (e.g., timed out/canceled while waiting) - if req.is_finished(): - self.complete_requests([req]) - continue - req_tokens = req.get_input_tokens() num_required_blocks = req.get_num_blocks_required(self.block_size) @@ -107,47 +184,53 @@ def schedule(self) -> Optional[SchedulerOutput]: if not self.cache_manager.try_free_blocks(num_required_blocks): raise RuntimeError("No available cache blocks for new request") - # Allocate blocks with automatic prefix caching support - req.block_table, req.slot_mapping, req.num_cached_tokens = ( - self.cache_manager.allocate_blocks(req_tokens, req.block_table) - ) - + # Allocate blocks (prefix caching applied automatically). + if not req.block_table: + req.block_table, req.slot_mapping, req.num_cached_tokens = ( + self.cache_manager.allocate_blocks(req_tokens, req.block_table) + ) + req.num_blocks = len(req.block_table) req.status = RequestStatus.RUNNING - # Start chunked-prefill: enqueue into chunking_queue and emit a - # single-request batch immediately. We don't mix chunked-prefill - # with other requests in the same batch — the C++ ChunkPrefillCompiler - # graph is keyed on (batch_size, chunk_size). - if req.chunk_size > 0 and req.prompt_length > req.chunk_size: - req.chunk_prefill_offset = 0 - return SchedulerOutput( - scheduled_requests=[req], - is_prefill=True, - ) + # Start chunked-prefill: emit a single-request batch immediately + # to keep the C++ graph signature stable. The request will be + # requeued into chunking_queue by llm._update_requests after each + # chunk runs. + remaining = req.prompt_length - req.num_cached_tokens + if req.chunk_size > 0 and remaining > req.chunk_size: + req.chunk_prefill_offset = req.num_cached_tokens + if scheduled_requests: + for already in scheduled_requests: + already.status = RequestStatus.WAITING + self.waiting_queue.sync_q.put(already) + return SchedulerOutput([req], is_prefill=True) scheduled_requests.append(req) - # Return prefill batch if any waiting requests were scheduled if scheduled_requests: - is_prefill = True return SchedulerOutput( scheduled_requests=scheduled_requests, - is_prefill=is_prefill, + is_prefill=True, ) + return None + + def _try_schedule_decode(self) -> Optional[SchedulerOutput]: + """Pull running_queue requests into a decode batch.""" + scheduled_requests: List[InferenceRequest] = [] - # Process Running queue (decode phase) while len(scheduled_requests) < self.max_batch_size: try: req = self.running_queue.sync_q.get_nowait() except queue.Empty: break - # Skip requests that were already finished (e.g., timed out/canceled while running) + + # Skip requests that were already finished (timed out / canceled while running). if req.is_finished(): self.complete_requests([req]) continue - # Decode phase: allocate slot for newly generated token + # Decode phase: allocate slot for newly generated token. try: req.block_table, new_slot = self.cache_manager.append_slot( req.block_table, req.get_total_length(), req.get_all_token_ids() @@ -156,26 +239,30 @@ def schedule(self) -> Optional[SchedulerOutput]: req.num_blocks = len(req.block_table) req.num_cached_tokens = req.get_total_length() - 1 scheduled_requests.append(req) - except RuntimeError as e: raise RuntimeError("No available cache blocks for new token") from e - # Return decode batch if any running requests were scheduled if scheduled_requests: - is_prefill = False return SchedulerOutput( scheduled_requests=scheduled_requests, - is_prefill=is_prefill, + is_prefill=False, ) - return None + # ------------------------------------------------------------------ # + # External hooks (unchanged behavior) # + # ------------------------------------------------------------------ # def requeue_chunking(self, req: InferenceRequest): """Put a request back into the chunking queue after a chunk has run.""" self.chunking_queue.sync_q.put(req) def complete_requests(self, requests: List[InferenceRequest]): - """Handle completed requests and free their blocks.""" + """Handle completed requests and free their blocks. + + Active (non-finished) requests passed here are re-enqueued into the + running_queue — this is how prefill-finished requests migrate into + the decode pipeline. + """ for req in requests: if req.status in [ RequestStatus.FINISHED, @@ -235,4 +322,4 @@ def get_cache_stats(self) -> dict: "num_free_blocks": self.cache_manager.get_num_free_blocks(), "num_req_blocks": len(self.cache_manager.req_block_ids), "num_used_blocks": len(self.cache_manager.used_block_ids), - } + } \ No newline at end of file diff --git a/python/infinilm/processors/basic_llm_processor.py b/python/infinilm/processors/basic_llm_processor.py index f5e603ba..397e9068 100644 --- a/python/infinilm/processors/basic_llm_processor.py +++ b/python/infinilm/processors/basic_llm_processor.py @@ -183,62 +183,72 @@ def _build_model_input_from_batch_scheduler_output( for req in scheduler_output.scheduled_requests: num_cached = req.num_cached_tokens if scheduler_output.is_prefill: - # Prefill phase req_tokens = req.get_input_tokens() # Chunked-prefill: only feed [chunk_prefill_offset : +chunk_size). - # past_kv_lengths = chunk_prefill_offset (attention sees the prefix - # already committed); total_kv_lengths = chunk_prefill_offset + - # len(tokens_to_compute). This keeps batch_size=1 and total_tokens - # == chunk_size so the C++ ChunkPrefillCompiler graph hits. if req.is_chunking(): start = req.chunk_prefill_offset end = min(start + req.chunk_size, len(req_tokens)) tokens_to_compute = req_tokens[start:end] compute_len = len(tokens_to_compute) tokens.extend(tokens_to_compute) - seq_len = end # attention prefix length after this chunk + seq_len = end seq_lens.append(seq_len) current_offset += compute_len seq_offsets.append(current_offset) - slot_mapping.extend(req.slot_mapping[start:end]) + # req.slot_mapping has length (prompt_length - num_cached) and is + # indexed [0..prompt_length-num_cached). Translate absolute token + # indices to slot_mapping indices. + slot_start = start - num_cached + slot_end = end - num_cached + assert slot_start >= 0 and slot_end <= len(req.slot_mapping), ( + f"chunking slot slice out of range: start={start} " + f"end={end} num_cached={num_cached} " + f"len(slot_mapping)={len(req.slot_mapping)}" + ) + slot_mapping.extend(req.slot_mapping[slot_start:slot_end]) cached_lens.append(start) position_ids.extend(range(start, end)) else: tokens_to_compute = req_tokens[num_cached:] tokens.extend(tokens_to_compute) - compute_len = len(tokens_to_compute) seq_len = len(req_tokens) seq_lens.append(seq_len) - current_offset += compute_len seq_offsets.append(current_offset) - slot_mapping.extend(req.slot_mapping) cached_lens.append(num_cached) position_ids.extend(range(num_cached, num_cached + compute_len)) - else: - # Decode phase seq_len = req.get_total_length() last_token = req.generated_token_ids[-1] tokens.append(last_token) seq_lens.append(seq_len) - current_offset += 1 seq_offsets.append(current_offset) - slot_mapping.extend(req.slot_mapping) cached_lens.append(num_cached) position_ids.append(seq_len - 1) - # Pad block_table to same length padded_block_table = req.block_table + [-1] * ( max_block_table_len - len(req.block_table) ) block_tables.append(padded_block_table) cu_seqlens.append(cu_seqlens[-1] + seq_len) + + # guarantee non-empty tokens and slot_mapping to avoid downstream errors. If empty, raise with detailed debug info. + if not tokens or not slot_mapping: + states = [ + (r.request_id[:8], r.is_prefill, r.is_chunking(), + r.chunk_prefill_offset, r.prompt_length, r.num_cached_tokens, + len(r.slot_mapping), r.status.name) + for r in scheduler_output.scheduled_requests + ] + raise RuntimeError( + f"build_model_inputs got empty tokens/slot_mapping. " + f"is_prefill={scheduler_output.is_prefill}, states={states}" + ) return { "input_ids": infinicore.from_list([tokens], dtype=infinicore.int64),