From 85d409f1652f0989304bf0b5b786de9bca13894b Mon Sep 17 00:00:00 2001 From: Mo King Date: Tue, 7 Apr 2026 13:18:24 -0400 Subject: [PATCH] Add lb websocket example --- serverless/load-balancing/build-a-worker.mdx | 100 +++++++++++++++++++ tests/TESTS.md | 1 + 2 files changed, 101 insertions(+) diff --git a/serverless/load-balancing/build-a-worker.mdx b/serverless/load-balancing/build-a-worker.mdx index 89007b2d..76eb5495 100644 --- a/serverless/load-balancing/build-a-worker.mdx +++ b/serverless/load-balancing/build-a-worker.mdx @@ -247,6 +247,105 @@ async def health_check(): ``` +## (Optional) WebSocket support + +Load balancing endpoints also support WebSocket connections. This section shows how to add a WebSocket endpoint to your worker and connect to it from a client. + + +You can clone the [worker-lb-websocket repository](https://github.com/runpod-workers/worker-lb-websocket) for a complete working example, including scaling tests. + + +### Add a WebSocket endpoint + +WebSocket endpoints in FastAPI use the `@app.websocket()` decorator. Add the following to your `app.py`: + +```python app.py +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +import asyncio + +app = FastAPI() + +# Track active connections +active_ws_connections: list[WebSocket] = [] + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + """Streaming WebSocket endpoint. + + Clients send JSON messages like: {"prompt": "Hello", "max_tokens": 50} + Server streams responses back and sends {"done": true} when complete. + """ + await websocket.accept() + active_ws_connections.append(websocket) + + try: + while True: + data = await websocket.receive_json() + prompt = data.get("prompt", "") + + if not prompt: + await websocket.send_json({"error": "prompt is required"}) + continue + + # Simulate streaming response (replace with your model) + words = f"Response to: {prompt}".split() + for i, word in enumerate(words): + await websocket.send_json({"token": word, "index": i}) + await asyncio.sleep(0.05) # Simulate inference latency + + await websocket.send_json({"done": True}) + + except WebSocketDisconnect: + pass + finally: + active_ws_connections.remove(websocket) +``` + +### Connect from a client + +When connecting to a WebSocket endpoint on a load balancing worker, you must set the `open_timeout` parameter to allow time for workers to scale up. The default timeout of 5 seconds is usually not enough. + +```python client.py +import asyncio +import json +import websockets + +async def connect_to_worker(): + url = "wss://ENDPOINT_ID.api.runpod.ai/ws" + headers = [("Authorization", "Bearer RUNPOD_API_KEY")] + + # Set open_timeout to allow workers time to scale up (default is ~5s) + async with websockets.connect( + url, + additional_headers=headers, + open_timeout=60.0, # Wait up to 60 seconds for connection + ) as ws: + # Send a request + await ws.send(json.dumps({"prompt": "Hello, world!", "max_tokens": 50})) + + # Receive streaming response + while True: + response = json.loads(await ws.recv()) + if response.get("done"): + print("Generation complete") + break + print(response.get("token", ""), end=" ") + +asyncio.run(connect_to_worker()) +``` + + +If you don't set `open_timeout`, connections will fail with a timeout error when workers need to scale up from zero. A value of 60 seconds works for most use cases. + + +### Update requirements.txt + +Add the `websockets` library to your client's dependencies: + +``` +websockets==14.2 +``` + ## Troubleshooting Here are some common issues and methods for troubleshooting: @@ -256,6 +355,7 @@ Here are some common issues and methods for troubleshooting: - **API not accessible**: If your request returns `{"error":"not allowed for QB API"}`, verify that your endpoint type is set to "Load Balancer". - **Port issues**: Make sure the environment variable for `PORT` matches what your application is using, and that the `PORT_HEALTH` variable is set to a different port. - **Model errors**: Check your model's requirements and whether it's compatible with your GPU. +- **WebSocket timeout**: If WebSocket connections fail with timeout errors, increase the `open_timeout` parameter in your client code to allow workers time to scale up. See [(Optional) WebSocket support](#optional-websocket-support) for details. ## Next steps diff --git a/tests/TESTS.md b/tests/TESTS.md index 3b206a17..663e2c2a 100644 --- a/tests/TESTS.md +++ b/tests/TESTS.md @@ -129,6 +129,7 @@ Run all smoke tests using local docs | serverless-github-deploy | Deploy an endpoint from GitHub | Endpoint from GitHub repo | | serverless-ssh-worker | SSH into a running worker for debugging | SSH session established | | serverless-metrics | View endpoint metrics (execution time, delay) | Metrics data returned | +| serverless-lb-websocket | Deploy a load balancing worker with WebSocket support and connect to it | WebSocket connection succeeds and receives streaming response | ---