From efb8424b3ec067993e5d294792d953869ccc29ff Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Sun, 15 Feb 2026 12:44:36 +0000 Subject: [PATCH 1/5] perf: reuse sessions, more sync code removal --- model-engine/model_engine_server/api/app.py | 97 ++++++---- .../inference/forwarding/forwarding.py | 169 +++++++++++------- ..._async_model_endpoint_inference_gateway.py | 8 +- ...eaming_model_endpoint_inference_gateway.py | 46 +++-- ...e_sync_model_endpoint_inference_gateway.py | 52 ++++-- .../live_batch_job_orchestration_service.py | 50 ++++-- .../tests/unit/inference/test_forwarding.py | 15 +- 7 files changed, 283 insertions(+), 154 deletions(-) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index cac68cda2..bd53745c7 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -1,13 +1,14 @@ +import json as json_module import os import traceback import uuid from datetime import datetime from pathlib import Path +from typing import Any import pytz -from fastapi import FastAPI, HTTPException, Request, Response +from fastapi import FastAPI, HTTPException, Response from fastapi.openapi.docs import get_redoc_html -from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from model_engine_server.api.batch_jobs_v1 import batch_job_router_v1 from model_engine_server.api.dependencies import get_or_create_aioredis_pool @@ -31,8 +32,8 @@ make_logger, ) from model_engine_server.core.tracing import get_tracing_gateway -from starlette.middleware import Middleware -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.types import ASGIApp, Receive, Scope, Send logger = make_logger(logger_name()) @@ -43,32 +44,44 @@ concurrency=MAX_CONCURRENCY, fail_on_concurrency_limit=True ) -healthcheck_routes = ["/healthcheck", "/healthz", "/readyz"] +healthcheck_routes = {"/healthcheck", "/healthz", "/readyz"} tracing_gateway = get_tracing_gateway() -class CustomMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): +class CustomMiddleware: + """ + Pure ASGI middleware for request tracking, tracing, and concurrency limiting. + Unlike BaseHTTPMiddleware this does not buffer the entire response body, + which is important for streaming inference responses. + """ + + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = Request(scope, receive) + + LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) + LoggerTagManager.set(LoggerTagKey.REQUEST_SIZE, request.headers.get("content-length")) + if tracing_gateway: + tracing_gateway.extract_tracing_headers(request, service="model_engine_server") + + # Healthcheck routes bypass concurrency limiting + if request.url.path in healthcheck_routes: + await self.app(scope, receive, send) + return + try: - LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) - LoggerTagManager.set(LoggerTagKey.REQUEST_SIZE, request.headers.get("content-length")) - if tracing_gateway: - tracing_gateway.extract_tracing_headers(request, service="model_engine_server") - # we intentionally exclude healthcheck routes from the concurrency limiter - if request.url.path in healthcheck_routes: - return await call_next(request) with concurrency_limiter: - return await call_next(request) + await self.app(scope, receive, send) except HTTPException as e: timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") - return JSONResponse( - status_code=e.status_code, - content={ - "error": e.detail, - "timestamp": timestamp, - }, - ) + await _send_json_error(send, e.status_code, e.detail, timestamp) except Exception as e: tb_str = traceback.format_exception(e) request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) @@ -79,22 +92,46 @@ async def dispatch(self, request: Request, call_next): "traceback": "".join(tb_str), } logger.error("Unhandled exception: %s", structured_log) - return JSONResponse( - status_code=500, - content={ - "error": "Internal error occurred. Our team has been notified.", - "timestamp": timestamp, - "request_id": request_id, - }, + await _send_json_error( + send, + 500, + "Internal error occurred. Our team has been notified.", + timestamp, + request_id=request_id, ) +async def _send_json_error( + send: Send, + status_code: int, + error: Any, + timestamp: str, + request_id: Any = None, +): + """Send a JSON error response directly via ASGI send.""" + body: dict = {"error": error, "timestamp": timestamp} + if request_id is not None: + body["request_id"] = request_id + body_bytes = json_module.dumps(body).encode("utf-8") + await send( + { + "type": "http.response.start", + "status": status_code, + "headers": [ + [b"content-type", b"application/json"], + [b"content-length", str(len(body_bytes)).encode()], + ], + } + ) + await send({"type": "http.response.body", "body": body_bytes}) + + app = FastAPI( title="launch", version="1.0.0", redoc_url=None, - middleware=[Middleware(CustomMiddleware)], ) +app.add_middleware(CustomMiddleware) app.include_router(batch_job_router_v1) app.include_router(inference_task_router_v1) diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index 5183955b9..7526c154d 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -42,6 +42,20 @@ DEFAULT_PORT: int = 5005 +TTL_DNS_CACHE: int = 300 + + +def _wait_for_service_ready(healthcheck_url: str): + """Block until the service at healthcheck_url returns 200.""" + while True: + try: + if requests.get(healthcheck_url, timeout=5).status_code == 200: + return + except requests.exceptions.ConnectionError: + pass + logger.info(f"Waiting for user-defined service to be ready at {healthcheck_url}...") + time.sleep(1) + class ModelEngineSerializationMixin: """Mixin class for optionally wrapping Model Engine requests.""" @@ -170,6 +184,21 @@ class Forwarder(ModelEngineSerializationMixin): forward_http_status_in_body: bool post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None + def __post_init__(self): + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + connector = aiohttp.TCPConnector(ttl_dns_cache=TTL_DNS_CACHE) + self._session = aiohttp.ClientSession( + json_serialize=_serialize_json, connector=connector + ) + return self._session + + async def close(self): + if self._session and not self._session.closed: + await self._session.close() + async def forward(self, json_payload: Any, trace_config: Optional[str] = None) -> Any: json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload @@ -177,18 +206,18 @@ async def forward(self, json_payload: Any, trace_config: Optional[str] = None) - logger.info(f"Accepted request, forwarding {json_payload_repr=}") try: - async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: - headers = {"Content-Type": "application/json"} - if trace_config and tracing_gateway: - headers.update(tracing_gateway.encode_trace_headers()) - response_raw = await aioclient.post( - self.predict_endpoint, - json=json_payload, - headers=headers, - ) - response = await response_raw.json( - content_type=None - ) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error + aioclient = await self._get_session() + headers = {"Content-Type": "application/json"} + if trace_config and tracing_gateway: + headers.update(tracing_gateway.encode_trace_headers()) + response_raw = await aioclient.post( + self.predict_endpoint, + json=json_payload, + headers=headers, + ) + response = await response_raw.json( + content_type=None + ) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error except Exception: logger.exception( @@ -345,15 +374,7 @@ def endpoint(route: str) -> str: logger.info(f"Prediction endpoint: {pred}") logger.info(f"Healthcheck endpoint: {hc}") - while True: - try: - if requests.get(hc).status_code == 200: - break - except requests.exceptions.ConnectionError: - pass - - logger.info(f"Waiting for user-defined service to be ready at {hc}...") - time.sleep(1) + _wait_for_service_ready(hc) logger.info(f"Unwrapping model engine payload formatting?: {self.model_engine_unwrap}") @@ -429,6 +450,21 @@ class StreamingForwarder(ModelEngineSerializationMixin): serialize_results_as_string: bool post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None # unused for now + def __post_init__(self): + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + connector = aiohttp.TCPConnector(ttl_dns_cache=TTL_DNS_CACHE) + self._session = aiohttp.ClientSession( + json_serialize=_serialize_json, connector=connector + ) + return self._session + + async def close(self): + if self._session and not self._session.closed: + await self._session.close() + async def forward(self, json_payload: Any) -> AsyncGenerator[Any, None]: # pragma: no cover json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload @@ -436,24 +472,24 @@ async def forward(self, json_payload: Any) -> AsyncGenerator[Any, None]: # prag logger.info(f"Accepted request, forwarding {json_payload_repr=}") try: - response: aiohttp.ClientResponse - async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: - response = await aioclient.post( - self.predict_endpoint, - json=json_payload, - headers={"Content-Type": "application/json"}, - ) + aioclient = await self._get_session() + response = await aioclient.post( + self.predict_endpoint, + json=json_payload, + headers={"Content-Type": "application/json"}, + ) - if response.status != 200: - raise HTTPException( - status_code=response.status, detail=await response.json(content_type=None) - ) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error + if response.status != 200: + raise HTTPException( + status_code=response.status, + detail=await response.json(content_type=None), + ) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error - async with EventSource(response=response) as event_source: - async for event in event_source: - yield self.get_response_payload_stream( - using_serialize_results_as_string, event.data - ) + async with EventSource(response=response) as event_source: + async for event in event_source: + yield self.get_response_payload_stream( + using_serialize_results_as_string, event.data + ) except Exception: logger.exception( @@ -567,15 +603,7 @@ def endpoint(route: str) -> str: logger.info(f"Prediction endpoint: {pred}") logger.info(f"Healthcheck endpoint: {hc}") - while True: - try: - if requests.get(hc).status_code == 200: - break - except requests.exceptions.ConnectionError: - pass - - logger.info(f"Waiting for user-defined service to be ready at {hc}...") - time.sleep(1) + _wait_for_service_ready(hc) logger.info(f"Unwrapping model engine payload formatting?: {self.model_engine_unwrap}") @@ -631,6 +659,19 @@ def endpoint(route: str) -> str: class PassthroughForwarder(ModelEngineSerializationMixin): passthrough_endpoint: str + def __post_init__(self): + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + connector = aiohttp.TCPConnector(ttl_dns_cache=TTL_DNS_CACHE) + self._session = aiohttp.ClientSession(connector=connector) + return self._session + + async def close(self): + if self._session and not self._session.closed: + await self._session.close() + async def _make_request( self, request: Any, aioclient: aiohttp.ClientSession ) -> aiohttp.ClientResponse: @@ -651,28 +692,28 @@ async def _make_request( return await aioclient.request( method=request.method, url=target_url, - data=await request.body() if request.method in ["POST", "PUT", "PATCH"] else None, + data=(await request.body() if request.method in ["POST", "PUT", "PATCH"] else None), headers=headers, ) async def forward_stream(self, request: Any): - async with aiohttp.ClientSession() as aioclient: - response = await self._make_request(request, aioclient) - response_headers = response.headers - yield (response_headers, response.status) + aioclient = await self._get_session() + response = await self._make_request(request, aioclient) + response_headers = response.headers + yield (response_headers, response.status) - if response.status != 200: - yield await response.read() + if response.status != 200: + yield await response.read() - async for chunk in response.content.iter_chunks(): - yield chunk[0] + async for chunk in response.content.iter_chunks(): + yield chunk[0] - yield await response.read() + yield await response.read() async def forward_sync(self, request: Any): - async with aiohttp.ClientSession() as aioclient: - response = await self._make_request(request, aioclient) - return response + aioclient = await self._get_session() + response = await self._make_request(request, aioclient) + return response @dataclass(frozen=True) @@ -711,15 +752,7 @@ def endpoint(route: str) -> str: logger.info(f"Passthrough endpoint: {passthrough_endpoint}") logger.info(f"Healthcheck endpoint: {hc}") - while True: - try: - if requests.get(hc).status_code == 200: - break - except requests.exceptions.ConnectionError: - pass - - logger.info(f"Waiting for user-defined service to be ready at {hc}...") - time.sleep(1) + _wait_for_service_ready(hc) logger.info(f"Creating PassthroughForwarder with endpoint: {passthrough_endpoint}") return PassthroughForwarder(passthrough_endpoint=passthrough_endpoint) diff --git a/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py index f2aa50328..d52196aec 100644 --- a/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py @@ -1,4 +1,3 @@ -import json from datetime import datetime from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME @@ -31,9 +30,10 @@ async def create_task( *, task_name: str = DEFAULT_CELERY_TASK_NAME, ) -> CreateAsyncTaskV1Response: - # Use json.loads instead of predict_request.dict() because we have overridden the 'root' - # key in some fields, and root overriding only reflects in the json() output. - predict_args = json.loads(predict_request.json()) + # model_dump(mode="json") produces a JSON-compatible dict directly, handling + # RootModel fields (like RequestSchema and CallbackAuth) correctly in Pydantic v2 + # without the overhead of serializing to a JSON string and parsing it back + predict_args = predict_request.model_dump(mode="json") send_task_response = await self.task_queue_gateway.send_task_async( task_name=task_name, diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py index 86151922f..f3d4d91fd 100644 --- a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -46,6 +46,8 @@ 1.2 # Must be a float > 1.0, lower number means more retries but less time waiting. ) +TTL_DNS_CACHE: int = 300 + def _get_streaming_endpoint_url( service_name: str, path: str = "/stream", manually_resolve_dns: bool = False @@ -62,7 +64,8 @@ def _get_streaming_endpoint_url( elif manually_resolve_dns: protocol = "http" hostname = resolve_dns( - f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local", port=protocol + f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local", + port=protocol, ) else: protocol = "http" @@ -89,24 +92,37 @@ class LiveStreamingModelEndpointInferenceGateway(StreamingModelEndpointInference def __init__(self, monitoring_metrics_gateway: MonitoringMetricsGateway, use_asyncio: bool): self.monitoring_metrics_gateway = monitoring_metrics_gateway self.use_asyncio = use_asyncio + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + connector = aiohttp.TCPConnector(ttl_dns_cache=TTL_DNS_CACHE) + self._session = aiohttp.ClientSession( + json_serialize=_serialize_json, connector=connector + ) + return self._session + + async def close(self): + if self._session and not self._session.closed: + await self._session.close() async def make_single_request(self, request_url: str, payload_json: Dict[str, Any]): errored = False if self.use_asyncio: - async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: - aio_resp = await aioclient.post( - request_url, - json=payload_json, - headers={"Content-Type": "application/json"}, - ) - status = aio_resp.status - if status == 200: - async with EventSource(response=aio_resp) as event_source: - async for event in event_source: - yield event.data - else: - content = await aio_resp.read() - errored = True + aioclient = await self._get_session() + aio_resp = await aioclient.post( + request_url, + json=payload_json, + headers={"Content-Type": "application/json"}, + ) + status = aio_resp.status + if status == 200: + async with EventSource(response=aio_resp) as event_source: + async for event in event_source: + yield event.data + else: + content = await aio_resp.read() + errored = True else: resp = requests.post( request_url, diff --git a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py index e65c30440..4c90b5d68 100644 --- a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py @@ -44,9 +44,13 @@ 1.2 # Must be a float > 1.0, lower number means more retries but less time waiting. ) +TTL_DNS_CACHE: int = 300 + def _get_sync_endpoint_url( - service_name: str, destination_path: str = "/predict", manually_resolve_dns: bool = False + service_name: str, + destination_path: str = "/predict", + manually_resolve_dns: bool = False, ) -> str: if CIRCLECI: # Circle CI: a NodePort is used to expose the service @@ -60,7 +64,8 @@ def _get_sync_endpoint_url( elif manually_resolve_dns: protocol = "http" hostname = resolve_dns( - f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local", port=protocol + f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local", + port=protocol, ) else: protocol = "http" @@ -89,6 +94,19 @@ def __init__( self.monitoring_metrics_gateway = monitoring_metrics_gateway self.tracing_gateway = tracing_gateway self.use_asyncio = use_asyncio + self._session: Optional[aiohttp.ClientSession] = None + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + connector = aiohttp.TCPConnector(ttl_dns_cache=TTL_DNS_CACHE) + self._session = aiohttp.ClientSession( + json_serialize=_serialize_json, connector=connector + ) + return self._session + + async def close(self): + if self._session and not self._session.closed: + await self._session.close() async def make_single_request(self, request_url: str, payload_json: Dict[str, Any]): # DEBUG: Log request details @@ -104,22 +122,22 @@ async def make_single_request(self, request_url: str, payload_json: Dict[str, An if self.use_asyncio: try: - async with aiohttp.ClientSession(json_serialize=_serialize_json) as client: - aio_resp = await client.post( - request_url, - json=payload_json, - headers=headers, + client = await self._get_session() + aio_resp = await client.post( + request_url, + json=payload_json, + headers=headers, + ) + status = aio_resp.status + if infra_config().debug_mode: # pragma: no cover + logger.info(f"DEBUG: Response status: {status}") + if status == 200: + return await aio_resp.json() + content = await aio_resp.read() + if infra_config().debug_mode: # pragma: no cover + logger.warning( + f"DEBUG: Non-200 response. Status: {status}, Content: {content.decode('utf-8', errors='replace')}" ) - status = aio_resp.status - if infra_config().debug_mode: # pragma: no cover - logger.info(f"DEBUG: Response status: {status}") - if status == 200: - return await aio_resp.json() - content = await aio_resp.read() - if infra_config().debug_mode: # pragma: no cover - logger.warning( - f"DEBUG: Non-200 response. Status: {status}, Content: {content.decode('utf-8', errors='replace')}" - ) except Exception as e: if infra_config().debug_mode: # pragma: no cover logger.error( diff --git a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py index e3d04fd09..efe566958 100644 --- a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py @@ -5,7 +5,6 @@ import json import pickle import sys -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from datetime import datetime, timedelta from typing import List, Optional, Union @@ -165,8 +164,11 @@ async def _run_batch_job( status=BatchJobStatus.RUNNING, ) - results = self._poll_tasks( - owner=owner, job_id=job_id, task_ids=task_ids, timeout_timestamp=timeout_timestamp + results = await self._poll_tasks( + owner=owner, + job_id=job_id, + task_ids=task_ids, + timeout_timestamp=timeout_timestamp, ) result_location = batch_job_record.result_location @@ -202,7 +204,10 @@ async def _wait_for_endpoint_to_be_ready( model_endpoint = await self.model_endpoint_service.get_model_endpoint_record( model_endpoint_id=model_endpoint_id, ) - updating = {ModelEndpointStatus.UPDATE_PENDING, ModelEndpointStatus.UPDATE_IN_PROGRESS} + updating = { + ModelEndpointStatus.UPDATE_PENDING, + ModelEndpointStatus.UPDATE_IN_PROGRESS, + } assert model_endpoint while model_endpoint.status in updating: @@ -240,7 +245,9 @@ async def _read_or_submit_tasks( pending_task_ids_location = batch_job_record.task_ids_location if pending_task_ids_location is not None: with self.filesystem_gateway.open( - pending_task_ids_location, "r", aws_profile=infra_config().profile_ml_worker + pending_task_ids_location, + "r", + aws_profile=infra_config().profile_ml_worker, ) as f: task_ids_serialized = f.read().splitlines() task_ids = [ @@ -254,7 +261,9 @@ async def _read_or_submit_tasks( task_ids = await self._submit_tasks(queue_name, input_path, task_name) pending_task_ids_location = self._get_pending_task_ids_location(job_id) with self.filesystem_gateway.open( - pending_task_ids_location, "w", aws_profile=infra_config().profile_ml_worker + pending_task_ids_location, + "w", + aws_profile=infra_config().profile_ml_worker, ) as f: f.write("\n".join([tid.serialize() for tid in task_ids])) await self.batch_job_record_repository.update_batch_job_record( @@ -304,7 +313,7 @@ async def _create_task( task_ids = await asyncio.gather(*[_create_task(inp) for inp in inputs]) return list(task_ids) - def _poll_tasks( + async def _poll_tasks( self, owner: str, job_id: str, @@ -312,31 +321,38 @@ def _poll_tasks( timeout_timestamp: datetime, ) -> List[BatchEndpointInferencePredictionResponse]: # Poll the task queue until all tasks are complete. - # Python multithreading works here because retrieving the tasks is I/O bound. + # Uses run_in_executor for the blocking get_task calls and asyncio.sleep + # between poll rounds to avoid spinning. task_ids_only = [in_progress_task.task_id for in_progress_task in task_ids] task_id_to_ref_id_map = { in_progress_task.task_id: in_progress_task.reference_id for in_progress_task in task_ids } pending_task_ids_set = set(task_ids_only) task_id_to_result = {} - executor = ThreadPoolExecutor() + loop = asyncio.get_event_loop() progress = BatchJobProgress( num_tasks_pending=len(pending_task_ids_set), num_tasks_completed=0, ) self.batch_job_progress_gateway.update_progress(owner, job_id, progress) + poll_interval = 2 # seconds, will increase with backoff + terminal_task_states = {TaskStatus.SUCCESS, TaskStatus.FAILURE} while pending_task_ids_set: - new_results = executor.map( - self.async_model_endpoint_inference_gateway.get_task, pending_task_ids_set + new_results = await asyncio.gather( + *[ + loop.run_in_executor( + None, self.async_model_endpoint_inference_gateway.get_task, tid + ) + for tid in pending_task_ids_set + ] ) has_new_ready_tasks = False curr_timestamp = datetime.utcnow() - terminal_task_states = {TaskStatus.SUCCESS, TaskStatus.FAILURE} for r in new_results: if r.status in terminal_task_states or curr_timestamp > timeout_timestamp: has_new_ready_tasks = True task_id_to_result[r.task_id] = r - pending_task_ids_set.remove(r.task_id) + pending_task_ids_set.discard(r.task_id) if has_new_ready_tasks: logger.info( @@ -348,10 +364,16 @@ def _poll_tasks( num_tasks_completed=len(task_id_to_result), ) self.batch_job_progress_gateway.update_progress(owner, job_id, progress) + poll_interval = 2 # reset on progress + + if pending_task_ids_set: + await asyncio.sleep(poll_interval) + poll_interval = min(poll_interval * 1.5, 30) # backoff, cap at 30s results = [ BatchEndpointInferencePredictionResponse( - response=task_id_to_result[task_id], reference_id=task_id_to_ref_id_map[task_id] + response=task_id_to_result[task_id], + reference_id=task_id_to_ref_id_map[task_id], ) for task_id in task_ids_only ] diff --git a/model-engine/tests/unit/inference/test_forwarding.py b/model-engine/tests/unit/inference/test_forwarding.py index 3e0141c83..6e8d4a3e6 100644 --- a/model-engine/tests/unit/inference/test_forwarding.py +++ b/model-engine/tests/unit/inference/test_forwarding.py @@ -589,9 +589,10 @@ async def test_passthrough_forwarder(): fwd = PassthroughForwarder(passthrough_endpoint="http://localhost:5005/mcp/test") mock_request = MockRequest(method="POST", path="/mcp/test", query="param=value") - with mock.patch("aiohttp.ClientSession") as mock_session: + with mock.patch("aiohttp.ClientSession") as mock_session, mock.patch("aiohttp.TCPConnector"): mock_client = mocked_aiohttp_client_session() - mock_session.return_value.__aenter__.return_value = mock_client + mock_client.closed = False + mock_session.return_value = mock_client response_generator = fwd.forward_stream(mock_request) await _check_passthrough_response(response_generator) @@ -614,9 +615,10 @@ async def test_passthrough_forwarder_get_request(): fwd = PassthroughForwarder(passthrough_endpoint="http://localhost:5005/mcp/status") mock_request = MockRequest(method="GET", path="/mcp/status", query="", body_data=b"") - with mock.patch("aiohttp.ClientSession") as mock_session: + with mock.patch("aiohttp.ClientSession") as mock_session, mock.patch("aiohttp.TCPConnector"): mock_client = mocked_aiohttp_client_session() - mock_session.return_value.__aenter__.return_value = mock_client + mock_client.closed = False + mock_session.return_value = mock_client response_generator = fwd.forward_stream(mock_request) await _check_passthrough_response(response_generator) @@ -650,9 +652,10 @@ async def test_passthrough_forwarder_header_filtering(): mock_request = MockRequest(method="POST", path="/mcp/test", headers=headers_with_excluded) - with mock.patch("aiohttp.ClientSession") as mock_session: + with mock.patch("aiohttp.ClientSession") as mock_session, mock.patch("aiohttp.TCPConnector"): mock_client = mocked_aiohttp_client_session() - mock_session.return_value.__aenter__.return_value = mock_client + mock_client.closed = False + mock_session.return_value = mock_client response_generator = fwd.forward_stream(mock_request) await _check_passthrough_response(response_generator) From db315ad31a9e8437af18129b68beb0da41bf2a77 Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Sun, 15 Feb 2026 12:45:51 +0000 Subject: [PATCH 2/5] simplify --- model-engine/model_engine_server/api/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index bd53745c7..ec1ddc88d 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -1,4 +1,4 @@ -import json as json_module +import json import os import traceback import uuid @@ -112,7 +112,7 @@ async def _send_json_error( body: dict = {"error": error, "timestamp": timestamp} if request_id is not None: body["request_id"] = request_id - body_bytes = json_module.dumps(body).encode("utf-8") + body_bytes = json.dumps(body).encode("utf-8") await send( { "type": "http.response.start", From c8403876d4774bddbf3de49dce48f94bb88df660 Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Sun, 15 Feb 2026 12:59:27 +0000 Subject: [PATCH 3/5] ci --- .../infra/services/live_batch_job_orchestration_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py index efe566958..5afb819b9 100644 --- a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py @@ -335,7 +335,7 @@ async def _poll_tasks( num_tasks_completed=0, ) self.batch_job_progress_gateway.update_progress(owner, job_id, progress) - poll_interval = 2 # seconds, will increase with backoff + poll_interval = 2.0 # seconds, will increase with backoff terminal_task_states = {TaskStatus.SUCCESS, TaskStatus.FAILURE} while pending_task_ids_set: new_results = await asyncio.gather( @@ -364,7 +364,7 @@ async def _poll_tasks( num_tasks_completed=len(task_id_to_result), ) self.batch_job_progress_gateway.update_progress(owner, job_id, progress) - poll_interval = 2 # reset on progress + poll_interval = 2.0 # reset on progress if pending_task_ids_set: await asyncio.sleep(poll_interval) From 99b04c3fb8fb1e3a054c95cc294a3cc44855be02 Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Sun, 15 Feb 2026 13:25:43 +0000 Subject: [PATCH 4/5] fix ci --- model-engine/requirements-test.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model-engine/requirements-test.txt b/model-engine/requirements-test.txt index 0115722b4..5d2c54b15 100644 --- a/model-engine/requirements-test.txt +++ b/model-engine/requirements-test.txt @@ -8,6 +8,8 @@ moto==3.1.12 mypy==1.3.0 pylint<3.0.0 pytest==7.2.0 +# typeguard 4.5.0 (2026-02-15) requires typing_extensions>=4.13.0 but we pin 4.10.0 +typeguard<4.5.0 pytest-asyncio==0.20.1 pytest-cov==2.10.0 pytest-mypy==0.9.1 From f2160fe3c9df5a55bb1d3cf790e7971548119198 Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Sun, 15 Feb 2026 13:43:43 +0000 Subject: [PATCH 5/5] ci --- .../model_engine_server/inference/forwarding/forwarding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index 7526c154d..c4f4f0e7e 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -184,7 +184,7 @@ class Forwarder(ModelEngineSerializationMixin): forward_http_status_in_body: bool post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None - def __post_init__(self): + def __post_init__(self) -> None: self._session: Optional[aiohttp.ClientSession] = None async def _get_session(self) -> aiohttp.ClientSession: @@ -450,7 +450,7 @@ class StreamingForwarder(ModelEngineSerializationMixin): serialize_results_as_string: bool post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None # unused for now - def __post_init__(self): + def __post_init__(self) -> None: self._session: Optional[aiohttp.ClientSession] = None async def _get_session(self) -> aiohttp.ClientSession: @@ -659,7 +659,7 @@ def endpoint(route: str) -> str: class PassthroughForwarder(ModelEngineSerializationMixin): passthrough_endpoint: str - def __post_init__(self): + def __post_init__(self) -> None: self._session: Optional[aiohttp.ClientSession] = None async def _get_session(self) -> aiohttp.ClientSession: