From cd23a52365ff3d4e97f432e492e1d58fed85632a Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 8 May 2026 21:21:01 +0800 Subject: [PATCH 1/4] feat: configure cudagraph capture batch sizes --- lmdeploy/cli/serve.py | 2 + lmdeploy/cli/utils.py | 9 ++ lmdeploy/messages.py | 3 + .../pytorch/backends/cuda/graph_runner.py | 11 ++- lmdeploy/pytorch/backends/graph_runner.py | 2 + lmdeploy/pytorch/config.py | 5 +- lmdeploy/pytorch/engine/config_builder.py | 13 +++ .../test_cudagraph_capture_batch_sizes.py | 95 +++++++++++++++++++ 8 files changed, 136 insertions(+), 4 deletions(-) create mode 100644 tests/pytorch/backends/test_cudagraph_capture_batch_sizes.py diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 9ca662b961..71d6c7965a 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -128,6 +128,7 @@ def add_parser_api_server(): ArgumentHelper.enable_eplb(pt_group) ArgumentHelper.role(pt_group) ArgumentHelper.migration_backend(pt_group) + ArgumentHelper.cudagraph_capture_batch_sizes(pt_group) # multi-node serving args node_rank_act = ArgumentHelper.node_rank(pt_group) num_nodes_act = ArgumentHelper.num_nodes(pt_group) @@ -236,6 +237,7 @@ def api_server(args): quant_policy=args.quant_policy, eager_mode=args.eager_mode, max_prefill_token_num=args.max_prefill_token_num, + cudagraph_capture_batch_sizes=args.cudagraph_capture_batch_sizes, enable_microbatch=args.enable_microbatch, enable_eplb=args.enable_eplb, enable_metrics=not args.disable_metrics, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 762bd2f04b..389ab64806 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -604,6 +604,15 @@ def max_prefill_token_num(parser): default=8192, help='the max number of tokens per iteration during prefill') + @staticmethod + def cudagraph_capture_batch_sizes(parser): + return parser.add_argument('--cudagraph-capture-batch-sizes', + type=int, + nargs='+', + default=None, + help='Batch sizes to capture CUDA graphs for in the PyTorch engine. ' + 'If not specified, the engine infers them from max_batch_size') + @staticmethod def vision_max_batch_size(parser): return parser.add_argument('--vision-max-batch-size', type=int, default=1, help='the vision model batch size') diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 9041486c38..dddbb2c981 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -352,6 +352,8 @@ class PytorchEngineConfig: would be allocate according to current environment. adapters: The path configs to lora adapters. max_prefill_token_num: tokens per iteration. + cudagraph_capture_batch_sizes: Batch sizes to capture CUDA graphs for. + If not specified, the engine will infer them from max_batch_size. thread_safe: thread safe engine instance. enable_prefix_caching: Enable token match and sharing caches. device_type: The inference device type, options ['cuda'] @@ -411,6 +413,7 @@ class PytorchEngineConfig: num_gpu_blocks: int = 0 adapters: dict[str, str] = None max_prefill_token_num: int = 8192 + cudagraph_capture_batch_sizes: list[int] | None = None thread_safe: bool = False enable_prefix_caching: bool = False device_type: str = 'cuda' diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index d0ea917d14..d13c161581 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -188,7 +188,7 @@ def _get_capture_tokens(self, batch_size: int): for size in cap_sizes: if size >= batch_size: return size - assert False, f'Unsupported batch_size={batch_size}' + return None def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: list, attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs): @@ -237,6 +237,10 @@ def __call__(self, **kwargs): graph_key = self.get_graph_key(**kwargs) max_batches = graph_key[0] + if max_batches is None: + with record_function('forward_eager'): + output = self.model(**kwargs) + return self.model.make_output_buffers(output) is_decoding = graph_key[1] decode_query_len = graph_key[3] if graph_key not in self._runner_map: @@ -303,9 +307,12 @@ def update_inputs(self, inputs): meta = self.get_meta() padding_batch_size = meta.padding_batch_size tp_size = self._get_capture_tokens(padding_batch_size) - dp_meta.sync_tp_size(tp_size) + if tp_size is not None: + dp_meta.sync_tp_size(tp_size) return inputs def get_capture_batch_sizes(self) -> list[int]: """Capture batch sizes.""" + if self.cache_config.cudagraph_capture_batch_sizes is not None: + return self.cache_config.cudagraph_capture_batch_sizes return _get_capture_batch_size_impl(self.cache_config.max_batches) diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index 72f460ef5b..c3224f52bf 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -101,4 +101,6 @@ def update_inputs(self, inputs): def get_capture_batch_sizes(self) -> list[int]: """Capture batch sizes.""" + if self.cache_config.cudagraph_capture_batch_sizes is not None: + return self.cache_config.cudagraph_capture_batch_sizes return _get_capture_batch_size_impl(self.cache_config.max_batches) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 9dfecaf30d..26ca2eb692 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -98,6 +98,7 @@ class CacheConfig: window_size: int = -1 cache_max_entry_count: float = 0.8 max_prefill_token_num: int = 8192 + cudagraph_capture_batch_sizes: list[int] | None = None enable_prefix_caching: bool = False quant_policy: QuantPolicy = QuantPolicy.NONE device_type: str = 'cuda' @@ -390,9 +391,8 @@ def from_pretrained( activations. Refer to `PyTorchEngineConfig` for details hf_overrides (dict[str, Any]): overrides for the HF config. """ - from transformers import AutoConfig - from lmdeploy.pytorch.transformers import config_from_pretrained + from transformers import AutoConfig hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) if getattr(hf_config, 'model_type', None) in ['phi3']: # phi3 + trust_remote_code leads to error when tp. @@ -587,6 +587,7 @@ def from_config( num_gpu_blocks=target_cache_cfg.num_gpu_blocks, cache_max_entry_count=target_cache_cfg.cache_max_entry_count, max_prefill_token_num=target_cache_cfg.max_prefill_token_num, + cudagraph_capture_batch_sizes=target_cache_cfg.cudagraph_capture_batch_sizes, device_type=target_cache_cfg.device_type, migration_backend=target_cache_cfg.migration_backend) obj = cls( diff --git a/lmdeploy/pytorch/engine/config_builder.py b/lmdeploy/pytorch/engine/config_builder.py index c96fa5b67a..95fba7b9c6 100644 --- a/lmdeploy/pytorch/engine/config_builder.py +++ b/lmdeploy/pytorch/engine/config_builder.py @@ -39,6 +39,18 @@ def update_engine_config(engine_config: PytorchEngineConfig): f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size ' f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).') + capture_sizes = engine_config.cudagraph_capture_batch_sizes + if capture_sizes is not None: + assert len(capture_sizes) > 0, 'cudagraph_capture_batch_sizes should not be empty' + assert all(isinstance(size, int) for size in capture_sizes), ( + 'cudagraph_capture_batch_sizes should be integers') + assert all(size > 0 for size in capture_sizes), 'cudagraph_capture_batch_sizes should be positive' + capture_sizes = sorted({size for size in capture_sizes if size <= engine_config.max_batch_size}) + assert len(capture_sizes) > 0, ( + 'cudagraph_capture_batch_sizes should contain at least one value ' + f'<= max_batch_size ({engine_config.max_batch_size})') + engine_config.cudagraph_capture_batch_sizes = capture_sizes + if engine_config.dp != 1: if engine_config.tp == 1 and engine_config.ep == 1: logger.warning('Data parallelism is enabled but tensor parallelism and ' @@ -67,6 +79,7 @@ def build_cache_config(engine_config: PytorchEngineConfig): num_gpu_blocks=engine_config.num_gpu_blocks, cache_max_entry_count=engine_config.cache_max_entry_count, max_prefill_token_num=engine_config.max_prefill_token_num, + cudagraph_capture_batch_sizes=engine_config.cudagraph_capture_batch_sizes, enable_prefix_caching=engine_config.enable_prefix_caching, quant_policy=engine_config.quant_policy, device_type=engine_config.device_type, diff --git a/tests/pytorch/backends/test_cudagraph_capture_batch_sizes.py b/tests/pytorch/backends/test_cudagraph_capture_batch_sizes.py new file mode 100644 index 0000000000..4b2fcdf0a6 --- /dev/null +++ b/tests/pytorch/backends/test_cudagraph_capture_batch_sizes.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from types import SimpleNamespace + +import pytest + +from lmdeploy.cli.utils import ArgumentHelper +from lmdeploy.messages import PytorchEngineConfig +from lmdeploy.pytorch.backends.cuda.graph_runner import ( + CUDAGraphRunner, + _get_capture_batch_size_impl, +) +from lmdeploy.pytorch.config import CacheConfig +from lmdeploy.pytorch.engine.config_builder import ConfigBuilder + + +def _cache_config(max_batches=8, cudagraph_capture_batch_sizes=None): + return CacheConfig(max_batches=max_batches, + block_size=64, + num_cpu_blocks=0, + num_gpu_blocks=1, + cudagraph_capture_batch_sizes=cudagraph_capture_batch_sizes) + + +def test_default_capture_batch_sizes_are_unchanged(): + cache_config = _cache_config(max_batches=512) + runner = object.__new__(CUDAGraphRunner) + runner.cache_config = cache_config + + assert runner.get_capture_batch_sizes() == _get_capture_batch_size_impl(512) + + +def test_custom_capture_batch_sizes_are_normalized_in_engine_config(): + engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[8, 1, 4, 4]) + + engine_config = ConfigBuilder.update_engine_config(engine_config) + cache_config = ConfigBuilder.build_cache_config(engine_config) + + assert engine_config.cudagraph_capture_batch_sizes == [1, 4, 8] + assert cache_config.cudagraph_capture_batch_sizes == [1, 4, 8] + + +def test_capture_batch_sizes_larger_than_max_batch_size_are_filtered(): + engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[1, 4, 16]) + + engine_config = ConfigBuilder.update_engine_config(engine_config) + + assert engine_config.cudagraph_capture_batch_sizes == [1, 4] + + +@pytest.mark.parametrize('sizes', [[], [0], [-1], [1.5], ['1'], [16]]) +def test_invalid_capture_batch_sizes_raise(sizes): + engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=sizes) + + with pytest.raises(AssertionError): + ConfigBuilder.update_engine_config(engine_config) + + +def test_graph_runner_uses_custom_capture_batch_sizes(): + cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[1, 4]) + runner = object.__new__(CUDAGraphRunner) + runner.cache_config = cache_config + + assert runner.get_capture_batch_sizes() == [1, 4] + assert runner._get_capture_tokens(2) == 4 + assert runner._get_capture_tokens(8) is None + + +def test_runtime_batch_larger_than_capture_sizes_falls_back_to_eager_forward(): + + class FakeModel: + + def __call__(self, **kwargs): + return 'eager-output' + + def make_output_buffers(self, output): + return {'output': output} + + runner = object.__new__(CUDAGraphRunner) + runner.backend_config = SimpleNamespace(eager_mode=True) + runner.model = FakeModel() + runner._prepare_inputs = lambda **kwargs: kwargs + runner.enable_graph = lambda **kwargs: True + runner.get_graph_key = lambda **kwargs: (None, True, False, 1) + + assert runner(input_ids='dummy', attn_metadata='dummy') == {'output': 'eager-output'} + + +def test_cudagraph_capture_batch_sizes_cli_arg(): + parser = argparse.ArgumentParser() + ArgumentHelper.cudagraph_capture_batch_sizes(parser) + + args = parser.parse_args(['--cudagraph-capture-batch-sizes', '1', '2', '4', '8']) + + assert args.cudagraph_capture_batch_sizes == [1, 2, 4, 8] From 6593d977e2b58c375f69571667cb0bc11c3bebd9 Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 26 May 2026 11:43:03 +0800 Subject: [PATCH 2/4] chore: simplify cudagraph capture sizes --- lmdeploy/pytorch/config.py | 3 +- lmdeploy/pytorch/engine/config_builder.py | 5 +- .../test_cudagraph_capture_batch_sizes.py | 95 ------------------- 3 files changed, 4 insertions(+), 99 deletions(-) delete mode 100644 tests/pytorch/backends/test_cudagraph_capture_batch_sizes.py diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 26ca2eb692..8450961c27 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -391,8 +391,9 @@ def from_pretrained( activations. Refer to `PyTorchEngineConfig` for details hf_overrides (dict[str, Any]): overrides for the HF config. """ - from lmdeploy.pytorch.transformers import config_from_pretrained from transformers import AutoConfig + + from lmdeploy.pytorch.transformers import config_from_pretrained hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) if getattr(hf_config, 'model_type', None) in ['phi3']: # phi3 + trust_remote_code leads to error when tp. diff --git a/lmdeploy/pytorch/engine/config_builder.py b/lmdeploy/pytorch/engine/config_builder.py index 95fba7b9c6..1209170cc6 100644 --- a/lmdeploy/pytorch/engine/config_builder.py +++ b/lmdeploy/pytorch/engine/config_builder.py @@ -42,9 +42,8 @@ def update_engine_config(engine_config: PytorchEngineConfig): capture_sizes = engine_config.cudagraph_capture_batch_sizes if capture_sizes is not None: assert len(capture_sizes) > 0, 'cudagraph_capture_batch_sizes should not be empty' - assert all(isinstance(size, int) for size in capture_sizes), ( - 'cudagraph_capture_batch_sizes should be integers') - assert all(size > 0 for size in capture_sizes), 'cudagraph_capture_batch_sizes should be positive' + assert all(isinstance(size, int) and size > 0 for size in capture_sizes), ( + 'cudagraph_capture_batch_sizes should be positive integers') capture_sizes = sorted({size for size in capture_sizes if size <= engine_config.max_batch_size}) assert len(capture_sizes) > 0, ( 'cudagraph_capture_batch_sizes should contain at least one value ' diff --git a/tests/pytorch/backends/test_cudagraph_capture_batch_sizes.py b/tests/pytorch/backends/test_cudagraph_capture_batch_sizes.py deleted file mode 100644 index 4b2fcdf0a6..0000000000 --- a/tests/pytorch/backends/test_cudagraph_capture_batch_sizes.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import argparse -from types import SimpleNamespace - -import pytest - -from lmdeploy.cli.utils import ArgumentHelper -from lmdeploy.messages import PytorchEngineConfig -from lmdeploy.pytorch.backends.cuda.graph_runner import ( - CUDAGraphRunner, - _get_capture_batch_size_impl, -) -from lmdeploy.pytorch.config import CacheConfig -from lmdeploy.pytorch.engine.config_builder import ConfigBuilder - - -def _cache_config(max_batches=8, cudagraph_capture_batch_sizes=None): - return CacheConfig(max_batches=max_batches, - block_size=64, - num_cpu_blocks=0, - num_gpu_blocks=1, - cudagraph_capture_batch_sizes=cudagraph_capture_batch_sizes) - - -def test_default_capture_batch_sizes_are_unchanged(): - cache_config = _cache_config(max_batches=512) - runner = object.__new__(CUDAGraphRunner) - runner.cache_config = cache_config - - assert runner.get_capture_batch_sizes() == _get_capture_batch_size_impl(512) - - -def test_custom_capture_batch_sizes_are_normalized_in_engine_config(): - engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[8, 1, 4, 4]) - - engine_config = ConfigBuilder.update_engine_config(engine_config) - cache_config = ConfigBuilder.build_cache_config(engine_config) - - assert engine_config.cudagraph_capture_batch_sizes == [1, 4, 8] - assert cache_config.cudagraph_capture_batch_sizes == [1, 4, 8] - - -def test_capture_batch_sizes_larger_than_max_batch_size_are_filtered(): - engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[1, 4, 16]) - - engine_config = ConfigBuilder.update_engine_config(engine_config) - - assert engine_config.cudagraph_capture_batch_sizes == [1, 4] - - -@pytest.mark.parametrize('sizes', [[], [0], [-1], [1.5], ['1'], [16]]) -def test_invalid_capture_batch_sizes_raise(sizes): - engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=sizes) - - with pytest.raises(AssertionError): - ConfigBuilder.update_engine_config(engine_config) - - -def test_graph_runner_uses_custom_capture_batch_sizes(): - cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[1, 4]) - runner = object.__new__(CUDAGraphRunner) - runner.cache_config = cache_config - - assert runner.get_capture_batch_sizes() == [1, 4] - assert runner._get_capture_tokens(2) == 4 - assert runner._get_capture_tokens(8) is None - - -def test_runtime_batch_larger_than_capture_sizes_falls_back_to_eager_forward(): - - class FakeModel: - - def __call__(self, **kwargs): - return 'eager-output' - - def make_output_buffers(self, output): - return {'output': output} - - runner = object.__new__(CUDAGraphRunner) - runner.backend_config = SimpleNamespace(eager_mode=True) - runner.model = FakeModel() - runner._prepare_inputs = lambda **kwargs: kwargs - runner.enable_graph = lambda **kwargs: True - runner.get_graph_key = lambda **kwargs: (None, True, False, 1) - - assert runner(input_ids='dummy', attn_metadata='dummy') == {'output': 'eager-output'} - - -def test_cudagraph_capture_batch_sizes_cli_arg(): - parser = argparse.ArgumentParser() - ArgumentHelper.cudagraph_capture_batch_sizes(parser) - - args = parser.parse_args(['--cudagraph-capture-batch-sizes', '1', '2', '4', '8']) - - assert args.cudagraph_capture_batch_sizes == [1, 2, 4, 8] From f540f50384a5fb72a924965853acce92e5e5b1a8 Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 26 May 2026 11:53:50 +0800 Subject: [PATCH 3/4] fix: require cudagraph capture coverage --- lmdeploy/cli/utils.py | 3 ++- lmdeploy/messages.py | 1 + lmdeploy/pytorch/backends/cuda/graph_runner.py | 9 ++------- lmdeploy/pytorch/engine/config_builder.py | 2 ++ 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 389ab64806..e9860e0f49 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -611,7 +611,8 @@ def cudagraph_capture_batch_sizes(parser): nargs='+', default=None, help='Batch sizes to capture CUDA graphs for in the PyTorch engine. ' - 'If not specified, the engine infers them from max_batch_size') + 'If not specified, the engine infers them from max_batch_size. ' + 'max_batch_size is always captured') @staticmethod def vision_max_batch_size(parser): diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index dddbb2c981..79e550911a 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -354,6 +354,7 @@ class PytorchEngineConfig: max_prefill_token_num: tokens per iteration. cudagraph_capture_batch_sizes: Batch sizes to capture CUDA graphs for. If not specified, the engine will infer them from max_batch_size. + max_batch_size is always captured. thread_safe: thread safe engine instance. enable_prefix_caching: Enable token match and sharing caches. device_type: The inference device type, options ['cuda'] diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index d13c161581..56f2748fe1 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -188,7 +188,7 @@ def _get_capture_tokens(self, batch_size: int): for size in cap_sizes: if size >= batch_size: return size - return None + assert False, f'Unsupported batch_size={batch_size}' def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: list, attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs): @@ -237,10 +237,6 @@ def __call__(self, **kwargs): graph_key = self.get_graph_key(**kwargs) max_batches = graph_key[0] - if max_batches is None: - with record_function('forward_eager'): - output = self.model(**kwargs) - return self.model.make_output_buffers(output) is_decoding = graph_key[1] decode_query_len = graph_key[3] if graph_key not in self._runner_map: @@ -307,8 +303,7 @@ def update_inputs(self, inputs): meta = self.get_meta() padding_batch_size = meta.padding_batch_size tp_size = self._get_capture_tokens(padding_batch_size) - if tp_size is not None: - dp_meta.sync_tp_size(tp_size) + dp_meta.sync_tp_size(tp_size) return inputs def get_capture_batch_sizes(self) -> list[int]: diff --git a/lmdeploy/pytorch/engine/config_builder.py b/lmdeploy/pytorch/engine/config_builder.py index 1209170cc6..4373535e08 100644 --- a/lmdeploy/pytorch/engine/config_builder.py +++ b/lmdeploy/pytorch/engine/config_builder.py @@ -48,6 +48,8 @@ def update_engine_config(engine_config: PytorchEngineConfig): assert len(capture_sizes) > 0, ( 'cudagraph_capture_batch_sizes should contain at least one value ' f'<= max_batch_size ({engine_config.max_batch_size})') + if capture_sizes[-1] != engine_config.max_batch_size: + capture_sizes.append(engine_config.max_batch_size) engine_config.cudagraph_capture_batch_sizes = capture_sizes if engine_config.dp != 1: From 93496c591e3d3469b0336d429e06f226642e22d5 Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 26 May 2026 16:21:31 +0800 Subject: [PATCH 4/4] fix: normalize cudagraph capture sizes --- .../pytorch/backends/cuda/graph_runner.py | 9 ++- lmdeploy/pytorch/backends/graph_runner.py | 9 ++- lmdeploy/pytorch/config.py | 20 +++++++ lmdeploy/pytorch/engine/config_builder.py | 13 +---- .../test_cudagraph_capture_batch_sizes.py | 55 +++++++++++++++++++ 5 files changed, 94 insertions(+), 12 deletions(-) create mode 100644 tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 56f2748fe1..bc7bf09b1b 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -7,7 +7,12 @@ from lmdeploy.pytorch.backends.deepep_state import get_deepep_state from lmdeploy.pytorch.backends.selector import get_backend -from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.pytorch.config import ( + BackendConfig, + CacheConfig, + ModelConfig, + normalize_cudagraph_capture_batch_sizes, +) from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta from lmdeploy.pytorch.strategies.base import StrategyFactoryBase @@ -309,5 +314,7 @@ def update_inputs(self, inputs): def get_capture_batch_sizes(self) -> list[int]: """Capture batch sizes.""" if self.cache_config.cudagraph_capture_batch_sizes is not None: + self.cache_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes( + self.cache_config.cudagraph_capture_batch_sizes, self.cache_config.max_batches) return self.cache_config.cudagraph_capture_batch_sizes return _get_capture_batch_size_impl(self.cache_config.max_batches) diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index c3224f52bf..c977e6bc63 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -4,7 +4,12 @@ import torch -from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.pytorch.config import ( + BackendConfig, + CacheConfig, + ModelConfig, + normalize_cudagraph_capture_batch_sizes, +) from lmdeploy.pytorch.model_inputs import StepContext @@ -102,5 +107,7 @@ def update_inputs(self, inputs): def get_capture_batch_sizes(self) -> list[int]: """Capture batch sizes.""" if self.cache_config.cudagraph_capture_batch_sizes is not None: + self.cache_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes( + self.cache_config.cudagraph_capture_batch_sizes, self.cache_config.max_batches) return self.cache_config.cudagraph_capture_batch_sizes return _get_capture_batch_size_impl(self.cache_config.max_batches) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 8450961c27..f56937166c 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -14,6 +14,24 @@ logger = get_logger('lmdeploy') +def normalize_cudagraph_capture_batch_sizes(capture_sizes: list[int] | None, max_batches: int) -> list[int] | None: + """Normalize configured cudagraph capture batch sizes.""" + if capture_sizes is None: + return None + + assert len(capture_sizes) > 0, 'cudagraph_capture_batch_sizes should not be empty' + assert all(isinstance(size, int) and size > 0 for size in capture_sizes), ( + 'cudagraph_capture_batch_sizes should be positive integers') + + capture_sizes = sorted({size for size in capture_sizes if size <= max_batches}) + assert len(capture_sizes) > 0, ( + 'cudagraph_capture_batch_sizes should contain at least one value ' + f'<= max_batch_size ({max_batches})') + if capture_sizes[-1] != max_batches: + capture_sizes.append(max_batches) + return capture_sizes + + def _update_torch_dtype(config: 'ModelConfig', dtype: str, device_type: str = 'auto'): """Update the torch dtype from the model config. @@ -119,6 +137,8 @@ def __post_init__(self): self.enable_prefix_caching = False if self.kernel_block_size == -1: self.kernel_block_size = self.block_size + self.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes( + self.cudagraph_capture_batch_sizes, self.max_batches) class TPMode(enum.Enum): diff --git a/lmdeploy/pytorch/engine/config_builder.py b/lmdeploy/pytorch/engine/config_builder.py index 4373535e08..e147747000 100644 --- a/lmdeploy/pytorch/engine/config_builder.py +++ b/lmdeploy/pytorch/engine/config_builder.py @@ -10,6 +10,7 @@ MiscConfig, SchedulerConfig, SpecDecodeConfig, + normalize_cudagraph_capture_batch_sizes, ) from lmdeploy.utils import get_logger, get_max_batch_size, get_model @@ -41,16 +42,8 @@ def update_engine_config(engine_config: PytorchEngineConfig): capture_sizes = engine_config.cudagraph_capture_batch_sizes if capture_sizes is not None: - assert len(capture_sizes) > 0, 'cudagraph_capture_batch_sizes should not be empty' - assert all(isinstance(size, int) and size > 0 for size in capture_sizes), ( - 'cudagraph_capture_batch_sizes should be positive integers') - capture_sizes = sorted({size for size in capture_sizes if size <= engine_config.max_batch_size}) - assert len(capture_sizes) > 0, ( - 'cudagraph_capture_batch_sizes should contain at least one value ' - f'<= max_batch_size ({engine_config.max_batch_size})') - if capture_sizes[-1] != engine_config.max_batch_size: - capture_sizes.append(engine_config.max_batch_size) - engine_config.cudagraph_capture_batch_sizes = capture_sizes + engine_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes( + capture_sizes, engine_config.max_batch_size) if engine_config.dp != 1: if engine_config.tp == 1 and engine_config.ep == 1: diff --git a/tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py b/tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py new file mode 100644 index 0000000000..01cb783cd5 --- /dev/null +++ b/tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from lmdeploy.messages import PytorchEngineConfig +from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner +from lmdeploy.pytorch.config import CacheConfig +from lmdeploy.pytorch.engine.config_builder import ConfigBuilder + + +def _cache_config(max_batches=8, cudagraph_capture_batch_sizes=None): + return CacheConfig(max_batches=max_batches, + block_size=64, + num_cpu_blocks=0, + num_gpu_blocks=1, + cudagraph_capture_batch_sizes=cudagraph_capture_batch_sizes) + + +def test_custom_capture_batch_sizes_include_max_batch_size(): + engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[4, 1, 4, 16]) + + engine_config = ConfigBuilder.update_engine_config(engine_config) + + assert engine_config.cudagraph_capture_batch_sizes == [1, 4, 8] + + +def test_cache_config_normalizes_capture_batch_sizes(): + cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[4, 1, 4, 16]) + + assert cache_config.cudagraph_capture_batch_sizes == [1, 4, 8] + + +@pytest.mark.parametrize('sizes', [[], [0], [-1], [1.5], ['1'], [16]]) +def test_invalid_capture_batch_sizes_raise(sizes): + with pytest.raises(AssertionError): + _cache_config(max_batches=8, cudagraph_capture_batch_sizes=sizes) + + +def test_capture_batch_size_miss_raises(): + engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[1, 4]) + engine_config = ConfigBuilder.update_engine_config(engine_config) + runner = object.__new__(CUDAGraphRunner) + runner.cache_config = ConfigBuilder.build_cache_config(engine_config) + + assert runner._get_capture_tokens(5) == 8 + with pytest.raises(AssertionError): + runner._get_capture_tokens(9) + + +def test_graph_runner_defensively_normalizes_capture_batch_sizes(): + cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[1, 8]) + cache_config.cudagraph_capture_batch_sizes = [4, 1, 4, 16] + runner = object.__new__(CUDAGraphRunner) + runner.cache_config = cache_config + + assert runner.get_capture_batch_sizes() == [1, 4, 8]