Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 67 additions & 30 deletions model-engine/model_engine_server/api/app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import os
import traceback
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any

import pytz
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi import FastAPI, HTTPException, Response
from fastapi.openapi.docs import get_redoc_html
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from model_engine_server.api.batch_jobs_v1 import batch_job_router_v1
from model_engine_server.api.dependencies import get_or_create_aioredis_pool
Expand All @@ -31,8 +32,8 @@
make_logger,
)
from model_engine_server.core.tracing import get_tracing_gateway
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.types import ASGIApp, Receive, Scope, Send

logger = make_logger(logger_name())

Expand All @@ -43,32 +44,44 @@
concurrency=MAX_CONCURRENCY, fail_on_concurrency_limit=True
)

healthcheck_routes = ["/healthcheck", "/healthz", "/readyz"]
healthcheck_routes = {"/healthcheck", "/healthz", "/readyz"}

tracing_gateway = get_tracing_gateway()


class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
class CustomMiddleware:
"""
Pure ASGI middleware for request tracking, tracing, and concurrency limiting.
Unlike BaseHTTPMiddleware this does not buffer the entire response body,
which is important for streaming inference responses.
"""

def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return

request = Request(scope, receive)

LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4()))
LoggerTagManager.set(LoggerTagKey.REQUEST_SIZE, request.headers.get("content-length"))
if tracing_gateway:
tracing_gateway.extract_tracing_headers(request, service="model_engine_server")

# Healthcheck routes bypass concurrency limiting
if request.url.path in healthcheck_routes:
await self.app(scope, receive, send)
return

try:
LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4()))
LoggerTagManager.set(LoggerTagKey.REQUEST_SIZE, request.headers.get("content-length"))
if tracing_gateway:
tracing_gateway.extract_tracing_headers(request, service="model_engine_server")
# we intentionally exclude healthcheck routes from the concurrency limiter
if request.url.path in healthcheck_routes:
return await call_next(request)
with concurrency_limiter:
return await call_next(request)
await self.app(scope, receive, send)
except HTTPException as e:
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
return JSONResponse(
status_code=e.status_code,
content={
"error": e.detail,
"timestamp": timestamp,
},
)
await _send_json_error(send, e.status_code, e.detail, timestamp)
except Exception as e:
tb_str = traceback.format_exception(e)
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
Expand All @@ -79,22 +92,46 @@ async def dispatch(self, request: Request, call_next):
"traceback": "".join(tb_str),
}
logger.error("Unhandled exception: %s", structured_log)
return JSONResponse(
status_code=500,
content={
"error": "Internal error occurred. Our team has been notified.",
"timestamp": timestamp,
"request_id": request_id,
},
await _send_json_error(
send,
500,
"Internal error occurred. Our team has been notified.",
timestamp,
request_id=request_id,
)


async def _send_json_error(
send: Send,
status_code: int,
error: Any,
timestamp: str,
request_id: Any = None,
):
"""Send a JSON error response directly via ASGI send."""
body: dict = {"error": error, "timestamp": timestamp}
if request_id is not None:
body["request_id"] = request_id
body_bytes = json.dumps(body).encode("utf-8")
await send(
{
"type": "http.response.start",
"status": status_code,
"headers": [
[b"content-type", b"application/json"],
[b"content-length", str(len(body_bytes)).encode()],
],
}
)
await send({"type": "http.response.body", "body": body_bytes})


app = FastAPI(
title="launch",
version="1.0.0",
redoc_url=None,
middleware=[Middleware(CustomMiddleware)],
)
app.add_middleware(CustomMiddleware)

app.include_router(batch_job_router_v1)
app.include_router(inference_task_router_v1)
Expand Down
169 changes: 101 additions & 68 deletions model-engine/model_engine_server/inference/forwarding/forwarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@

DEFAULT_PORT: int = 5005

TTL_DNS_CACHE: int = 300


def _wait_for_service_ready(healthcheck_url: str):
"""Block until the service at healthcheck_url returns 200."""
while True:
try:
if requests.get(healthcheck_url, timeout=5).status_code == 200:
return
except requests.exceptions.ConnectionError:
pass
logger.info(f"Waiting for user-defined service to be ready at {healthcheck_url}...")
time.sleep(1)


class ModelEngineSerializationMixin:
"""Mixin class for optionally wrapping Model Engine requests."""
Expand Down Expand Up @@ -170,25 +184,40 @@ class Forwarder(ModelEngineSerializationMixin):
forward_http_status_in_body: bool
post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None

def __post_init__(self) -> None:
self._session: Optional[aiohttp.ClientSession] = None

async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
connector = aiohttp.TCPConnector(ttl_dns_cache=TTL_DNS_CACHE)
self._session = aiohttp.ClientSession(
json_serialize=_serialize_json, connector=connector
)
return self._session

async def close(self):
if self._session and not self._session.closed:
await self._session.close()

async def forward(self, json_payload: Any, trace_config: Optional[str] = None) -> Any:
json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload)
json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload

logger.info(f"Accepted request, forwarding {json_payload_repr=}")

try:
async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient:
headers = {"Content-Type": "application/json"}
if trace_config and tracing_gateway:
headers.update(tracing_gateway.encode_trace_headers())
response_raw = await aioclient.post(
self.predict_endpoint,
json=json_payload,
headers=headers,
)
response = await response_raw.json(
content_type=None
) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error
aioclient = await self._get_session()
headers = {"Content-Type": "application/json"}
if trace_config and tracing_gateway:
headers.update(tracing_gateway.encode_trace_headers())
response_raw = await aioclient.post(
self.predict_endpoint,
json=json_payload,
headers=headers,
)
response = await response_raw.json(
content_type=None
) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error

except Exception:
logger.exception(
Expand Down Expand Up @@ -345,15 +374,7 @@ def endpoint(route: str) -> str:
logger.info(f"Prediction endpoint: {pred}")
logger.info(f"Healthcheck endpoint: {hc}")

while True:
try:
if requests.get(hc).status_code == 200:
break
except requests.exceptions.ConnectionError:
pass

logger.info(f"Waiting for user-defined service to be ready at {hc}...")
time.sleep(1)
_wait_for_service_ready(hc)

logger.info(f"Unwrapping model engine payload formatting?: {self.model_engine_unwrap}")

Expand Down Expand Up @@ -429,31 +450,46 @@ class StreamingForwarder(ModelEngineSerializationMixin):
serialize_results_as_string: bool
post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None # unused for now

def __post_init__(self) -> None:
self._session: Optional[aiohttp.ClientSession] = None

async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
connector = aiohttp.TCPConnector(ttl_dns_cache=TTL_DNS_CACHE)
self._session = aiohttp.ClientSession(
json_serialize=_serialize_json, connector=connector
)
return self._session

async def close(self):
if self._session and not self._session.closed:
await self._session.close()

async def forward(self, json_payload: Any) -> AsyncGenerator[Any, None]: # pragma: no cover
json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload)
json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload

logger.info(f"Accepted request, forwarding {json_payload_repr=}")

try:
response: aiohttp.ClientResponse
async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient:
response = await aioclient.post(
self.predict_endpoint,
json=json_payload,
headers={"Content-Type": "application/json"},
)
aioclient = await self._get_session()
response = await aioclient.post(
self.predict_endpoint,
json=json_payload,
headers={"Content-Type": "application/json"},
)

if response.status != 200:
raise HTTPException(
status_code=response.status, detail=await response.json(content_type=None)
) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=await response.json(content_type=None),
) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error

async with EventSource(response=response) as event_source:
async for event in event_source:
yield self.get_response_payload_stream(
using_serialize_results_as_string, event.data
)
async with EventSource(response=response) as event_source:
async for event in event_source:
yield self.get_response_payload_stream(
using_serialize_results_as_string, event.data
)

except Exception:
logger.exception(
Expand Down Expand Up @@ -567,15 +603,7 @@ def endpoint(route: str) -> str:
logger.info(f"Prediction endpoint: {pred}")
logger.info(f"Healthcheck endpoint: {hc}")

while True:
try:
if requests.get(hc).status_code == 200:
break
except requests.exceptions.ConnectionError:
pass

logger.info(f"Waiting for user-defined service to be ready at {hc}...")
time.sleep(1)
_wait_for_service_ready(hc)

logger.info(f"Unwrapping model engine payload formatting?: {self.model_engine_unwrap}")

Expand Down Expand Up @@ -631,6 +659,19 @@ def endpoint(route: str) -> str:
class PassthroughForwarder(ModelEngineSerializationMixin):
passthrough_endpoint: str

def __post_init__(self) -> None:
self._session: Optional[aiohttp.ClientSession] = None

async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
connector = aiohttp.TCPConnector(ttl_dns_cache=TTL_DNS_CACHE)
self._session = aiohttp.ClientSession(connector=connector)
return self._session

async def close(self):
if self._session and not self._session.closed:
await self._session.close()

async def _make_request(
self, request: Any, aioclient: aiohttp.ClientSession
) -> aiohttp.ClientResponse:
Expand All @@ -651,28 +692,28 @@ async def _make_request(
return await aioclient.request(
method=request.method,
url=target_url,
data=await request.body() if request.method in ["POST", "PUT", "PATCH"] else None,
data=(await request.body() if request.method in ["POST", "PUT", "PATCH"] else None),
headers=headers,
)

async def forward_stream(self, request: Any):
async with aiohttp.ClientSession() as aioclient:
response = await self._make_request(request, aioclient)
response_headers = response.headers
yield (response_headers, response.status)
aioclient = await self._get_session()
response = await self._make_request(request, aioclient)
response_headers = response.headers
yield (response_headers, response.status)

if response.status != 200:
yield await response.read()
if response.status != 200:
yield await response.read()

async for chunk in response.content.iter_chunks():
yield chunk[0]
async for chunk in response.content.iter_chunks():
yield chunk[0]

yield await response.read()
yield await response.read()

async def forward_sync(self, request: Any):
async with aiohttp.ClientSession() as aioclient:
response = await self._make_request(request, aioclient)
return response
aioclient = await self._get_session()
response = await self._make_request(request, aioclient)
return response


@dataclass(frozen=True)
Expand Down Expand Up @@ -711,15 +752,7 @@ def endpoint(route: str) -> str:
logger.info(f"Passthrough endpoint: {passthrough_endpoint}")
logger.info(f"Healthcheck endpoint: {hc}")

while True:
try:
if requests.get(hc).status_code == 200:
break
except requests.exceptions.ConnectionError:
pass

logger.info(f"Waiting for user-defined service to be ready at {hc}...")
time.sleep(1)
_wait_for_service_ready(hc)

logger.info(f"Creating PassthroughForwarder with endpoint: {passthrough_endpoint}")
return PassthroughForwarder(passthrough_endpoint=passthrough_endpoint)
Expand Down
Loading