Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def add_parser_api_server():
ArgumentHelper.enable_eplb(pt_group)
ArgumentHelper.role(pt_group)
ArgumentHelper.migration_backend(pt_group)
ArgumentHelper.cudagraph_capture_batch_sizes(pt_group)
# multi-node serving args
node_rank_act = ArgumentHelper.node_rank(pt_group)
num_nodes_act = ArgumentHelper.num_nodes(pt_group)
Expand Down Expand Up @@ -237,6 +238,7 @@ def api_server(args):
quant_policy=args.quant_policy,
eager_mode=args.eager_mode,
max_prefill_token_num=args.max_prefill_token_num,
cudagraph_capture_batch_sizes=args.cudagraph_capture_batch_sizes,
enable_microbatch=args.enable_microbatch,
enable_eplb=args.enable_eplb,
enable_metrics=not args.disable_metrics,
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,16 @@ def max_prefill_token_num(parser):
default=8192,
help='the max number of tokens per iteration during prefill')

@staticmethod
def cudagraph_capture_batch_sizes(parser):
return parser.add_argument('--cudagraph-capture-batch-sizes',
type=int,
nargs='+',
default=None,
help='Batch sizes to capture CUDA graphs for in the PyTorch engine. '
'If not specified, the engine infers them from max_batch_size. '
'max_batch_size is always captured')

@staticmethod
def vision_max_batch_size(parser):
return parser.add_argument('--vision-max-batch-size', type=int, default=1, help='the vision model batch size')
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ class PytorchEngineConfig:
would be allocate according to current environment.
adapters: The path configs to lora adapters.
max_prefill_token_num: tokens per iteration.
cudagraph_capture_batch_sizes: Batch sizes to capture CUDA graphs for.
If not specified, the engine will infer them from max_batch_size.
max_batch_size is always captured.
thread_safe: thread safe engine instance.
enable_prefix_caching: Enable token match and sharing caches.
device_type: The inference device type, options ['cuda']
Expand Down Expand Up @@ -422,6 +425,7 @@ class PytorchEngineConfig:
num_gpu_blocks: int = 0
adapters: dict[str, str] = None
max_prefill_token_num: int = 8192
cudagraph_capture_batch_sizes: list[int] | None = None
thread_safe: bool = False
enable_prefix_caching: bool = False
device_type: str = 'cuda'
Expand Down
11 changes: 10 additions & 1 deletion lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

from lmdeploy.pytorch.backends.deepep_state import get_deepep_state
from lmdeploy.pytorch.backends.selector import get_backend
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.config import (
BackendConfig,
CacheConfig,
ModelConfig,
normalize_cudagraph_capture_batch_sizes,
)
from lmdeploy.pytorch.envs import fake_capture
from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
Expand Down Expand Up @@ -342,4 +347,8 @@ def update_inputs(self, inputs):

def get_capture_batch_sizes(self) -> list[int]:
"""Capture batch sizes."""
if self.cache_config.cudagraph_capture_batch_sizes is not None:
self.cache_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes(
self.cache_config.cudagraph_capture_batch_sizes, self.cache_config.max_batches)
return self.cache_config.cudagraph_capture_batch_sizes
return _get_capture_batch_size_impl(self.cache_config.max_batches)
Comment thread
CUHKSZzxy marked this conversation as resolved.
11 changes: 10 additions & 1 deletion lmdeploy/pytorch/backends/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

import torch

from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.config import (
BackendConfig,
CacheConfig,
ModelConfig,
normalize_cudagraph_capture_batch_sizes,
)
from lmdeploy.pytorch.model_inputs import StepContext


Expand Down Expand Up @@ -101,4 +106,8 @@ def update_inputs(self, inputs):

def get_capture_batch_sizes(self) -> list[int]:
"""Capture batch sizes."""
if self.cache_config.cudagraph_capture_batch_sizes is not None:
self.cache_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes(
self.cache_config.cudagraph_capture_batch_sizes, self.cache_config.max_batches)
return self.cache_config.cudagraph_capture_batch_sizes
return _get_capture_batch_size_impl(self.cache_config.max_batches)
Comment thread
CUHKSZzxy marked this conversation as resolved.
22 changes: 22 additions & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@
logger = get_logger('lmdeploy')


def normalize_cudagraph_capture_batch_sizes(capture_sizes: list[int] | None, max_batches: int) -> list[int] | None:
"""Normalize configured cudagraph capture batch sizes."""
if capture_sizes is None:
return None

assert len(capture_sizes) > 0, 'cudagraph_capture_batch_sizes should not be empty'
assert all(isinstance(size, int) and size > 0 for size in capture_sizes), (
'cudagraph_capture_batch_sizes should be positive integers')

capture_sizes = sorted({size for size in capture_sizes if size <= max_batches})
assert len(capture_sizes) > 0, (
'cudagraph_capture_batch_sizes should contain at least one value '
f'<= max_batch_size ({max_batches})')
if capture_sizes[-1] != max_batches:
capture_sizes.append(max_batches)
return capture_sizes


def _update_torch_dtype(config: 'ModelConfig', dtype: str, device_type: str = 'auto'):
"""Update the torch dtype from the model config.

Expand Down Expand Up @@ -98,6 +116,7 @@ class CacheConfig:
window_size: int = -1
cache_max_entry_count: float = 0.8
max_prefill_token_num: int = 8192
cudagraph_capture_batch_sizes: list[int] | None = None
enable_prefix_caching: bool = False
quant_policy: QuantPolicy = QuantPolicy.NONE
device_type: str = 'cuda'
Expand All @@ -118,6 +137,8 @@ def __post_init__(self):
self.enable_prefix_caching = False
if self.kernel_block_size == -1:
self.kernel_block_size = self.block_size
self.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes(
self.cudagraph_capture_batch_sizes, self.max_batches)


class TPMode(enum.Enum):
Expand Down Expand Up @@ -611,6 +632,7 @@ def from_config(
num_gpu_blocks=target_cache_cfg.num_gpu_blocks,
cache_max_entry_count=target_cache_cfg.cache_max_entry_count,
max_prefill_token_num=target_cache_cfg.max_prefill_token_num,
cudagraph_capture_batch_sizes=target_cache_cfg.cudagraph_capture_batch_sizes,
device_type=target_cache_cfg.device_type,
quant_policy=target_cache_cfg.quant_policy,
migration_backend=target_cache_cfg.migration_backend)
Expand Down
7 changes: 7 additions & 0 deletions lmdeploy/pytorch/engine/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
MiscConfig,
SchedulerConfig,
SpecDecodeConfig,
normalize_cudagraph_capture_batch_sizes,
)
from lmdeploy.utils import get_logger, get_max_batch_size, get_model

Expand Down Expand Up @@ -39,6 +40,11 @@ def update_engine_config(engine_config: PytorchEngineConfig):
f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size '
f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).')

capture_sizes = engine_config.cudagraph_capture_batch_sizes
if capture_sizes is not None:
engine_config.cudagraph_capture_batch_sizes = normalize_cudagraph_capture_batch_sizes(
capture_sizes, engine_config.max_batch_size)

if engine_config.dp != 1:
if engine_config.tp == 1 and engine_config.ep == 1:
logger.warning('Data parallelism is enabled but tensor parallelism and '
Expand Down Expand Up @@ -67,6 +73,7 @@ def build_cache_config(engine_config: PytorchEngineConfig):
num_gpu_blocks=engine_config.num_gpu_blocks,
cache_max_entry_count=engine_config.cache_max_entry_count,
max_prefill_token_num=engine_config.max_prefill_token_num,
cudagraph_capture_batch_sizes=engine_config.cudagraph_capture_batch_sizes,
enable_prefix_caching=engine_config.enable_prefix_caching,
quant_policy=engine_config.quant_policy,
device_type=engine_config.device_type,
Expand Down
55 changes: 55 additions & 0 deletions tests/pytorch/engine/test_cudagraph_capture_batch_sizes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest

from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner
from lmdeploy.pytorch.config import CacheConfig
from lmdeploy.pytorch.engine.config_builder import ConfigBuilder


def _cache_config(max_batches=8, cudagraph_capture_batch_sizes=None):
return CacheConfig(max_batches=max_batches,
block_size=64,
num_cpu_blocks=0,
num_gpu_blocks=1,
cudagraph_capture_batch_sizes=cudagraph_capture_batch_sizes)


def test_custom_capture_batch_sizes_include_max_batch_size():
engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[4, 1, 4, 16])

engine_config = ConfigBuilder.update_engine_config(engine_config)

assert engine_config.cudagraph_capture_batch_sizes == [1, 4, 8]


def test_cache_config_normalizes_capture_batch_sizes():
cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[4, 1, 4, 16])

assert cache_config.cudagraph_capture_batch_sizes == [1, 4, 8]


@pytest.mark.parametrize('sizes', [[], [0], [-1], [1.5], ['1'], [16]])
def test_invalid_capture_batch_sizes_raise(sizes):
with pytest.raises(AssertionError):
_cache_config(max_batches=8, cudagraph_capture_batch_sizes=sizes)


def test_capture_batch_size_miss_raises():
engine_config = PytorchEngineConfig(max_batch_size=8, cudagraph_capture_batch_sizes=[1, 4])
engine_config = ConfigBuilder.update_engine_config(engine_config)
runner = object.__new__(CUDAGraphRunner)
runner.cache_config = ConfigBuilder.build_cache_config(engine_config)

assert runner._get_capture_tokens(5) == 8
with pytest.raises(AssertionError):
runner._get_capture_tokens(9)


def test_graph_runner_defensively_normalizes_capture_batch_sizes():
cache_config = _cache_config(max_batches=8, cudagraph_capture_batch_sizes=[1, 8])
cache_config.cudagraph_capture_batch_sizes = [4, 1, 4, 16]
runner = object.__new__(CUDAGraphRunner)
runner.cache_config = cache_config

assert runner.get_capture_batch_sizes() == [1, 4, 8]
Loading