diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index e10f20f44e..bc1cd77bac 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -100,6 +100,7 @@ class VariableInterface: api_server_url: str | None = None allow_terminate_by_client: bool = False enable_abort_handling: bool = False + api_keys: list[str] | None = None response_parser_cls: type[ResponseParser] | None = None @classmethod @@ -1213,6 +1214,8 @@ async def startup_event(): url = f'{VariableInterface.proxy_url}/nodes/add' data = {'url': VariableInterface.api_server_url, 'status': {'models': get_model_list(), 'role': engine_role}} headers = {'accept': 'application/json', 'Content-Type': 'application/json'} + if isinstance(VariableInterface.api_keys, list) and len(VariableInterface.api_keys) > 0: + headers['Authorization'] = f'Bearer {VariableInterface.api_keys[0]}' response = requests.post(url, headers=headers, json=data) if response.status_code != 200: @@ -1394,6 +1397,9 @@ def serve(model_path: str, VariableInterface.allow_terminate_by_client = allow_terminate_by_client VariableInterface.enable_abort_handling = enable_abort_handling + if isinstance(api_keys, str): + api_keys = [api_keys] + VariableInterface.api_keys = api_keys ssl_keyfile, ssl_certfile, http_or_https = None, None, 'http' if ssl: diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 667886273e..5360875171 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -921,6 +921,8 @@ def proxy(server_name: str = '0.0.0.0', with_gdr=True, ) node_manager.cache_status = not disable_cache_status + if isinstance(api_keys, str): + api_keys = [api_keys] if api_keys is not None and (tokens := [key for key in api_keys if key]): from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware diff --git a/lmdeploy/serve/utils/server_utils.py b/lmdeploy/serve/utils/server_utils.py index f032b2bc14..349b1abdb1 100644 --- a/lmdeploy/serve/utils/server_utils.py +++ b/lmdeploy/serve/utils/server_utils.py @@ -85,12 +85,13 @@ class AuthenticationMiddleware: def __init__(self, app: ASGIApp, tokens: list[str]) -> None: self.app = app self.api_tokens = [hashlib.sha256(t.encode('utf-8')).digest() for t in tokens] - # Path prefixes that bypass authentication + # Path prefixes that bypass authentication. Keep this list limited to + # passive public endpoints; proxy node-management routes mutate routing + # state and must stay behind bearer-token authentication. self.skip_prefixes = [ '/health', # Health check endpoints '/docs', # Swagger UI documentation '/redoc', # ReDoc documentation - '/nodes', # Endpoints about node operation between proxy and api_server ] def verify_token(self, headers: Headers) -> bool: diff --git a/tests/test_lmdeploy/serve/test_authentication_middleware.py b/tests/test_lmdeploy/serve/test_authentication_middleware.py new file mode 100644 index 0000000000..78eb9d05be --- /dev/null +++ b/tests/test_lmdeploy/serve/test_authentication_middleware.py @@ -0,0 +1,59 @@ +import importlib.util +from pathlib import Path + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +def _load_authentication_middleware(): + repo_root = Path(__file__).resolve().parents[3] + module_path = repo_root / 'lmdeploy' / 'serve' / 'utils' / 'server_utils.py' + spec = importlib.util.spec_from_file_location('server_utils_for_test', module_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module.AuthenticationMiddleware + + +AuthenticationMiddleware = _load_authentication_middleware() + + +def _make_client(): + app = FastAPI() + + @app.get('/health') + def health(): + return {'ok': True} + + @app.post('/nodes/add') + def add_node(): + return {'added': True} + + @app.get('/nodes/status') + def node_status(): + return {'nodes': []} + + @app.post('/v1/chat/completions') + def chat_completions(): + return {'ok': True} + + app.add_middleware(AuthenticationMiddleware, tokens=['secret']) + return TestClient(app) + + +def test_auth_middleware_protects_node_management_routes(): + client = _make_client() + + assert client.post('/nodes/add').status_code == 401 + assert client.get('/nodes/status').status_code == 401 + + headers = {'Authorization': 'Bearer secret'} + assert client.post('/nodes/add', headers=headers).status_code == 200 + assert client.get('/nodes/status', headers=headers).status_code == 200 + + +def test_auth_middleware_keeps_passive_health_public(): + client = _make_client() + + assert client.get('/health').status_code == 200 + assert client.post('/v1/chat/completions').status_code == 401