diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index b5016d801b..a1c4b583bf 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -128,6 +128,7 @@ def add_parser_api_server(): ArgumentHelper.enable_eplb(pt_group) ArgumentHelper.role(pt_group) ArgumentHelper.migration_backend(pt_group) + ArgumentHelper.cudagraph_capture_batch_sizes(pt_group) # multi-node serving args node_rank_act = ArgumentHelper.node_rank(pt_group) num_nodes_act = ArgumentHelper.num_nodes(pt_group) @@ -237,6 +238,7 @@ def api_server(args): quant_policy=args.quant_policy, eager_mode=args.eager_mode, max_prefill_token_num=args.max_prefill_token_num, + cudagraph_capture_batch_sizes=args.cudagraph_capture_batch_sizes, enable_microbatch=args.enable_microbatch, enable_eplb=args.enable_eplb, enable_metrics=not args.disable_metrics, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 2c9f8034f6..a5b0f9de33 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -616,6 +616,16 @@ def max_prefill_token_num(parser): default=8192, help='the max number of tokens per iteration during prefill') + @staticmethod + def cudagraph_capture_batch_sizes(parser): + return parser.add_argument('--cudagraph-capture-batch-sizes', + type=int, + nargs='+', + default=None, + help='Batch sizes to capture CUDA graphs for in the PyTorch engine. ' + 'If not specified, the engine infers them from max_batch_size. ' + 'max_batch_size is always captured') + @staticmethod def vision_max_batch_size(parser): return parser.add_argument('--vision-max-batch-size', type=int, default=1, help='the vision model batch size') diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index c262bdbe03..39a41263f6 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -362,6 +362,9 @@ class PytorchEngineConfig: would be allocate according to current environment. adapters: The path configs to lora adapters. max_prefill_token_num: tokens per iteration. + cudagraph_capture_batch_sizes: Batch sizes to capture CUDA graphs for. + If not specified, the engine will infer them from max_batch_size. + max_batch_size is always captured. thread_safe: thread safe engine instance. enable_prefix_caching: Enable token match and sharing caches. device_type: The inference device type, options ['cuda'] @@ -422,6 +425,7 @@ class PytorchEngineConfig: num_gpu_blocks: int = 0 adapters: dict[str, str] = None max_prefill_token_num: int = 8192 + cudagraph_capture_batch_sizes: list[int] | None = None thread_safe: bool = False enable_prefix_caching: bool = False device_type: str = 'cuda' diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 9119c23e87..e82eefbdcb 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -7,7 +7,12 @@ from lmdeploy.pytorch.backends.deepep_state import get_deepep_state from lmdeploy.pytorch.backends.selector import get_backend -from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.pytorch.config import ( + BackendConfig, + CacheConfig, + ModelConfig, + normalize_cudagraph_capture_batch_sizes, +) from lmdeploy.pytorch.envs import fake_capture from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta @@ -342,4 +347,8 @@ def update_inputs(self, inputs): def get_capture_batch_sizes(self) -> list[int]: """Capture batch sizes.""" + if self.cache_config.cudagraph_capture_batch_sizes is not None: + self.cache_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes( + self.cache_config.cudagraph_capture_batch_sizes, self.cache_config.max_batches) + return self.cache_config.cudagraph_capture_batch_sizes return _get_capture_batch_size_impl(self.cache_config.max_batches) diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index 72f460ef5b..c977e6bc63 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -4,7 +4,12 @@ import torch -from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig +from lmdeploy.pytorch.config import ( + BackendConfig, + CacheConfig, + ModelConfig, + normalize_cudagraph_capture_batch_sizes, +) from lmdeploy.pytorch.model_inputs import StepContext @@ -101,4 +106,8 @@ def update_inputs(self, inputs): def get_capture_batch_sizes(self) -> list[int]: """Capture batch sizes.""" + if self.cache_config.cudagraph_capture_batch_sizes is not None: + self.cache_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes( + self.cache_config.cudagraph_capture_batch_sizes, self.cache_config.max_batches) + return self.cache_config.cudagraph_capture_batch_sizes return _get_capture_batch_size_impl(self.cache_config.max_batches) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index af26839855..13dbe61ebd 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -14,6 +14,24 @@ logger = get_logger('lmdeploy') +def normalize_cudagraph_capture_batch_sizes(capture_sizes: list[int] | None, max_batches: int) -> list[int] | None: + """Normalize configured cudagraph capture batch sizes.""" + if capture_sizes is None: + return None + + assert len(capture_sizes) > 0, 'cudagraph_capture_batch_sizes should not be empty' + assert all(isinstance(size, int) and size > 0 for size in capture_sizes), ( + 'cudagraph_capture_batch_sizes should be positive integers') + + capture_sizes = sorted({size for size in capture_sizes if size <= max_batches}) + assert len(capture_sizes) > 0, ( + 'cudagraph_capture_batch_sizes should contain at least one value ' + f'<= max_batch_size ({max_batches})') + if capture_sizes[-1] != max_batches: + capture_sizes.append(max_batches) + return capture_sizes + + def _update_torch_dtype(config: 'ModelConfig', dtype: str, device_type: str = 'auto'): """Update the torch dtype from the model config. @@ -98,6 +116,7 @@ class CacheConfig: window_size: int = -1 cache_max_entry_count: float = 0.8 max_prefill_token_num: int = 8192 + cudagraph_capture_batch_sizes: list[int] | None = None enable_prefix_caching: bool = False quant_policy: QuantPolicy = QuantPolicy.NONE device_type: str = 'cuda' @@ -118,6 +137,8 @@ def __post_init__(self): self.enable_prefix_caching = False if self.kernel_block_size == -1: self.kernel_block_size = self.block_size + self.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes( + self.cudagraph_capture_batch_sizes, self.max_batches) class TPMode(enum.Enum): @@ -611,6 +632,7 @@ def from_config( num_gpu_blocks=target_cache_cfg.num_gpu_blocks, cache_max_entry_count=target_cache_cfg.cache_max_entry_count, max_prefill_token_num=target_cache_cfg.max_prefill_token_num, + cudagraph_capture_batch_sizes=target_cache_cfg.cudagraph_capture_batch_sizes, device_type=target_cache_cfg.device_type, quant_policy=target_cache_cfg.quant_policy, migration_backend=target_cache_cfg.migration_backend) diff --git a/lmdeploy/pytorch/engine/config_builder.py b/lmdeploy/pytorch/engine/config_builder.py index f80132eb92..6b4644d1e0 100644 --- a/lmdeploy/pytorch/engine/config_builder.py +++ b/lmdeploy/pytorch/engine/config_builder.py @@ -10,6 +10,7 @@ MiscConfig, SchedulerConfig, SpecDecodeConfig, + normalize_cudagraph_capture_batch_sizes, ) from lmdeploy.utils import get_logger, get_max_batch_size, get_model @@ -39,6 +40,11 @@ def update_engine_config(engine_config: PytorchEngineConfig): f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size ' f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).') + capture_sizes = engine_config.cudagraph_capture_batch_sizes + if capture_sizes is not None: + engine_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes( + capture_sizes, engine_config.max_batch_size) + if engine_config.dp != 1: if engine_config.tp == 1 and engine_config.ep == 1: logger.warning('Data parallelism is enabled but tensor parallelism and ' @@ -67,6 +73,7 @@ def build_cache_config(engine_config: PytorchEngineConfig): num_gpu_blocks=engine_config.num_gpu_blocks, cache_max_entry_count=engine_config.cache_max_entry_count, max_prefill_token_num=engine_config.max_prefill_token_num, + cudagraph_capture_batch_sizes=engine_config.cudagraph_capture_batch_sizes, enable_prefix_caching=engine_config.enable_prefix_caching, quant_policy=engine_config.quant_policy, device_type=engine_config.device_type, diff --git a/tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py b/tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py new file mode 100644 index 0000000000..01cb783cd5 --- /dev/null +++ b/tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from lmdeploy.messages import PytorchEngineConfig +from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner +from lmdeploy.pytorch.config import CacheConfig +from lmdeploy.pytorch.engine.config_builder import ConfigBuilder + + +def _cache_config(max_batches=8, cudagraph_capture_batch_sizes=None): + return CacheConfig(max_batches=max_batches, + block_size=64, + num_cpu_blocks=0, + num_gpu_blocks=1, + cudagraph_capture_batch_sizes=cudagraph_capture_batch_sizes) + + +def test_custom_capture_batch_sizes_include_max_batch_size(): + engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[4, 1, 4, 16]) + + engine_config = ConfigBuilder.update_engine_config(engine_config) + + assert engine_config.cudagraph_capture_batch_sizes == [1, 4, 8] + + +def test_cache_config_normalizes_capture_batch_sizes(): + cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[4, 1, 4, 16]) + + assert cache_config.cudagraph_capture_batch_sizes == [1, 4, 8] + + +@pytest.mark.parametrize('sizes', [[], [0], [-1], [1.5], ['1'], [16]]) +def test_invalid_capture_batch_sizes_raise(sizes): + with pytest.raises(AssertionError): + _cache_config(max_batches=8, cudagraph_capture_batch_sizes=sizes) + + +def test_capture_batch_size_miss_raises(): + engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[1, 4]) + engine_config = ConfigBuilder.update_engine_config(engine_config) + runner = object.__new__(CUDAGraphRunner) + runner.cache_config = ConfigBuilder.build_cache_config(engine_config) + + assert runner._get_capture_tokens(5) == 8 + with pytest.raises(AssertionError): + runner._get_capture_tokens(9) + + +def test_graph_runner_defensively_normalizes_capture_batch_sizes(): + cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[1, 8]) + cache_config.cudagraph_capture_batch_sizes = [4, 1, 4, 16] + runner = object.__new__(CUDAGraphRunner) + runner.cache_config = cache_config + + assert runner.get_capture_batch_sizes() == [1, 4, 8]