diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..521929e --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,73 @@ +name: Release + +# Triggered when a version tag is pushed, e.g.: +# git tag v0.3.2.9 && git push origin v0.3.2.9 +on: + push: + tags: + - "v*" + +permissions: + contents: read + +jobs: + # 1. Build the sdist + wheel and verify the tag matches pyproject version. + build: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v3 + with: + version: "latest" + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.13" + + - name: Verify tag matches pyproject version + run: | + TAG="${GITHUB_REF_NAME#v}" + PKG=$(uv version --short 2>/dev/null || python -c "import tomllib;print(tomllib.load(open('pyproject.toml','rb'))['project']['version'])") + echo "Tag version: $TAG" + echo "Package version: $PKG" + if [ "$TAG" != "$PKG" ]; then + echo "::error::Tag ($TAG) does not match pyproject.toml version ($PKG)." + exit 1 + fi + + - name: Build sdist and wheel + run: uv build + + - name: Check distribution metadata + run: uvx twine check dist/* + + - name: Upload build artifacts + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + + # 2. Create the GitHub Release with auto-generated notes and attach artifacts. + # PyPI publishing is handled manually (see Makefile: `make publish`). + github-release: + needs: build + runs-on: ubuntu-latest + permissions: + contents: write # required to create a release + steps: + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - name: Create GitHub Release + uses: softprops/action-gh-release@v2 + with: + generate_release_notes: true + files: dist/* + fail_on_unmatched_files: true diff --git a/CLAUDE.md b/CLAUDE.md index 1c3a312..7e13e48 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -5,8 +5,9 @@ core framework see `agentflow/CLAUDE.md`; for the TS client, docs, or playground for the monorepo overview see the workspace-root `CLAUDE.md`. - Package name (PyPI): `10xscale-agentflow-cli` -- Version: `0.3.2.9` (`pyproject.toml`) โ€” note the CLI's own `CLI_VERSION` constant and - `agentflow_cli.__version__` both say `1.0.0`; these are out of sync (see Known Doc Drift). +- Version: `0.3.2.9` (`pyproject.toml`). `CLI_VERSION` and `agentflow_cli.__version__` are + single-sourced from the installed distribution metadata (falling back to `pyproject.toml`), so + `agentflow version` reports `0.3.2.9` consistently. The previous `1.0.0` hardcode is gone. - Requires: Python >= 3.12 ยท Status: `4 - Beta` - Console entry point: `agentflow = agentflow_cli.cli.main:main` - Depends on the core framework: `10xscale-agentflow>=0.7.0`. @@ -131,9 +132,10 @@ ruff check . && ruff format . ## Known doc drift (do not trust without checking) -- **Version is inconsistent.** `pyproject.toml` = `0.3.2.9`, but `CLI_VERSION` constant and - `agentflow_cli.__version__` = `1.0.0`. `agentflow version` prints both the (hardcoded) CLI - version and the pyproject version, so it shows `1.0.0` and `0.3.2.9` side by side. +- **Version is now single-sourced.** `CLI_VERSION` (and `agentflow_cli.__version__`, which aliases + it) resolve from installed distribution metadata, falling back to `pyproject.toml`. `agentflow + version` reports `0.3.2.9` for both the CLI and package lines. (The old hardcoded `1.0.0` drift is + resolved.) - **README shows `agentflow init --prod`** โ€” that flag does not exist. `init` is interactive and only accepts `--path` / `--force`. - **`api`/`play` help text claims default host `0.0.0.0`** but `DEFAULT_HOST` is `127.0.0.1`. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..a027f19 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Iamsdt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 37c2324..7c82b0e 100644 --- a/README.md +++ b/README.md @@ -1,95 +1,103 @@ -# AgentFlow CLI -A professional Python API framework for building agent-based applications with FastAPI, state graph orchestration, and comprehensive CLI tools. +# 10xScale Agentflow CLI + +[![CI](https://github.com/10xHub/agentflow-cli/actions/workflows/ci.yaml/badge.svg)](https://github.com/10xHub/agentflow-cli/actions/workflows/ci.yaml) +[![Release](https://github.com/10xHub/agentflow-cli/actions/workflows/release.yml/badge.svg)](https://github.com/10xHub/agentflow-cli/actions/workflows/release.yml) + +[![PyPI](https://img.shields.io/pypi/v/10xscale-agentflow-cli?color=blue)](https://pypi.org/project/10xscale-agentflow-cli/) +[![Python](https://img.shields.io/pypi/pyversions/10xscale-agentflow-cli)](https://pypi.org/project/10xscale-agentflow-cli/) +[![License](https://img.shields.io/github/license/10xHub/agentflow-cli)](https://github.com/10xHub/agentflow-cli/blob/main/LICENSE) +[![Coverage](https://img.shields.io/badge/coverage-90%25-brightgreen.svg)](https://github.com/10xHub/agentflow-cli/actions/workflows/ci.yaml) +[![Tests](https://img.shields.io/badge/tests-871%20passed-brightgreen.svg)](https://github.com/10xHub/agentflow-cli/actions/workflows/ci.yaml) +[![Status](https://img.shields.io/badge/status-beta-yellow.svg)](https://pypi.org/project/10xscale-agentflow-cli/) +[![Code style: ruff](https://img.shields.io/badge/code%20style-ruff-000000.svg)](https://github.com/astral-sh/ruff) + +**10xScale Agentflow CLI** turns an Agentflow `CompiledGraph` into a production-grade FastAPI service, plus a Typer-based command line to scaffold, run, build, test, and evaluate it. You write a graph, point `agentflow.json` at it, and `agentflow api` serves it over REST + WebSocket with authentication, rate limiting, media handling, checkpointer/thread management, and a memory store API. + +> ### ๐Ÿ“ฆ Part of the 10xScale Agentflow library +> +> This package (`10xscale-agentflow-cli`) is the **API server + CLI layer** of the larger +> [**10xScale Agentflow**](https://github.com/10xHub/agentflow) framework. The core orchestration +> engine โ€” `StateGraph`, `Agent`, `ToolNode`, state, persistence, memory, and tools โ€” lives in the +> separate [`10xscale-agentflow`](https://pypi.org/project/10xscale-agentflow/) package. This CLI +> builds on top of it to expose your agent graphs as a deployable service. +> +> - **Core framework:** [`10xscale-agentflow`](https://pypi.org/project/10xscale-agentflow/) ยท [source](https://github.com/10xHub/agentflow) +> - **This package (API + CLI):** [`10xscale-agentflow-cli`](https://pypi.org/project/10xscale-agentflow-cli/) +> - **TypeScript client:** [`@10xscale/agentflow-client`](https://www.npmjs.com/package/@10xscale/agentflow-client) +> - **Docs:** [agentflow.10xscale.ai](https://agentflow.10xscale.ai/) -## ๐Ÿ“š Documentation +--- + +## โœจ Key Features + +- **๐Ÿ–ฅ๏ธ Professional CLI** - Scaffold, run, build, test, and evaluate agents from one command line +- **โšก FastAPI Backend** - Your compiled graph auto-served over REST + WebSocket, high-performance and async +- **๐Ÿ”Œ Config-Driven** - One `agentflow.json` wires agent, auth, checkpointer, store, Redis, and rate limits +- **๐Ÿ” Authentication** - Built-in JWT auth, custom `BaseAuth` backends, and RBAC authorization +- **๐Ÿšฆ Rate Limiting** - Sliding-window limits with memory, Redis, or custom backends +- **๐Ÿ†” Distributed IDs** - Snowflake ID generation for multi-node deployments +- **๐Ÿงต Thread Management** - Conversation thread naming, listing, state, and message APIs +- **๐Ÿ–ผ๏ธ Multimodal & Media** - File upload/download endpoints and media handling for multimodal agents +- **๐ŸŽ™๏ธ Realtime Audio Bridge** - WebSocket endpoint for live audio-to-audio agents (Gemini Live) +- **๐Ÿณ Docker & Kubernetes Ready** - Generate production Dockerfiles and compose files with one command +- **๐Ÿ›ก๏ธ Production Hardening** - Error/log sanitization, request size limits, security headers, startup validation +- **๐Ÿ’‰ Dependency Injection** - InjectQ for clean, testable dependency wiring -- **[CLI Guide](./docs/cli-guide.md)** - Complete command-line interface reference -- **[Configuration Guide](./docs/configuration.md)** - All configuration options explained -- **[Deployment Guide](./docs/deployment.md)** - Docker, Kubernetes, and cloud deployment -- **[Authentication Guide](./docs/authentication.md)** - JWT and custom authentication -- **[Rate Limiting Guide](./docs/rate-limiting.md)** - Memory, Redis, and custom rate-limit backends -- **[ID Generation Guide](./docs/id-generation.md)** - Snowflake ID generation -- **[Thread Name Generator Guide](./docs/thread-name-generator.md)** - Thread naming strategies +--- -## Quick Start +## Installation -### Installation +**Basic installation:** ```bash pip install 10xscale-agentflow-cli ``` -Redis rate limiting is optional. Install the Redis extra only when you configure -`rate_limit.backend` as `redis`: +Optional extras โ€” install only what you configure: ```bash -pip install "10xscale-agentflow-cli[redis]" +pip install "10xscale-agentflow-cli[redis]" # Redis rate-limit / cache backend +pip install "10xscale-agentflow-cli[jwt]" # JWT authentication +pip install "10xscale-agentflow-cli[media]" # Document text extraction (multimodal) +pip install "10xscale-agentflow-cli[otel]" # OpenTelemetry tracing +pip install "10xscale-agentflow-cli[snowflakekit]" # Snowflake ID generation ``` -JWT auth and document text extraction are optional too. Install only the extra -you need: +Requires **Python โ‰ฅ 3.12**. Depends on the core `10xscale-agentflow` framework. -```bash -pip install "10xscale-agentflow-cli[jwt]" -pip install "10xscale-agentflow-cli[media]" -``` +--- -### Initialize a New Project +## ๐Ÿš€ Quick Start ```bash -# Create project structure (interactive: prompts for dev vs production, auth, rate limiting) +# 1. Scaffold a project (interactive: dev vs production, auth, rate limiting) agentflow init -``` - -### Start Development Server -```bash +# 2. Start the dev API server (127.0.0.1:8000) agentflow api -``` -### Start API With Play - -```bash +# 3. Or start the server and open the hosted playground agentflow play -``` - -### Generate Docker Files -```bash +# 4. Generate production Docker files agentflow build --docker-compose ``` -## Key Features - -- โœ… **CLI Tools** - Professional command-line interface for scaffolding and deployment -- โœ… **State Graph Orchestration** - Build complex agent workflows with LangGraph -- โœ… **FastAPI Backend** - High-performance async web framework -- โœ… **Authentication** - Built-in JWT auth and custom authentication support -- โœ… **Rate Limiting** - Sliding-window limits with memory, Redis, and custom backends -- โœ… **ID Generation** - Distributed Snowflake ID generation -- โœ… **Thread Management** - Intelligent thread naming and conversation management -- โœ… **Docker Ready** - Generate production-ready Docker files -- โœ… **Dependency Injection** - InjectQ for clean dependency management -- โœ… **Development Tools** - Hot-reload, pre-commit hooks, testing +--- -## CLI Commands +## ๐Ÿ–ฅ๏ธ CLI Commands -For detailed command documentation, see the [CLI Guide](./docs/cli-guide.md). +For detailed command documentation, see the **[CLI Guide](./docs/cli-guide.md)**. ### `agentflow init` -Initialize a new project with configuration and sample graph. +Initialize a new project with configuration and a sample graph. ```bash -# Basic initialization (interactive prompts choose dev vs production setup) -agentflow init - -# Custom directory -agentflow init --path ./my-project - -# Force overwrite existing files -agentflow init --force +agentflow init # interactive (chooses dev vs production setup) +agentflow init --path ./my-app # custom directory +agentflow init --force # overwrite existing files ``` ### `agentflow api` @@ -97,38 +105,21 @@ agentflow init --force Start the development API server. ```bash -# Start with defaults (localhost:8000) -agentflow api - -# Custom host and port -agentflow api --host 127.0.0.1 --port 9000 - -# Custom config file -agentflow api --config production.json - -# Disable auto-reload -agentflow api --no-reload - -# Verbose logging -agentflow api --verbose +agentflow api # defaults (127.0.0.1:8000) +agentflow api --host 127.0.0.1 --port 9000 # custom host/port +agentflow api --config production.json # custom config file +agentflow api --no-reload # disable auto-reload +agentflow api --verbose # verbose logging ``` ### `agentflow play` -Start the development API server and open the hosted playground with your local backend URL preconfigured. +Start the dev server and open the hosted playground with your local backend URL preconfigured. ```bash -# Start with defaults and open the playground agentflow play - -# Custom host and port agentflow play --host 127.0.0.1 --port 9000 - -# Custom config file agentflow play --config production.json - -# Disable auto-reload -agentflow play --no-reload ``` ### `agentflow build` @@ -136,28 +127,42 @@ agentflow play --no-reload Generate production Docker files. ```bash -# Generate Dockerfile -agentflow build +agentflow build # Dockerfile +agentflow build --docker-compose # Dockerfile + docker-compose.yml +agentflow build --python-version 3.12 --port 9000 +agentflow build --force +``` -# Generate Dockerfile and docker-compose.yml -agentflow build --docker-compose +### `agentflow eval` / `agentflow test` -# Custom Python version and port -agentflow build --python-version 3.12 --port 9000 +Run agent evaluations (discovers `*_eval.py` / `eval_*.py`, writes HTML + JSON to `eval_reports/`) and project tests (pytest). -# Force overwrite -agentflow build --force +```bash +agentflow eval --parallel --threshold 0.8 +agentflow test --coverage +``` + +### `agentflow skills` + +Install bundled coding-agent skills (Codex, Claude, GitHub Copilot) into your project so your AI assistant knows how to build with Agentflow. + +```bash +agentflow skills --all # install for every supported agent +agentflow skills --agent claude # install for one +agentflow skills --list # show supported agents ``` ### `agentflow version` -Display version information. +Display CLI and package version information. ```bash agentflow version ``` -## Configuration +--- + +## โš™๏ธ Configuration The configuration file (`agentflow.json`) defines your agent, authentication, and infrastructure settings: @@ -170,7 +175,8 @@ The configuration file (`agentflow.json`) defines your agent, authentication, an "injectq": null, "store": null, "redis": null, - "thread_name_generator": null + "thread_name_generator": null, + "rate_limit": {} } ``` @@ -178,37 +184,30 @@ The configuration file (`agentflow.json`) defines your agent, authentication, an | Field | Type | Description | |-------|------|-------------| -| `agent` | string | Path to your compiled agent graph (required) | +| `agent` | string | Path to your compiled agent graph, `"module:attribute"` (required) | | `env` | string | Path to environment variables file | -| `auth` | null\|"jwt"\|object | Authentication configuration | -| `checkpointer` | string\|null | Path to custom checkpointer | -| `injectq` | string\|null | Path to InjectQ container | -| `store` | string\|null | Path to data store | -| `redis` | string\|null | Redis connection URL | -| `rate_limit` | object\|null | Sliding-window rate limiting configuration | -| `thread_name_generator` | string\|null | Path to custom thread name generator | +| `auth` | null \| "jwt" \| object | Authentication configuration | +| `authorization` | string \| null | Path to an `AuthorizationBackend` (RBAC / per-tool access) | +| `checkpointer` | string \| null | Path to a custom checkpointer | +| `injectq` | string \| null | Path to an InjectQ container | +| `store` | string \| null | Path to a data store | +| `redis` | string \| null | Redis connection URL | +| `rate_limit` | object \| null | Sliding-window rate limiting configuration | +| `thread_name_generator` | string \| null | Path to a custom thread name generator | -See the [Configuration Guide](./docs/configuration.md) for complete details. +See the **[Configuration Guide](./docs/configuration.md)** for complete details. -## Authentication +--- -AgentFlow supports multiple authentication strategies. See the [Authentication Guide](./docs/authentication.md) for complete details. +## ๐Ÿ” Authentication -### No Authentication - -```json -{ - "auth": null -} -``` +Agentflow supports multiple authentication strategies. See the **[Authentication Guide](./docs/authentication.md)** for details. ### JWT Authentication **agentflow.json:** ```json -{ - "auth": "jwt" -} +{ "auth": "jwt" } ``` **.env:** @@ -221,12 +220,7 @@ JWT_ALGORITHM=HS256 **agentflow.json:** ```json -{ - "auth": { - "method": "custom", - "path": "auth.custom:MyAuthBackend" - } -} +{ "auth": { "method": "custom", "path": "auth.custom:MyAuthBackend" } } ``` **auth/custom.py:** @@ -235,51 +229,42 @@ from agentflow_cli import BaseAuth from fastapi import Response, HTTPException from fastapi.security import HTTPAuthorizationCredentials + class MyAuthBackend(BaseAuth): def authenticate( self, res: Response, - credential: HTTPAuthorizationCredentials + credential: HTTPAuthorizationCredentials, ) -> dict[str, any] | None: - # Your authentication logic token = credential.credentials user = verify_token(token) - if not user: raise HTTPException(401, "Invalid token") - - return { - "user_id": user.id, - "username": user.username, - "email": user.email - } + return {"user_id": user.id, "username": user.username, "email": user.email} ``` -## ID Generation +--- -AgentFlow includes Snowflake ID generation for distributed, time-sortable unique IDs. +## ๐Ÿ†” ID Generation + +Agentflow includes Snowflake ID generation for distributed, time-sortable unique IDs. ```bash pip install "10xscale-agentflow-cli[snowflakekit]" ``` -**Usage:** ```python from agentflow_cli import SnowFlakeIdGenerator -# Initialize generator = SnowFlakeIdGenerator( snowflake_epoch=1704067200000, # Jan 1, 2024 snowflake_node_id=1, - snowflake_worker_id=1 + snowflake_worker_id=1, ) - -# Generate ID -id = await generator.generate() -print(f"Generated ID: {id}") +new_id = await generator.generate() ``` -**Environment Configuration:** +**Environment configuration:** ```bash SNOWFLAKE_EPOCH=1704067200000 SNOWFLAKE_NODE_ID=1 @@ -289,9 +274,11 @@ SNOWFLAKE_NODE_BITS=5 SNOWFLAKE_WORKER_BITS=8 ``` -See the [ID Generation Guide](./docs/id-generation.md) for more details. +See the **[ID Generation Guide](./docs/id-generation.md)** for more details. + +--- -## Thread Name Generation +## ๐Ÿงต Thread Name Generation Generate human-friendly names for conversation threads. @@ -300,18 +287,18 @@ from agentflow_cli.src.app.utils.thread_name_generator import AIThreadNameGenera generator = AIThreadNameGenerator() name = generator.generate_name() -# Output: "thoughtful-dialogue", "exploring-ideas", etc. +# "thoughtful-dialogue", "exploring-ideas", ... ``` -See the [Thread Name Generator Guide](./docs/thread-name-generator.md) for custom implementations. +See the **[Thread Name Generator Guide](./docs/thread-name-generator.md)** for custom implementations. -## Security +--- -AgentFlow CLI provides enterprise-grade security features for production deployments. +## ๐Ÿ›ก๏ธ Security -### Security Features +Agentflow CLI provides production-grade security features. -- โœ… **Authentication** - Built-in JWT and custom authentication backends +- โœ… **Authentication** - JWT and custom authentication backends - โœ… **Authorization** - Resource-based access control with extensible backends - โœ… **Request Limits** - DoS protection with configurable size limits (default 10MB) - โœ… **Error Sanitization** - Production-safe error messages preventing information disclosure @@ -321,95 +308,25 @@ AgentFlow CLI provides enterprise-grade security features for production deploym ### Production Security Checklist -Before deploying to production, ensure: - ```bash -# Required: Set production mode -MODE=production - -# Required: Strong JWT secret (32+ characters) -JWT_SECRET_KEY= - -# Required: Disable debug mode -IS_DEBUG=false - -# Required: Specific CORS origins (not *) -ORIGINS=https://yourdomain.com - -# Required: Specific allowed hosts (not *) -ALLOWED_HOST=yourdomain.com - -# Recommended: Disable API docs -DOCS_PATH= +MODE=production # production mode +JWT_SECRET_KEY=<32+ chars> # strong secret (secrets.token_urlsafe(32)) +IS_DEBUG=false # disable debug +ORIGINS=https://yourdomain.com # specific CORS origins (never *) +ALLOWED_HOST=yourdomain.com # specific allowed hosts (never *) +DOCS_PATH= # recommended: disable API docs REDOCS_PATH= - -# Recommended: Configure request size limit -MAX_REQUEST_SIZE=10485760 # 10MB default -``` - -### Quick Security Setup - -**1. Enable JWT Authentication:** -```json -{ - "auth": "jwt" -} +MAX_REQUEST_SIZE=10485760 # request size limit (10MB default) ``` -**2. Implement Authorization:** -```python -# auth/rbac_backend.py -from agentflow_cli.src.app.core.auth.authorization import AuthorizationBackend - -class RBACAuthorizationBackend(AuthorizationBackend): - async def authorize(self, user, resource, action, resource_id=None, **context): - role = user.get("role", "viewer") - # Implement your authorization logic - return role == "admin" or (role == "developer" and action == "read") -``` - -**3. Configure in agentflow.json:** -```json -{ - "auth": "jwt", - "authorization": { - "path": "auth.rbac_backend:RBACAuthorizationBackend" - } -} -``` - -### Security Validation - -AgentFlow automatically validates your configuration and warns about security issues: - -``` -โš ๏ธ SECURITY WARNING: CORS ORIGINS='*' in production. - Set ORIGINS to specific domains. - -โš ๏ธ SECURITY WARNING: DEBUG mode enabled in production! - Set IS_DEBUG=false -``` - -### Comprehensive Security Guide - -For detailed security documentation, threat model, best practices, and deployment guidelines, see: +For deployment hardening and authentication patterns, see the +**[Deployment Guide](./docs/deployment.md)** and **[Authentication Guide](./docs/authentication.md)**. -๐Ÿ“– **[SECURITY.md](./SECURITY.md)** - Complete Security Guide - -Topics covered: -- Threat model and attack vectors -- Authentication and authorization patterns -- Production deployment checklist -- Docker and Kubernetes security configurations -- Security testing and monitoring -- Incident response procedures -- Vulnerability reporting - -## Deployment +--- -See the [Deployment Guide](./docs/deployment.md) for complete deployment instructions. +## ๐Ÿณ Deployment -### Docker Deployment +See the **[Deployment Guide](./docs/deployment.md)** for full instructions. ```bash # Generate Docker files @@ -422,137 +339,102 @@ docker compose up --build -d docker compose logs -f ``` -### Kubernetes - -See [Deployment Guide - Kubernetes](./docs/deployment.md#kubernetes) for complete manifests. +Cloud targets covered in the guide: [AWS ECS](./docs/deployment.md#aws-ecs), +[Google Cloud Run](./docs/deployment.md#google-cloud-run), +[Azure Container Instances](./docs/deployment.md#azure-container-instances), +[Kubernetes](./docs/deployment.md#kubernetes), and [Heroku](./docs/deployment.md#heroku). -### Cloud Platforms - -- [AWS ECS](./docs/deployment.md#aws-ecs) -- [Google Cloud Run](./docs/deployment.md#google-cloud-run) -- [Azure Container Instances](./docs/deployment.md#azure-container-instances) -- [Heroku](./docs/deployment.md#heroku) +--- -## Project Structure +## ๐Ÿ“ Project Structure ``` agentflow-cli/ โ”œโ”€โ”€ agentflow_cli/ # Main package -โ”‚ โ”œโ”€โ”€ __init__.py # Package exports -โ”‚ โ”œโ”€โ”€ cli/ # CLI commands -โ”‚ โ”‚ โ”œโ”€โ”€ main.py # CLI entry point -โ”‚ โ”‚ โ””โ”€โ”€ commands/ # Command implementations -โ”‚ โ””โ”€โ”€ src/ # Application source -โ”‚ โ””โ”€โ”€ app/ # FastAPI application -โ”‚ โ”œโ”€โ”€ main.py # App entry point -โ”‚ โ”œโ”€โ”€ core/ # Core functionality -โ”‚ โ”œโ”€โ”€ routers/ # API routes -โ”‚ โ””โ”€โ”€ utils/ # Utilities -โ”œโ”€โ”€ graph/ # Agent graphs -โ”‚ โ”œโ”€โ”€ __init__.py -โ”‚ โ””โ”€โ”€ react.py # Sample React agent -โ”œโ”€โ”€ docs/ # Documentation -โ”œโ”€โ”€ tests/ # Test suite -โ”œโ”€โ”€ agentflow.json # Configuration -โ”œโ”€โ”€ pyproject.toml # Project metadata -โ””โ”€โ”€ README.md # This file +โ”‚ โ”œโ”€โ”€ __init__.py # Package exports (BaseAuth, SnowFlakeIdGenerator, ThreadNameGenerator) +โ”‚ โ”œโ”€โ”€ cli/ # Typer CLI: main.py + commands/ + templates/ +โ”‚ โ””โ”€โ”€ src/app/ # FastAPI application (main.py, loader.py, core/, routers/, utils/) +โ”œโ”€โ”€ docs/ # Documentation +โ”œโ”€โ”€ tests/ # Test suite +โ”œโ”€โ”€ agentflow.json # Configuration +โ”œโ”€โ”€ pyproject.toml # Project metadata +โ””โ”€โ”€ README.md # This file ``` -## Development +--- -### Setup +## ๐Ÿ”ง Development ```bash -# Clone repository +# Clone and set up git clone https://github.com/10xHub/agentflow-cli.git cd agentflow-cli - -# Create virtual environment -python -m venv .venv -source .venv/bin/activate # On Windows: .venv\Scripts\activate - -# Install in development mode +python -m venv .venv && source .venv/bin/activate pip install -e ".[dev]" - -# Install pre-commit hooks pre-commit install -``` - -### Testing -```bash -# Run all tests -pytest - -# With coverage +# Quality gate +pytest # tests (coverage gate: 80%) pytest --cov=agentflow_cli --cov-report=html - -# Run specific test file -pytest tests/test_cli.py -v +ruff check . && ruff format . # lint + format +pre-commit run --all-files # full gate (ruff + bandit, pinned versions) ``` -### Code Quality +### Using the Makefile ```bash -# Format code -ruff format . - -# Lint code -ruff check . - -# Fix auto-fixable issues -ruff check --fix . +make build # build sdist + wheel +make test # run tests +make test-cov # run tests with coverage +make publish # upload to PyPI (maintainers) +make clean # remove build artifacts ``` -### Using the Makefile +### Releasing + +Releases are cut by pushing a version tag that matches `pyproject.toml`. The +[`release.yml`](./.github/workflows/release.yml) workflow then verifies the tag, builds the +sdist + wheel, checks the distribution metadata, and creates a GitHub Release with auto-generated +notes and the artifacts attached. PyPI publishing is manual (`make publish`). ```bash -# Show available commands -make help +git tag v0.3.2.9 && git push origin v0.3.2.9 +``` -# Install development dependencies -make dev-install +--- -# Run tests -make test +## ๐Ÿ“„ License -# Format and lint -make format -make lint +MIT License - see [LICENSE](https://github.com/10xHub/agentflow-cli/blob/main/LICENSE) for details. -# Build package -make build +--- -# Clean build artifacts -make clean -``` +## ๐Ÿ”— Links & Resources -## Contributing +- **[Documentation](https://agentflow.10xscale.ai/)** - Full framework docs +- **[Core framework (`10xscale-agentflow`)](https://github.com/10xHub/agentflow)** - The orchestration engine this CLI serves +- **[This repository](https://github.com/10xHub/agentflow-cli)** - Source code and issues +- **[PyPI Project](https://pypi.org/project/10xscale-agentflow-cli/)** - Package releases +- **[Local docs](./docs/)** - CLI, configuration, deployment, auth, rate limiting, IDs, thread names -Contributions are welcome! Please follow these steps: +--- -1. Fork the repository -2. Create a feature branch (`git checkout -b feature/amazing-feature`) -3. Make your changes -4. Run tests and linting -5. Commit your changes (`git commit -m 'Add amazing feature'`) -6. Push to the branch (`git push origin feature/amazing-feature`) -7. Open a Pull Request +## ๐Ÿ™ Contributing -## License +Contributions are welcome! Fork the repo, create a feature branch, run tests and linting, and open a +Pull Request. See the [repository](https://github.com/10xHub/agentflow-cli) for issue reporting and +guidelines. -MIT License - see LICENSE file for details. +--- -## Support +## ๐Ÿ’ฌ Support -- **Documentation:** [Complete Documentation](./docs/) +- **Documentation:** [agentflow.10xscale.ai](https://agentflow.10xscale.ai/) and [local docs](./docs/) - **Issues:** [GitHub Issues](https://github.com/10xHub/agentflow-cli/issues) - **Repository:** [GitHub](https://github.com/10xHub/agentflow-cli) -## Credits +--- Developed by [10xScale](https://10xscale.ai) and maintained by the community. ---- - **Made with โค๏ธ for the AI agent development community** diff --git a/agentflow_cli/cli/commands/skills.py b/agentflow_cli/cli/commands/skills.py index 6c9ba24..4325eeb 100644 --- a/agentflow_cli/cli/commands/skills.py +++ b/agentflow_cli/cli/commands/skills.py @@ -221,6 +221,34 @@ def _install_one( self.output.success( f"Installed Agentflow skills for {target.name} at {', '.join(installed_paths)}" ) + self._print_activation_hint(target) + + def _print_activation_hint(self, target: _AgentTarget) -> None: + """Tell the user how to make the freshly installed skill take effect. + + Coding agents load skills at session start. When ``agentflow skills`` + creates the skills directory for the first time during a running session, + the agent does not watch the new top-level directory until it is + restarted, so the skill appears "installed but unused". This note makes + the required restart explicit. + """ + if target.name == "Claude": + note = ( + "Restart Claude Code (or run /exit then `claude`) so it loads the new " + ".claude/skills/ directory. Claude auto-invokes the skill from its " + "description; type /agentflow to run it manually." + ) + elif target.name == "Codex": + note = ( + "Restart Codex so it picks up the new .agents/skills/ directory. Codex " + "auto-selects the skill when your task matches its description." + ) + else: # GitHub Copilot + note = ( + "Restart GitHub Copilot / your editor so it loads the new " + ".github/skills/ directory and .github/instructions/ file." + ) + self.output.warning(f"Activate: {note}") def _install_all(self, templates_root: Path, project_root: Path, *, force: bool) -> int: installed = 0 diff --git a/agentflow_cli/cli/templates/skills/agent-skills/references/agents-and-tools.md b/agentflow_cli/cli/templates/skills/agent-skills/references/agents-and-tools.md index 247a75e..bb71b7e 100644 --- a/agentflow_cli/cli/templates/skills/agent-skills/references/agents-and-tools.md +++ b/agentflow_cli/cli/templates/skills/agent-skills/references/agents-and-tools.md @@ -113,6 +113,21 @@ agent = Agent( `RetryConfig` fields: `max_retries` (3), `initial_delay` (1.0 s), `max_delay` (30.0 s), `backoff_factor` (2.0). +### Circuit breaker (opt-in) + +Complements retry + `fallback_models`. Once a `(provider, model)` target fails +`circuit_breaker_threshold` times in a row, its circuit opens and that target is skipped (straight +to the next fallback) for `circuit_breaker_reset_timeout` seconds, instead of being retried on every +call. Configure on `RetryConfig`: + +- `circuit_breaker_enabled: bool = False` +- `circuit_breaker_threshold: int = 5` +- `circuit_breaker_reset_timeout: float = 30.0` + +```python +RetryConfig(circuit_breaker_enabled=True, circuit_breaker_threshold=3, circuit_breaker_reset_timeout=60.0) +``` + ### Fallback models When the primary model exhausts all retries, AgentFlow tries each fallback in order: @@ -128,6 +143,23 @@ agent = Agent( ) ``` +### LLM call timeout + +All LLM clients apply a default request timeout (600 s) so a stalled provider cannot hang a run +indefinitely. Override globally via the `AGENTFLOW_LLM_TIMEOUT` env var (seconds), or +programmatically: + +```python +from agentflow.core.llm import ( + set_default_llm_timeout, get_default_llm_timeout, DEFAULT_LLM_TIMEOUT_SECONDS, +) + +set_default_llm_timeout(120.0) # set +set_default_llm_timeout(None) # reset to default +``` + +An explicit per-client `timeout=` still takes precedence. + ### Structured output `output_schema` with a Pydantic model forces JSON output. Requires `output_type="text"` (default). diff --git a/agentflow_cli/cli/templates/skills/agent-skills/references/architecture.md b/agentflow_cli/cli/templates/skills/agent-skills/references/architecture.md index 78ec4be..497f75f 100644 --- a/agentflow_cli/cli/templates/skills/agent-skills/references/architecture.md +++ b/agentflow_cli/cli/templates/skills/agent-skills/references/architecture.md @@ -21,7 +21,8 @@ Use this when deciding where a change belongs or explaining how Agentflow packag | Sub-package | Key exports | |---|---| | `agentflow.core` | `StateGraph`, `Agent`, `ToolNode`, `AgentState`, `Message`, `StreamChunk` | -| `agentflow.prebuilt.agent` | `ReactAgent`, `PlanActReflectAgent`, `StructuredOutputAgent`, `SupervisorTeamAgent`, `SwarmAgent`, `RAGAgent` | +| `agentflow.prebuilt.agent` | `ReactAgent`, `PlanActReflectAgent`, `StructuredOutputAgent`, `SupervisorTeamAgent`, `SwarmAgent`, `RAGAgent`, `AudioAgent` (realtime voice) | +| `agentflow.core.realtime` | `LiveInputQueue`, `RealtimeConfig`, `RealtimeEvent` (Gemini Live audio-to-audio) | | `agentflow.prebuilt.tools` | `fetch_url`, `safe_calculator`, `file_read`, `file_write`, `file_search`, `google_web_search`, `vertex_ai_search`, `memory_tool`, `create_handoff_tool` | | `agentflow.storage.checkpointer` | `InMemoryCheckpointer`, `PgCheckpointer` | | `agentflow.storage.store` | `QdrantStore`, `Mem0Store` | diff --git a/agentflow_cli/cli/templates/skills/agent-skills/references/callbacks-and-command.md b/agentflow_cli/cli/templates/skills/agent-skills/references/callbacks-and-command.md index 32ee7e1..0ac7872 100644 --- a/agentflow_cli/cli/templates/skills/agent-skills/references/callbacks-and-command.md +++ b/agentflow_cli/cli/templates/skills/agent-skills/references/callbacks-and-command.md @@ -149,6 +149,23 @@ app = graph.compile(callback_manager=callback_mgr) | `on_resume` | `AgentState \| None` | 0โ€“1 per call | Before `clear_interrupt()` | | `on_checkpoint` | `(AgentState, list[Message]) \| AgentState \| None` | 1โ€“N per run | Before every durable checkpoint write | | `on_state_update` | `AgentState \| None` | N per run (once per node) | After each node result is merged | +| `on_turn_start` | `AgentState \| None` | N (realtime only) | Start of each model turn | +| `on_turn_end` | `AgentState \| None` | N (realtime only) | End of each model turn | + +### Realtime hooks + +The same `GraphLifecycleHook` fires for realtime (audio) graphs, with two methods that fire **only** +in realtime (no-ops for `invoke` / `stream`): + +- `on_graph_start(ctx, state)` / `on_graph_end(ctx, final_state, messages, total_steps)` โ€” once per + session (the `LIVE` node *is* the graph); `total_steps` = number of turns. +- `on_turn_start(ctx, state, turn_index)` / `on_turn_end(ctx, state, turn_index)` โ€” per model turn + (1-based; a turn spans one model generation, bounded by `turn_complete` or a barge-in). A turn cut + off by session end still gets a balanced `on_turn_end`. + +Realtime has no `AI`-invocation callback or input-validator pass (no discrete LLM call); the per-turn +hooks are the observability stand-in. Tool/MCP `before/after/error` callbacks fire as usual. See +`realtime.md`. ### Example diff --git a/agentflow_cli/cli/templates/skills/agent-skills/references/prebuilt-agents-and-tools.md b/agentflow_cli/cli/templates/skills/agent-skills/references/prebuilt-agents-and-tools.md index 6add7cf..263a0fa 100644 --- a/agentflow_cli/cli/templates/skills/agent-skills/references/prebuilt-agents-and-tools.md +++ b/agentflow_cli/cli/templates/skills/agent-skills/references/prebuilt-agents-and-tools.md @@ -14,6 +14,7 @@ from agentflow.prebuilt.agent import ( SupervisorTeamAgent, SwarmAgent, RAGAgent, + AudioAgent, ) ``` @@ -240,6 +241,32 @@ from agentflow.prebuilt.agent import RAGAgent, CohereReranker agent = RAGAgent(model="gpt-4o", store=store, reranker=CohereReranker(api_key="...")) ``` +### AudioAgent + +Realtime, full-duplex voice agent (Gemini Live). React-style builder that wraps a `LiveAgent` as the +graph root. Requires `pip install "10xscale-agentflow[realtime]"` and `GEMINI_API_KEY`. Driven by +`arealtime()` / `realtime()`, not `invoke` / `stream`. + +```python +from agentflow.prebuilt.agent import AudioAgent +from agentflow.core.realtime import RealtimeConfig, LiveInputQueue + +app = AudioAgent( + "gemini-live-2.5-flash-preview", + realtime_config=RealtimeConfig(model="gemini-live-2.5-flash-preview", voice="Puck"), + tools=[my_tool], # React-style tool calling, including barge-in +).compile() + +queue = LiveInputQueue() +queue.send_audio(pcm16_bytes) +async for event in app.arealtime(queue, {"thread_id": "t1"}): + ... +``` + +`compile()` takes `checkpointer`, `store`, `callback_manager`, `shutdown_timeout` only (no +`media_store` / `interrupt_*`). `system_prompt`, `skills`, and `memory` work like `ReactAgent`. See +`realtime.md` for the full surface (events, reconnection, lifecycle hooks, WebSocket bridge). + --- ## Compile options (all prebuilt agents) diff --git a/agentflow_cli/cli/templates/skills/agent-skills/references/publishers-and-runtime-protocols.md b/agentflow_cli/cli/templates/skills/agent-skills/references/publishers-and-runtime-protocols.md index b9ccd7a..3e6f4dc 100644 --- a/agentflow_cli/cli/templates/skills/agent-skills/references/publishers-and-runtime-protocols.md +++ b/agentflow_cli/cli/templates/skills/agent-skills/references/publishers-and-runtime-protocols.md @@ -24,6 +24,13 @@ Implementations: Publishers receive structured events from graph/tool execution. Use them for tracing, monitoring, audit logs, and external streaming/event bus integrations. +`ConsolePublisher` is a dev/debug, opt-in publisher (use a real transport in production). It writes +to stdout by default; pass `ConsolePublisher({"use_logger": True})` to route events through the +`agentflow.publisher` logger instead of stdout. + +Realtime adds an `Event.REALTIME` event source and a `ContentType.TRANSCRIPT` content type (see +`realtime.md`). + ## Runtime Adapters LLM adapters: diff --git a/agentflow_cli/cli/templates/skills/agent-skills/references/realtime.md b/agentflow_cli/cli/templates/skills/agent-skills/references/realtime.md new file mode 100644 index 0000000..7014289 --- /dev/null +++ b/agentflow_cli/cli/templates/skills/agent-skills/references/realtime.md @@ -0,0 +1,172 @@ +# Realtime audio-to-audio (Gemini Live) + +Use this when building live, full-duplex audio (voice) agents. Unlike `invoke` / `stream` +(turn-based super-step traversal), a realtime graph is driven by a separate runtime because the +provider owns the turn loop. Provider neutrality is built in: contracts import no provider SDK, so +OpenAI Realtime can be added later behind the same `RealtimeClient` Protocol. Today only Gemini +Live is implemented. + +## Install and environment + +- `pip install "10xscale-agentflow[realtime]"` (pulls `google-genai>=1.56.0`). +- Set `GEMINI_API_KEY`; optionally `GEMINI_LIVE_MODEL` (default `gemini-live-2.5-flash-preview`). +- Provider SDK imports are lazy: importing `agentflow.core.realtime` never requires the extra. + +## Audio and media formats + +- Input audio: PCM16 mono @ 16 kHz. Output audio: PCM16 @ 24 kHz. +- Transcripts are persisted as `Message`s (`metadata={"modality": "audio"}`); raw audio is never + stored. +- Image / video input: still images and video frames are sent live to the model as JPEG frames via + `LiveInputQueue.send_image(...)`. Like ADK, there is no media store/offload in the realtime path; + image frames are not persisted (a reconnect reseeds text transcripts only). + +## Prebuilt `AudioAgent` + +`from agentflow.prebuilt.agent import AudioAgent` โ€” a React-style builder mirroring `ReactAgent`'s +construction surface, wrapping a `LiveAgent` as the graph root. + +```python +from agentflow.prebuilt.agent import AudioAgent +from agentflow.core.realtime import RealtimeConfig + +app = AudioAgent( + "gemini-live-2.5-flash-preview", + realtime_config=RealtimeConfig(model="gemini-live-2.5-flash-preview", voice="Puck"), + tools=[my_tool], # advertised to the model automatically; runs React-style +).compile() +``` + +- Constructor (positional): `model`, then optional `state`, `context_manager`, `publisher`, + `id_generator`, `container`; keyword-only: `realtime_config`, `system_prompt`, `tools`, `client`, + `pass_user_info_to_mcp`, `skills`, `memory`, `realtime_client_factory`, `live_node_name="LIVE"`. +- `compile()` takes `checkpointer`, `store`, `callback_manager`, `shutdown_timeout` (default 30.0). + It does **not** take `media_store` or `interrupt_before` / `interrupt_after` โ€” those belong to the + turn-based super-step executor, which realtime bypasses. +- Tools work like a normal `ToolNode` (reason -> tool -> respond, including barge-in). No + sub-agents / handoff in v1. +- `system_prompt`, `skills`, and `memory` work like `ReactAgent`: the agent's `system_prompt` (plus + the skills trigger table / session-mode skill content and the memory system prompt) is flattened + into the single Gemini Live `system_instruction` at connect, and `{field}` placeholders are + interpolated from state exactly like the turn-based path. Skill/memory **tools** are advertised + normally. Caveat: `system_instruction` is fixed for the session, so state-dependent content + (session-mode skill from a state field, memory preload) is a connect-time snapshot. Mid-session + dynamism goes through `set_skill` / memory tools, which work continuously. + +`LiveAgent` (the graph root `AudioAgent` wraps) is at +`from agentflow.core.realtime.live_agent import LiveAgent`. It is not re-exported from +`agentflow.core.realtime`; import it from the module if you build the graph by hand. + +## Driving a session: `CompiledGraph.arealtime` / `realtime` + +- `arealtime(input_queue, config=None, state=None)` is an async generator yielding normalized + `RealtimeEvent`s. `realtime(...)` is the sync wrapper (run with no active event loop). +- Forcing rule: the graph must contain exactly one `LiveAgent`; ordinary graphs raise. Conversely a + graph containing a `LiveAgent` must use `arealtime()` โ€” `invoke` / `stream` raise. + +```python +from agentflow.core.realtime import LiveInputQueue + +queue = LiveInputQueue() +queue.send_audio(pcm16_bytes) # non-blocking; safe from an audio callback +async for event in app.arealtime(queue, {"thread_id": "t1"}): + ... # AudioDeltaEvent / transcripts / ToolCallEvent / ... +queue.close() # ends the session once the provider goes idle +``` + +## Public API (`agentflow.core.realtime`) + +- `LiveInputQueue` / `LiveInput` / `LiveInputKind` โ€” non-blocking upstream input queue. Methods + (all synchronous, callable from any context): `send_audio`, `send_text`, `send_image` (still + image / video frame, default mime `image/jpeg`), `send_activity_start`, `send_activity_end`, + `close`. +- `RealtimeConfig` โ€” per-session config. Fields and defaults: + - `model: str` (required) + - `response_modalities: list[...] = ["AUDIO"]` (exactly one per session) + - `voice: str | None = None` + - `system_instruction: str | None = None` + - `input_audio_transcription: bool = True` + - `output_audio_transcription: bool = True` + - `vad: VADConfig = VADConfig()` + - `reconnect: ReconnectConfig = ReconnectConfig()` + - `context_window_compression: bool = False` + - `session_resumption: bool = True` + - `tools: list | None = None`, `tools_tags: list[str] | None = None` +- `VADConfig` โ€” voice-activity detection; disable for push-to-talk (manual activity via + `send_activity_start` / `send_activity_end`). +- `ReconnectConfig` โ€” reconnect/backoff for a dropped socket: `base_delay=0.5`, `max_delay=10.0`, + `max_attempts=5` (set `0` to disable error-driven reconnect). +- `RealtimeEvent` โ€” discriminated union (keyed on `type`): `AudioDeltaEvent`, + `InputTranscriptEvent`, `OutputTranscriptEvent`, `ToolCallEvent`, `ToolResultEvent`, + `TurnCompleteEvent`, `InterruptedEvent` (barge-in), `SessionUpdateEvent`, `GoAwayEvent`, + `AgentChangedEvent`, `ErrorEvent`. +- `RealtimeClient` โ€” provider Protocol (one implementation per provider). +- `GeminiLiveClient` / `normalize_message` โ€” the Gemini Live provider client. + +## Reconnection and resumption + +Reconnect is automatic inside the realtime runtime (the builder / `AudioAgent` wires nothing). + +- Provider `go_away` (planned rotation): reconnect immediately, no backoff. +- Transient drop / receive error: exponential backoff `min(base_delay * 2**(n-1), max_delay)`, up to + `max_attempts`, then a fatal `ErrorEvent` (`code="reconnect_failed"`) ends the session. +- Tune per session via `RealtimeConfig.reconnect`: + ```python + from agentflow.core.realtime import RealtimeConfig, ReconnectConfig + RealtimeConfig(model="...", reconnect=ReconnectConfig(base_delay=0.25, max_attempts=8)) + ``` +- Context across reconnects: Gemini streams a resumption handle (`session_update`) that is persisted + to checkpointer thread metadata; reconnect resumes provider-side context (requires + `session_resumption=True`, the default). With no handle (a fresh session on the same `thread_id`), + persisted transcript history is reseeded instead. Cross-session resume therefore needs a + checkpointer. + +## Session and turn lifecycle hooks + +Realtime fires graph/turn hooks through the same `GraphLifecycleHook` used by turn-based graphs +(register via `CallbackManager.register_lifecycle_hook`, pass the manager to +`compile(callback_manager=...)`). These fire only in realtime (no-ops for `invoke` / `stream`): + +- `on_graph_start(ctx, state)` / `on_graph_end(ctx, final_state, messages, total_steps)` โ€” once per + session (the `LIVE` node *is* the graph). `total_steps` = number of turns. +- `on_turn_start(ctx, state, turn_index)` / `on_turn_end(ctx, state, turn_index)` โ€” per model turn + (1-based; a turn spans one model generation, bounded by `turn_complete` or a barge-in). A turn cut + off by session end still gets a balanced `on_turn_end`. + +All hooks may return a modified state to replace the current one. Tool/MCP `before/after/error` +callbacks fire as usual (tools run through `ToolNode`). There is no `AI`-invocation callback or +input-validator pass in realtime (no discrete LLM call); `on_turn_start` / `on_turn_end` are the +per-turn observability stand-in. + +## API server WebSocket bridge (`/v1/graph/live`) + +`agentflow api` exposes `ws:///v1/graph/live` when the configured graph is rooted at a +`LiveAgent`. + +- First frame: a JSON control frame (e.g. `{"model": "...", "thread_id": "abc", "voice": "Puck"}`); + present fields override the agent's build-time config for that session. +- Upstream: binary frame = PCM16 input audio; JSON control frame = + `{"type": "text" | "activity_start" | "activity_end" | "close", ...}`. Image/video input is + currently SDK-only via `LiveInputQueue.send_image`; the WebSocket bridge does not forward image + frames yet. +- Downstream: binary frame = PCM16 model audio; JSON text frame = every other event (transcripts, + `turn_complete`, `interrupted`, `tool_call`, session / `go_away`, `error`). + +## Events / publisher additions + +`Event.REALTIME` event and `ContentType.TRANSCRIPT` content type live in +`agentflow.runtime.publisher.events`. + +## Examples + +`examples/realtime/`: headless WAV-in/WAV-out (`audio_agent_file.py`), live full-duplex microphone +with React-style tool calling (`audio_agent_mic.py`), and the API WebSocket setup +(`agentflow.json` + `graph.py`). See `examples/realtime/README.md`. + +## Source Map + +- Realtime package: https://github.com/10xHub/Agentflow/tree/main/agentflow/agentflow/core/realtime +- LiveAgent: https://github.com/10xHub/Agentflow/blob/main/agentflow/agentflow/core/realtime/live_agent.py +- AudioAgent: https://github.com/10xHub/Agentflow/blob/main/agentflow/agentflow/prebuilt/agent/audio.py +- Realtime drivers (`arealtime` / `realtime`): https://github.com/10xHub/Agentflow/blob/main/agentflow/agentflow/core/graph/compiled_graph.py +- WebSocket bridge: https://github.com/10xHub/agentflow-cli/tree/main/agentflow_cli/src/app/routers/graph diff --git a/agentflow_cli/cli/templates/skills/agent-skills/references/rest-api-and-errors.md b/agentflow_cli/cli/templates/skills/agent-skills/references/rest-api-and-errors.md index 589c548..a52d9a0 100644 --- a/agentflow_cli/cli/templates/skills/agent-skills/references/rest-api-and-errors.md +++ b/agentflow_cli/cli/templates/skills/agent-skills/references/rest-api-and-errors.md @@ -21,6 +21,11 @@ Registered in https://github.com/10xHub/agentflow-cli/blob/main/agentflow_cli/sr - `POST /v1/graph/stop` - `POST /v1/graph/setup` - `POST /v1/graph/fix` +- `WS /v1/graph/live` โ€” realtime audio bridge, exposed only when the configured graph is rooted at a + `LiveAgent`. First frame is a JSON control frame (`model`, `thread_id`, `voice`, ... override + build-time config); upstream binary = PCM16 input audio, JSON control = text/activity/close; + downstream binary = PCM16 model audio, JSON = transcripts/turn_complete/interrupted/tool_call/ + error. See `realtime.md`. ## Thread / Checkpointer Routes diff --git a/agentflow_cli/cli/templates/skills/agent-skills/references/security-and-validators.md b/agentflow_cli/cli/templates/skills/agent-skills/references/security-and-validators.md index e20ea01..83ea88d 100644 --- a/agentflow_cli/cli/templates/skills/agent-skills/references/security-and-validators.md +++ b/agentflow_cli/cli/templates/skills/agent-skills/references/security-and-validators.md @@ -41,6 +41,24 @@ Common safety points: - Model output validation before returning to API clients. - Store/memory write validation. +## Secret redaction for logs + +Helpers in `agentflow.utils` strip credentials from log output: + +- `mask_secrets(text)` โ€” redacts API keys, `Bearer` tokens, `key=value` secrets, and signed-URL + credential query params from a string. +- `SecretRedactionFilter` โ€” a `logging.Filter`; add it to a handler to cover all loggers that + propagate to it. +- `install_secret_redaction(logger_name="agentflow")` โ€” convenience installer that attaches the + filter and returns it. + +```python +from agentflow.utils import mask_secrets, SecretRedactionFilter, install_secret_redaction + +install_secret_redaction() # cover the agentflow logger tree +safe = mask_secrets(raw_text) # one-off redaction +``` + ## API Security Boundary This reference covers graph-level validators. For HTTP auth, read `auth-and-authorization.md`. diff --git a/agentflow_cli/cli/templates/skills/agent-skills/references/state-graph.md b/agentflow_cli/cli/templates/skills/agent-skills/references/state-graph.md index f11fe97..3db36bc 100644 --- a/agentflow_cli/cli/templates/skills/agent-skills/references/state-graph.md +++ b/agentflow_cli/cli/templates/skills/agent-skills/references/state-graph.md @@ -143,6 +143,17 @@ messages = result["messages"] `ainvoke()` is the async equivalent. Returns the full state dict. +### Lifecycle / async context manager + +`CompiledGraph` supports `async with`; `aclose()` runs on exit even if the body raises, and is +idempotent (a second call returns `{"status": "already_closed"}`). + +```python +async with await build_and_compile_graph() as graph: + await graph.ainvoke(input_data) +# aclose() runs automatically here +``` + --- ## Streaming @@ -160,6 +171,14 @@ for chunk in app.stream( `astream()` is the async equivalent. See `streaming.md` for full details on `StreamChunk` fields and `StreamEvent` values. +### Realtime (audio) graphs + +A graph rooted at a `LiveAgent` is driven by `arealtime(input_queue, config=None, state=None)` (async +generator of `RealtimeEvent`s) or the sync `realtime(...)` wrapper, not `invoke` / `stream`. The +forcing rule is mutual: a graph with a `LiveAgent` must use `arealtime()` (`invoke` / `stream` +raise), and `arealtime()` requires exactly one `LiveAgent` (ordinary graphs raise). See +`realtime.md`. + --- ## Config keys diff --git a/agentflow_cli/cli/templates/skills/claude/SKILL.md b/agentflow_cli/cli/templates/skills/claude/SKILL.md index e7b8258..648c2c9 100644 --- a/agentflow_cli/cli/templates/skills/claude/SKILL.md +++ b/agentflow_cli/cli/templates/skills/claude/SKILL.md @@ -1,7 +1,6 @@ --- name: agentflow description: Expert guidance for building, debugging, and extending applications with AgentFlow (10xscale-agentflow). TRIGGER when: code imports from agentflow (e.g. `from agentflow import`, `StateGraph`, `Agent`, `ToolNode`, `AgentState`); user references `agentflow.json` or CLI commands (`agentflow init`, `agentflow api`, `agentflow play`, `agentflow build`, `agentflow skills`); user is building graph-based multi-agent workflows, tools, memory, checkpointing, or streaming with this framework. SKIP: generic Python or multi-agent questions not referencing agentflow; other frameworks (LangGraph, CrewAI, AutoGen) unless comparing. -user-invocable: false --- # Agentflow Project Skill @@ -38,6 +37,7 @@ Treat https://agentflow.10xscale.ai/ as the first source of truth for public pac - Context management, ID generation, and background tasks: `.claude/skills/agentflow/references/context-id-background.md` - Provider internals and adapters: `.claude/skills/agentflow/references/providers-and-adapters.md` - Prompt-injection and validation safety: `.claude/skills/agentflow/references/security-and-validators.md` + - Realtime audio-to-audio voice agents (AudioAgent, Gemini Live, `arealtime`, WebSocket bridge): `.claude/skills/agentflow/references/realtime.md` ### API/CLI SDK - CLI commands and generated project files: `.claude/skills/agentflow/references/cli-commands.md` diff --git a/agentflow_cli/cli/templates/skills/codex/SKILL.md b/agentflow_cli/cli/templates/skills/codex/SKILL.md index 1eea1c6..8ebb067 100644 --- a/agentflow_cli/cli/templates/skills/codex/SKILL.md +++ b/agentflow_cli/cli/templates/skills/codex/SKILL.md @@ -37,6 +37,7 @@ Treat https://agentflow.10xscale.ai/ as the first source of truth for public pac - Context management, ID generation, and background tasks: `.agents/skills/agentflow/references/context-id-background.md` - Provider internals and adapters: `.agents/skills/agentflow/references/providers-and-adapters.md` - Prompt-injection and validation safety: `.agents/skills/agentflow/references/security-and-validators.md` + - Realtime audio-to-audio voice agents (AudioAgent, Gemini Live, `arealtime`, WebSocket bridge): `.agents/skills/agentflow/references/realtime.md` ### API/CLI SDK - CLI commands and generated project files: `.agents/skills/agentflow/references/cli-commands.md` diff --git a/agentflow_cli/cli/templates/skills/copilot/SKILL.md b/agentflow_cli/cli/templates/skills/copilot/SKILL.md index a42d13d..dd18d61 100644 --- a/agentflow_cli/cli/templates/skills/copilot/SKILL.md +++ b/agentflow_cli/cli/templates/skills/copilot/SKILL.md @@ -37,6 +37,7 @@ Treat https://agentflow.10xscale.ai/ as the first source of truth for public pac - Context management, ID generation, and background tasks: `.github/skills/agentflow/references/context-id-background.md` - Provider internals and adapters: `.github/skills/agentflow/references/providers-and-adapters.md` - Prompt-injection and validation safety: `.github/skills/agentflow/references/security-and-validators.md` + - Realtime audio-to-audio voice agents (AudioAgent, Gemini Live, `arealtime`, WebSocket bridge): `.github/skills/agentflow/references/realtime.md` ### API/CLI SDK - CLI commands and generated project files: `.github/skills/agentflow/references/cli-commands.md` diff --git a/agentflow_cli/cli/templates/skills/copilot/agentflow.instructions.md b/agentflow_cli/cli/templates/skills/copilot/agentflow.instructions.md index 9483393..af502e6 100644 --- a/agentflow_cli/cli/templates/skills/copilot/agentflow.instructions.md +++ b/agentflow_cli/cli/templates/skills/copilot/agentflow.instructions.md @@ -42,6 +42,7 @@ Compile graphs once at startup. Keep graph code storage-agnostic; wire dependenc - `SupervisorTeamAgent` โ€” central supervisor routes to specialist worker agents - `SwarmAgent` โ€” peer-to-peer agent handoff without a central supervisor - `RAGAgent` โ€” retrieves documents from a vector store before each LLM call + - `AudioAgent` โ€” realtime full-duplex voice agent (Gemini Live); needs the `[realtime]` extra and is driven by `arealtime()`, not `invoke`/`stream` - Persist conversation state with **checkpointers**. Use **stores** only for cross-thread memory. - Inject business services through **`InjectQ`**, not module-level globals. - Keep API/CLI graph modules storage-agnostic; wire dependencies via `agentflow.json`. diff --git a/agentflow_cli/src/app/core/auth/permissions.py b/agentflow_cli/src/app/core/auth/permissions.py index 48a882b..d37e282 100644 --- a/agentflow_cli/src/app/core/auth/permissions.py +++ b/agentflow_cli/src/app/core/auth/permissions.py @@ -6,11 +6,12 @@ """ from collections.abc import Callable -from typing import Any +from typing import Any, NoReturn -from fastapi import Depends, HTTPException, Request, Response -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi import HTTPException, Response, WebSocket, WebSocketException +from fastapi.security import HTTPAuthorizationCredentials from injectq.integrations import InjectAPI +from starlette.requests import HTTPConnection from agentflow_cli.src.app.core import logger from agentflow_cli.src.app.core.auth.auth_backend import BaseAuth @@ -19,6 +20,88 @@ from agentflow_cli.src.app.core.utils.log_sanitizer import sanitize_for_logging +# Sec-WebSocket-Protocol sentinel carrying the bearer token for browser WebSocket clients. +# The client offers two subprotocols: this sentinel followed by the raw JWT, e.g. +# new WebSocket(url, ["agentflow-bearer", ""]) +# The token rides in a request header, so -- unlike ``?token=`` -- it never lands in URLs, +# access logs, or browser history. The server must echo the sentinel on accept() (see +# ``ws_bearer_subprotocol``) or browsers fail the handshake. +WS_BEARER_SUBPROTOCOL = "agentflow-bearer" + +# A bearer-carrying Sec-WebSocket-Protocol offer is exactly [sentinel, token]. +_BEARER_SUBPROTOCOL_PARTS = 2 + + +def _subprotocols(connection: HTTPConnection) -> list[str]: + header = connection.headers.get("sec-websocket-protocol") + if not header: + return [] + return [p.strip() for p in header.split(",") if p.strip()] + + +def ws_bearer_subprotocol(connection: HTTPConnection) -> str | None: + """Return the subprotocol to echo on ``accept()`` when the client used it for the token. + + Browsers require the server to confirm one of the offered subprotocols; when the bearer + sentinel was offered, accept with it. Returns ``None`` otherwise (plain ``accept()``). + """ + parts = _subprotocols(connection) + if parts and parts[0] == WS_BEARER_SUBPROTOCOL: + return WS_BEARER_SUBPROTOCOL + return None + + +def _extract_credential( + connection: HTTPConnection, +) -> HTTPAuthorizationCredentials | None: + """Extract bearer credentials from a request or WebSocket connection. + + Mirrors ``HTTPBearer(auto_error=False)`` but works for both HTTP and WebSocket routes + (FastAPI cannot inject ``HTTPBearer`` on a WebSocket route). Token sources, most secure + first: + + 1. ``Authorization: Bearer `` header -- for non-browser clients. + 2. ``Sec-WebSocket-Protocol: agentflow-bearer, `` -- browser-settable and kept out + of URLs/logs; the preferred browser mechanism. + 3. ``?token=`` query parameter -- last-resort fallback; the token is exposed in + URLs/access logs/history, so prefer (2) for browser clients. + """ + authorization = connection.headers.get("Authorization") + if authorization: + scheme, _, token = authorization.partition(" ") + if scheme.lower() == "bearer" and token: + return HTTPAuthorizationCredentials(scheme=scheme, credentials=token) + + # [sentinel, token] -- the two subprotocols a browser client offers to carry the bearer. + parts = _subprotocols(connection) + if len(parts) >= _BEARER_SUBPROTOCOL_PARTS and parts[0] == WS_BEARER_SUBPROTOCOL and parts[1]: + return HTTPAuthorizationCredentials(scheme="Bearer", credentials=parts[1]) + + ws_token = connection.query_params.get("token") + if ws_token: + return HTTPAuthorizationCredentials(scheme="Bearer", credentials=ws_token) + + return None + + +# RFC 6455 close code 1008 "Policy Violation" -- the right signal for an auth/authz +# rejection at the WebSocket handshake. +WS_POLICY_VIOLATION = 1008 + + +def _reject(connection: HTTPConnection, status_code: int, detail: str) -> NoReturn: + """Reject a connection with the error type appropriate to its protocol. + + FastAPI translates an ``HTTPException`` into a response only on HTTP routes; on a + WebSocket route it propagates unhandled and tears the socket down abruptly (close + 1006) with a server-side error log. Raise ``WebSocketException`` (close 1008) there + instead so the client sees a clean policy-violation rejection. + """ + if isinstance(connection, WebSocket): + raise WebSocketException(code=WS_POLICY_VIOLATION, reason=detail) + raise HTTPException(status_code=status_code, detail=detail) + + class RequirePermission: """ FastAPI dependency that combines authentication and authorization. @@ -44,7 +127,7 @@ def __init__( self, resource: str, action: str, - extract_resource_id: Callable[[Request], str | None] | None = None, + extract_resource_id: Callable[[HTTPConnection], str | None] | None = None, ): """ Initialize the permission requirement. @@ -60,11 +143,8 @@ def __init__( async def __call__( self, - request: Request, + connection: HTTPConnection, response: Response, - credential: HTTPAuthorizationCredentials = Depends( - HTTPBearer(auto_error=False), - ), config: GraphConfig = InjectAPI(GraphConfig), auth_backend: BaseAuth = InjectAPI(BaseAuth), authz: AuthorizationBackend = InjectAPI(AuthorizationBackend), @@ -72,20 +152,22 @@ async def __call__( """ Verify authentication and authorization. + ``connection`` is typed as ``HTTPConnection`` (the common base of ``Request`` + and ``WebSocket``) so this dependency resolves on both HTTP and WebSocket + routes; FastAPI cannot inject a ``Request`` on a WebSocket route. + Returns: dict: User information if authenticated and authorized Raises: - HTTPException: 403 if authorization fails + HTTPException: 401/403 on HTTP routes when auth or authz fails. + WebSocketException: close 1008 on WebSocket routes for the same failures + (FastAPI does not translate an HTTPException on a WS handshake). """ - # Fallback: WebSocket clients running in the browser cannot set the - # Authorization header. Accept the token from a ``?token=`` query - # parameter as an alternative. This path is only reached when no - # Authorization header was present (credential is None). - if credential is None: - ws_token = request.query_params.get("token") - if ws_token: - credential = HTTPAuthorizationCredentials(scheme="Bearer", credentials=ws_token) + # Extract bearer credentials from the Authorization header, with a + # ``?token=`` query fallback for browser WebSocket clients (see + # _extract_credential). Works for both HTTP and WebSocket connections. + credential = _extract_credential(connection) # Step 1: Check if auth is configured backend = config.auth_config() @@ -103,11 +185,16 @@ async def __call__( logger.error("Auth backend is not configured") user = {} else: - user_result = auth_backend.authenticate( - request, - response, - credential, - ) + try: + user_result = auth_backend.authenticate( + connection, + response, + credential, + ) + except HTTPException as exc: + # JWT/custom backends signal auth failure with HTTPException; convert it + # to a clean WebSocket close on WS routes (see _reject). + _reject(connection, exc.status_code, str(exc.detail)) if user_result and "user_id" not in user_result: logger.error("Authentication failed: 'user_id' not found in user info") user = user_result or {} @@ -115,9 +202,9 @@ async def __call__( # Step 3: Extract resource_id if available resource_id = None if self.extract_resource_id_fn: - resource_id = self.extract_resource_id_fn(request) + resource_id = self.extract_resource_id_fn(connection) else: - resource_id = self._extract_resource_id_from_path(request) + resource_id = self._extract_resource_id_from_path(connection) # Step 4: Authorization if not await authz.authorize( @@ -130,9 +217,10 @@ async def __call__( f"Authorization failed for user {user.get('user_id')} " f"on {self.resource}:{self.action}" ) - raise HTTPException( - status_code=403, - detail=f"Not authorized to {self.action} {self.resource}", + _reject( + connection, + 403, + f"Not authorized to {self.action} {self.resource}", ) # Log successful auth/authz (with sanitized user info) @@ -143,20 +231,20 @@ async def __call__( return user - def _extract_resource_id_from_path(self, request: Request) -> str | None: + def _extract_resource_id_from_path(self, connection: HTTPConnection) -> str | None: """ Extract resource ID from request path parameters. Looks for common patterns like thread_id, memory_id in path params. Args: - request: FastAPI request object + connection: FastAPI request or WebSocket connection Returns: Resource ID as string, or None if not found """ # Check path parameters - path_params = request.path_params + path_params = connection.path_params # Common resource ID patterns for param_name in ["thread_id", "memory_id", "namespace"]: diff --git a/agentflow_cli/src/app/core/config/graph_config.py b/agentflow_cli/src/app/core/config/graph_config.py index 344278d..168b82d 100644 --- a/agentflow_cli/src/app/core/config/graph_config.py +++ b/agentflow_cli/src/app/core/config/graph_config.py @@ -142,6 +142,39 @@ def from_dict(cls, data: dict) -> "RateLimitConfig": ) +@dataclass +class WebSocketConfig: + """WebSocket connection limits parsed from agentflow.json. + + Example:: + + "websocket": { + "max_connections": 100 + } + + ``max_connections`` caps the number of concurrent WebSocket connections this server + *process* accepts (realtime ``/v1/graph/live`` + streaming ``/v1/graph/ws``). ``None`` or + ``0`` means unlimited. It is a per-process limit, like the in-memory rate-limit backend; + run one limiter per worker. WebSocket handshakes are also subject to the global + ``rate_limit`` (they share the same bucket as REST requests), since rate-limit middleware + is HTTP-only and cannot see WebSocket scopes. + """ + + max_connections: int | None + + @classmethod + def from_dict(cls, data: dict) -> "WebSocketConfig": + if not isinstance(data, dict): + raise ValueError("websocket must be an object") + raw = data.get("max_connections") + if raw in (None, 0): + return cls(max_connections=None) + max_connections = int(raw) + if max_connections < 0: + raise ValueError("websocket.max_connections must be a non-negative integer") + return cls(max_connections=max_connections) + + class GraphConfig: def __init__(self, path: str = "agentflow.json"): with Path(path).open() as f: @@ -246,3 +279,14 @@ def rate_limit(self) -> RateLimitConfig | None: if not config.enabled: return None return config + + @property + def websocket(self) -> "WebSocketConfig": + """WebSocket connection limits from agentflow.json (``websocket`` key). + + Returns a config with ``max_connections=None`` (unlimited) when the key is absent. + """ + data = self.data.get("websocket", None) + if data is None: + return WebSocketConfig(max_connections=None) + return WebSocketConfig.from_dict(data) diff --git a/agentflow_cli/src/app/core/config/setup_logs.py b/agentflow_cli/src/app/core/config/setup_logs.py index a14f205..6fc73a0 100644 --- a/agentflow_cli/src/app/core/config/setup_logs.py +++ b/agentflow_cli/src/app/core/config/setup_logs.py @@ -31,31 +31,6 @@ def init_logger(level: int | str = logging.INFO) -> None: fastapi_logger.handlers = gunicorn_error_logger.handlers fastapi_logger.setLevel(level) - # will print debug sql - logger_db_client = logging.getLogger("db_client") - logger_db_client.setLevel(level) - logger_db_client.addHandler(fastapi_logger) - - logger_tortoise = logging.getLogger("tortoise") - logger_tortoise.setLevel(level) - logger_tortoise.addHandler(fastapi_logger) - - # register custom logger here - injector_logging = logging.getLogger("injector") - injector_logging.setLevel(level) - injector_logging.addHandler(fastapi_logger) - - # Register custom logger for coding - # TODO: Change the logger name to the appropriate name - backend_logging = logging.getLogger("BACKEND_BASE") - backend_logging.setLevel(level) - backend_logging.addHandler(fastapi_logger) - - # Package logger - package_logger = logging.getLogger("PACKAGE") - package_logger.setLevel(level) - package_logger.addHandler(fastapi_logger) - # Create console handler and set level to DEBUG console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(level) @@ -71,3 +46,10 @@ def init_logger(level: int | str = logging.INFO) -> None: console_handler.setFormatter(formatter) # Add console handler to logger fastapi_logger.addHandler(console_handler) + + # Route application loggers through the same sanitizing console handler. + # NOTE: addHandler expects a logging.Handler, not a Logger. + for logger_name in ("db_client", "tortoise", "injector", "BACKEND_BASE", "PACKAGE"): + app_logger = logging.getLogger(logger_name) + app_logger.setLevel(level) + app_logger.addHandler(console_handler) diff --git a/agentflow_cli/src/app/core/middleware/rate_limit/keying.py b/agentflow_cli/src/app/core/middleware/rate_limit/keying.py new file mode 100644 index 0000000..17da09d --- /dev/null +++ b/agentflow_cli/src/app/core/middleware/rate_limit/keying.py @@ -0,0 +1,24 @@ +"""Shared client-key derivation for rate limiting. + +Used by both the HTTP ``RateLimitMiddleware`` and the WebSocket connection guard so that +WebSocket handshakes are counted against the *same* rate-limit bucket as REST requests. +Works on any ``HTTPConnection`` (both ``Request`` and ``WebSocket`` subclass it). +""" + +from starlette.requests import HTTPConnection + +from agentflow_cli.src.app.core.config.graph_config import RateLimitConfig + + +def client_key_for(connection: HTTPConnection, config: RateLimitConfig) -> str: + """Derive the rate-limit bucket key for a connection, honoring ``by`` and proxy headers.""" + if config.by == "global": + return "__global__" + + if config.trusted_proxy_headers: + forwarded_for = connection.headers.get("X-Forwarded-For") + if forwarded_for: + return forwarded_for.split(",")[0].strip() + + client = connection.client + return client.host if client else "unknown" diff --git a/agentflow_cli/src/app/core/middleware/rate_limit/middleware.py b/agentflow_cli/src/app/core/middleware/rate_limit/middleware.py index 336ce13..c364ed4 100644 --- a/agentflow_cli/src/app/core/middleware/rate_limit/middleware.py +++ b/agentflow_cli/src/app/core/middleware/rate_limit/middleware.py @@ -9,6 +9,7 @@ from agentflow_cli.src.app.core.config.graph_config import RateLimitConfig from .base import BaseRateLimitBackend +from .keying import client_key_for class RateLimitMiddleware(BaseHTTPMiddleware): @@ -26,16 +27,7 @@ def __init__( self._exclude = frozenset(config.exclude_paths) def _client_key(self, request: Request) -> str: - if self.config.by == "global": - return "__global__" - - if self.config.trusted_proxy_headers: - forwarded_for = request.headers.get("X-Forwarded-For") - if forwarded_for: - return forwarded_for.split(",")[0].strip() - - client = request.client - return client.host if client else "unknown" + return client_key_for(request, self.config) async def dispatch(self, request: Request, call_next): if request.url.path in self._exclude: diff --git a/agentflow_cli/src/app/core/middleware/request_limits.py b/agentflow_cli/src/app/core/middleware/request_limits.py index b2a833d..8699e3e 100644 --- a/agentflow_cli/src/app/core/middleware/request_limits.py +++ b/agentflow_cli/src/app/core/middleware/request_limits.py @@ -38,11 +38,11 @@ async def dispatch(self, request: Request, call_next): content_length = request.headers.get("content-length") if content_length: - content_length = int(content_length) + content_length_bytes = int(content_length) - if content_length > self.max_size: + if content_length_bytes > self.max_size: logger.warning( - f"Request rejected: size {content_length} bytes " + f"Request rejected: size {content_length_bytes} bytes " f"exceeds limit of {self.max_size} bytes " f"({self.max_size_mb:.1f}MB)" ) diff --git a/agentflow_cli/src/app/core/utils/log_sanitizer.py b/agentflow_cli/src/app/core/utils/log_sanitizer.py index f85773f..d714d3e 100644 --- a/agentflow_cli/src/app/core/utils/log_sanitizer.py +++ b/agentflow_cli/src/app/core/utils/log_sanitizer.py @@ -5,6 +5,7 @@ preventing tokens, passwords, and other credentials from appearing in logs. """ +import logging import re from typing import Any @@ -154,11 +155,14 @@ def sanitize_log_message(message: str, *args: Any, **kwargs: Any) -> tuple[str, return message, sanitized_args, sanitized_kwargs -class SanitizingFormatter: +class SanitizingFormatter(logging.Formatter): """ - A mixin or wrapper for log formatters that sanitizes sensitive data. + A wrapper for log formatters that sanitizes sensitive data. - This can be used to wrap existing formatters to add sanitization. + This can be used to wrap existing formatters to add sanitization. It is a + proper ``logging.Formatter`` subclass so it can be passed to + ``Handler.setFormatter`` without type errors, while delegating the actual + formatting to the wrapped base formatter. Example: import logging @@ -168,16 +172,17 @@ class SanitizingFormatter: handler.setFormatter(sanitizing_formatter) """ - def __init__(self, base_formatter): + def __init__(self, base_formatter: logging.Formatter) -> None: """ Initialize the sanitizing formatter. Args: base_formatter: The underlying formatter to wrap """ + super().__init__() self.base_formatter = base_formatter - def format(self, record): + def format(self, record: logging.LogRecord) -> str: """ Format the log record with sanitization. diff --git a/agentflow_cli/src/app/routers/graph/realtime_guard.py b/agentflow_cli/src/app/routers/graph/realtime_guard.py new file mode 100644 index 0000000..6917ba8 --- /dev/null +++ b/agentflow_cli/src/app/routers/graph/realtime_guard.py @@ -0,0 +1,91 @@ +"""Connection guard for the WebSocket endpoints. + +Rate-limit and request-size middleware are ``BaseHTTPMiddleware`` and Starlette runs them +only for HTTP scopes, so WebSocket handshakes bypass them entirely. This module re-applies +two protections at the handshake, as a FastAPI dependency: + +1. The same global rate limit as REST (shared backend + bucket), so opening a socket counts + like any other request. +2. A per-process cap on concurrent WebSocket connections (``websocket.max_connections`` in + agentflow.json). + +Rejections raise ``WebSocketException`` before ``accept()``, so the handshake fails with a +close code instead of leaving a half-open socket. The concurrency slot is released on +teardown of the (yield) dependency, i.e. when the handler returns or the client disconnects. +""" + +from collections.abc import AsyncIterator + +from fastapi import WebSocket, WebSocketException +from injectq.integrations import InjectAPI + +from agentflow_cli.src.app.core import logger +from agentflow_cli.src.app.core.config.graph_config import GraphConfig +from agentflow_cli.src.app.core.middleware.rate_limit.keying import client_key_for + + +# RFC 6455 close code 1013 "Try Again Later" -- the right signal for shed-load rejections. +WS_TRY_AGAIN_LATER = 1013 + + +class _ConnectionRegistry: + """Per-process counter of active WebSocket connections. + + The event loop is single-threaded, so the check-and-increment in :meth:`try_acquire` is + atomic (no ``await`` between the test and the mutation) and needs no lock. This counter is + per process, matching the in-memory rate-limit backend's scope; for a multi-worker + deployment, set ``max_connections`` per worker accordingly. + """ + + def __init__(self) -> None: + self._active = 0 + + @property + def active(self) -> int: + return self._active + + def try_acquire(self, max_connections: int | None) -> bool: + if max_connections is not None and self._active >= max_connections: + return False + self._active += 1 + return True + + def release(self) -> None: + if self._active > 0: + self._active -= 1 + + +# Module-level singleton (per process). +_registry = _ConnectionRegistry() + + +async def realtime_connection_guard( + websocket: WebSocket, + config: GraphConfig = InjectAPI(GraphConfig), +) -> AsyncIterator[None]: + """Gate a WebSocket handshake on the global rate limit and the concurrent-connection cap.""" + # 1) Shared global rate limit (same backend/bucket as the REST middleware). + rl_config = config.rate_limit + backend = getattr(getattr(websocket, "app", None), "state", None) + backend = getattr(backend, "rate_limit_backend", None) + if rl_config is not None and backend is not None: + key = client_key_for(websocket, rl_config) + decision = await backend.check(key, limit=rl_config.requests, window=rl_config.window) + if not decision.allowed: + logger.warning("WebSocket rate limit exceeded for %s", key) + raise WebSocketException(code=WS_TRY_AGAIN_LATER, reason="Rate limit exceeded") + + # 2) Per-process concurrent-connection cap. + max_conn = config.websocket.max_connections + if not _registry.try_acquire(max_conn): + logger.warning( + "WebSocket connection limit reached (active=%d, max=%s)", + _registry.active, + max_conn, + ) + raise WebSocketException(code=WS_TRY_AGAIN_LATER, reason="Too many connections") + + try: + yield + finally: + _registry.release() diff --git a/agentflow_cli/src/app/routers/graph/router.py b/agentflow_cli/src/app/routers/graph/router.py index 817aa9b..12bbe09 100644 --- a/agentflow_cli/src/app/routers/graph/router.py +++ b/agentflow_cli/src/app/routers/graph/router.py @@ -1,13 +1,22 @@ +import asyncio import contextlib +import json from typing import Any +from agentflow.core.realtime.base import ErrorEvent +from agentflow.core.realtime.queue import LiveInputQueue from agentflow.core.state import StreamChunk, StreamEvent from fastapi import APIRouter, Depends, Request, WebSocket, WebSocketDisconnect from fastapi.logger import logger from fastapi.responses import StreamingResponse from injectq.integrations import InjectAPI +from pydantic import ValidationError -from agentflow_cli.src.app.core.auth.permissions import RequirePermission +from agentflow_cli.src.app.core.auth.permissions import ( + RequirePermission, + ws_bearer_subprotocol, +) +from agentflow_cli.src.app.routers.graph.realtime_guard import realtime_connection_guard from agentflow_cli.src.app.routers.graph.schemas.graph_schemas import ( FixGraphRequestSchema, GraphInputSchema, @@ -22,6 +31,24 @@ from agentflow_cli.src.app.utils.swagger_helper import generate_swagger_responses +# Grace period (seconds) to let the model's final response drain to the client after the +# client side ends (``close`` control frame or disconnect). Closing the input queue makes +# the live agent finish its turn and stop once the provider goes idle, so this normally +# returns well within the window; it only bounds a provider that never goes idle. +REALTIME_DRAIN_TIMEOUT = 30.0 + +# Bound the upstream audio queue. WebSocket frames bypass RequestSizeLimitMiddleware (it is +# HTTP-only), so the realtime path must guard memory itself. At ~50 input frames/sec a depth +# of 1000 is ~20s of buffered audio; if the provider send stalls past that the oldest frames +# are dropped (logged) instead of growing memory without bound -- the audio was already behind. +REALTIME_INPUT_QUEUE_MAXSIZE = 1000 + +# Hard cap on a single binary (PCM16) input frame. Normal client chunks are tens of KB +# (16kHz PCM16 = 32KB/s); 1 MiB is ~30s of audio in one frame, far beyond any real chunk. +# Oversized frames are dropped rather than enqueued (again, the middleware does not see them). +REALTIME_MAX_FRAME_BYTES = 1024 * 1024 + + router = APIRouter( tags=["Graph"], ) @@ -279,6 +306,7 @@ async def fix_graph( @router.websocket("/v1/graph/ws") async def websocket_graph( websocket: WebSocket, + _guard: None = Depends(realtime_connection_guard), service: GraphService = InjectAPI(GraphService), user: dict[str, Any] = Depends(RequirePermission("graph", "stream")), ): @@ -307,15 +335,18 @@ async def websocket_graph( Authentication -------------- - Bearer token sent as the standard ``Authorization`` header during the - WebSocket handshake โ€” identical to the HTTP stream route. + Bearer token via the ``Authorization`` header, the ``agentflow-bearer`` + Sec-WebSocket-Protocol (browser-safe), or the ``?token=`` query fallback โ€” + identical to the HTTP stream route. Handshakes are subject to the global rate + limit and the ``websocket.max_connections`` cap. Close codes ----------- 1000 normal closure (client disconnected cleanly) 1011 unexpected server error + 1013 rejected: rate limit or connection cap exceeded (try again later) """ - await websocket.accept() + await websocket.accept(subprotocol=ws_bearer_subprotocol(websocket)) logger.info("WebSocket graph connection accepted") try: @@ -362,3 +393,161 @@ async def websocket_graph( logger.error("WebSocket graph connection error: %s", e) with contextlib.suppress(Exception): await websocket.close(code=1011) + + +def _realtime_event_json(event: Any) -> str: + """Serialize a non-audio RealtimeEvent to a JSON text frame for the client.""" + try: + payload = event.model_dump(mode="json") + except Exception as e: + logger.warning("Realtime event serialization failed (%s): %s", type(event).__name__, e) + payload = {"type": getattr(event, "type", "unknown")} + return json.dumps(payload) + + +@router.websocket("/v1/graph/live") +async def realtime_graph_ws( # noqa: PLR0915 + websocket: WebSocket, + _guard: None = Depends(realtime_connection_guard), + service: GraphService = InjectAPI(GraphService), + user: dict[str, Any] = Depends(RequirePermission("graph", "stream")), +): + """Realtime (audio-to-audio) WebSocket bridge over ``CompiledGraph.arealtime``. + + Protocol (provider-neutral; the client never sees Gemini vs OpenAI) + ------------------------------------------------------------------ + First frame : JSON control ``{model, thread_id?, voice?, modalities?, vad?, ...}`` + Upstream : binary frame = PCM16 input audio + JSON control = ``{type:"activity_start"|"activity_end"|"text"|"close", ...}`` + Downstream : binary frame = PCM16 model audio (``audio_delta``) + JSON text frame = every other event (transcripts, turn_complete, + interrupted, tool_call, session/go_away, error) + + Auth: ``RequirePermission("graph","stream")`` โ€” bearer via the ``Authorization`` header, + the ``agentflow-bearer`` Sec-WebSocket-Protocol (browser-safe), or the ``?token=`` query + fallback. Handshakes are subject to the global rate limit and the + ``websocket.max_connections`` cap (rejected with close code 1013). + """ + await websocket.accept(subprotocol=ws_bearer_subprotocol(websocket)) + logger.info("Realtime WebSocket connection accepted") + + try: + init = await websocket.receive_json() + except WebSocketDisconnect: + logger.info("Realtime client disconnected before init") + return + except Exception as e: + logger.warning("Realtime init frame invalid: %s", e) + with contextlib.suppress(Exception): + await websocket.close(code=1003) + return + + if not isinstance(init, dict): + logger.warning("Realtime init frame must be a JSON object, got %s", type(init).__name__) + with contextlib.suppress(Exception): + await websocket.close(code=1003) + return + + queue = LiveInputQueue(maxsize=REALTIME_INPUT_QUEUE_MAXSIZE) + + async def upstream() -> None: + """Pump client frames into the input queue until close/disconnect.""" + try: + while True: + message = await websocket.receive() + if message.get("type") == "websocket.disconnect": + break + data = message.get("bytes") + if data is not None: + if len(data) > REALTIME_MAX_FRAME_BYTES: + logger.warning( + "Realtime upstream: dropping oversized audio frame (%d bytes > %d)", + len(data), + REALTIME_MAX_FRAME_BYTES, + ) + continue + queue.send_audio(data) + continue + text = message.get("text") + if text is None: + continue + try: + control = json.loads(text) + except json.JSONDecodeError: + logger.warning("Realtime upstream: non-JSON text frame ignored") + continue + ctype = control.get("type") + if ctype == "text": + queue.send_text(control.get("text", "")) + elif ctype == "activity_start": + queue.send_activity_start() + elif ctype == "activity_end": + queue.send_activity_end() + elif ctype == "close": + break + except WebSocketDisconnect: + logger.info("Realtime client disconnected (upstream)") + finally: + queue.close() + + async def downstream() -> None: + """Stream normalized events back: audio as binary, everything else as JSON.""" + try: + async for event in service.realtime_graph(queue, init, user): + if getattr(event, "type", None) == "audio_delta": + await websocket.send_bytes(event.data) + else: + await websocket.send_text(_realtime_event_json(event)) + except (ValidationError, ValueError) as e: + # Bad session config from the init frame (e.g. an invalid ``modalities`` value) + # is a client error, not a server fault. Send a normalized, fatal error event + # so the client can show why instead of seeing an opaque 1011 close. + logger.warning("Realtime session config rejected: %s", e) + with contextlib.suppress(Exception): + await websocket.send_text( + _realtime_event_json( + ErrorEvent(code="invalid_config", message=str(e), fatal=True) + ) + ) + + up_task = asyncio.create_task(upstream()) + down_task = asyncio.create_task(downstream()) + try: + _done, pending = await asyncio.wait( + {up_task, down_task}, return_when=asyncio.FIRST_COMPLETED + ) + + # If the client side finished first (sent ``close`` or disconnected) while the + # model is still responding, give downstream a bounded grace period to drain the + # session's final events instead of cutting them off. Closing the input queue ends + # the live agent's turn and stops it once the provider goes idle, so this returns + # promptly; on a real disconnect the next send fails fast and downstream ends. + if down_task in pending: + queue.close() + await asyncio.wait({down_task}, timeout=REALTIME_DRAIN_TIMEOUT) + + # Cancel whatever is still running: downstream that overran the grace window, or + # upstream once the session ended. + for task in (up_task, down_task): + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await task + + # asyncio.wait never re-raises; surface any failure from the finished task(s) + # (e.g. arealtime rejecting a non-live graph, or a provider/checkpointer error). + for task in (up_task, down_task): + if task.done() and not task.cancelled(): + task.result() + except WebSocketDisconnect: + # The client going away mid-session is a normal termination, not a server fault; + # don't log it as an error or attempt a 1011 close on an already-closed socket. + logger.info("Realtime WebSocket client disconnected") + except Exception as e: + logger.error("Realtime WebSocket error: %s", e) + with contextlib.suppress(Exception): + await websocket.close(code=1011) + finally: + queue.close() + with contextlib.suppress(Exception): + await websocket.close() diff --git a/agentflow_cli/src/app/routers/graph/services/graph_service.py b/agentflow_cli/src/app/routers/graph/services/graph_service.py index f553df1..74d4623 100644 --- a/agentflow_cli/src/app/routers/graph/services/graph_service.py +++ b/agentflow_cli/src/app/routers/graph/services/graph_service.py @@ -96,7 +96,7 @@ async def _save_thread_name( return thread_name - async def _save_thread(self, config: dict[str, Any], thread_id: int): + async def _save_thread(self, config: dict[str, Any], thread_id: str): """ Save the generated thread name to the database. """ @@ -385,6 +385,62 @@ async def stream_graph( + "\n" ) + async def realtime_graph( + self, + input_queue: Any, + init: dict[str, Any], + user: dict[str, Any], + ): + """Bridge a realtime (audio) session over the compiled graph. + + Thin wrapper over ``CompiledGraph.arealtime``: builds the per-session config from + the init control frame + authenticated user, persists thread info, and yields the + normalized RealtimeEvents. The compiled graph must be rooted at a LiveAgent (e.g. + an ``AudioAgent``); otherwise ``arealtime`` raises. + """ + thread_id = init.get("thread_id") or str(uuid4()) + config: dict[str, Any] = { + "thread_id": thread_id, + "user": user, + "user_id": user.get("user_id", "anonymous"), + } + # Map the client init frame onto RealtimeConfig field names so the live agent can + # apply per-session overrides (model/voice/modalities/vad/...). Only present keys + # are forwarded; absent ones fall back to the agent's build-time config. + realtime = self._realtime_overrides(init) + if realtime: + config["realtime"] = realtime + await self._save_thread(config, thread_id) + logger.info("Realtime graph session starting: thread_id=%s", thread_id) + + async for event in self._graph.arealtime(input_queue, config): + yield event + + logger.info("Realtime graph session completed: thread_id=%s", thread_id) + + @staticmethod + def _realtime_overrides(init: dict[str, Any]) -> dict[str, Any]: + """Translate the client init frame into RealtimeConfig field overrides.""" + # init key -> RealtimeConfig field name + mapping = { + "model": "model", + "voice": "voice", + "modalities": "response_modalities", + "vad": "vad", + "system_prompt": "system_instruction", + "tools_tags": "tools_tags", + } + overrides: dict[str, Any] = {} + for init_key, field in mapping.items(): + value = init.get(init_key) + if value is not None: + overrides[field] = value + # Clients commonly send a single modality as a bare string ("AUDIO"); RealtimeConfig + # expects a list. Coerce so the shorthand doesn't trip the one-modality validator. + if isinstance(overrides.get("response_modalities"), str): + overrides["response_modalities"] = [overrides["response_modalities"]] + return overrides + async def graph_details(self) -> GraphSchema: try: logger.info("Getting graph details") diff --git a/pyproject.toml b/pyproject.toml index 548f6a3..b5dcbb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -206,7 +206,7 @@ addopts = [ "--cov-report=html", "--cov-report=term-missing", "--cov-report=xml", - "--cov-fail-under=0", + "--cov-fail-under=80", "--strict-markers", "-v" ] diff --git a/tests/unit_tests/test_permissions_auth.py b/tests/unit_tests/test_permissions_auth.py index a6113c5..3c23efe 100644 --- a/tests/unit_tests/test_permissions_auth.py +++ b/tests/unit_tests/test_permissions_auth.py @@ -9,9 +9,13 @@ @pytest.fixture def mock_request(): - """Create a mock FastAPI request.""" + """Create a mock FastAPI request/connection.""" request = MagicMock(spec=Request) request.path_params = {} + # _extract_credential reads these off the connection; use real containers so the + # Authorization-header / ?token= parsing works against the mock. + request.headers = {} + request.query_params = {} return request @@ -118,7 +122,7 @@ async def test_call_with_auth_not_configured( perm = RequirePermission("graph", "invoke") result = await perm( - mock_request, mock_response, mock_credential, mock_config, mock_auth_backend, mock_authz + mock_request, mock_response, mock_config, mock_auth_backend, mock_authz ) assert result == {} @@ -139,7 +143,7 @@ async def test_call_with_valid_auth_and_authz( perm = RequirePermission("graph", "invoke") result = await perm( - mock_request, mock_response, mock_credential, mock_config, mock_auth_backend, mock_authz + mock_request, mock_response, mock_config, mock_auth_backend, mock_authz ) assert result == {"user_id": "test-user"} @@ -162,7 +166,6 @@ async def test_call_auth_backend_not_configured( result = await perm( mock_request, mock_response, - mock_credential, mock_config, None, MagicMock(authorize=AsyncMock(return_value=True)), @@ -193,7 +196,6 @@ async def test_call_authorization_failed( await perm( mock_request, mock_response, - mock_credential, mock_config, mock_auth_backend, mock_authz, @@ -223,7 +225,6 @@ async def test_call_authentication_missing_user_id( result = await perm( mock_request, mock_response, - mock_credential, mock_config, mock_auth_backend, mock_authz, @@ -250,7 +251,7 @@ def custom_extractor(request): perm = RequirePermission("graph", "invoke", extract_resource_id=custom_extractor) result = await perm( - mock_request, mock_response, mock_credential, mock_config, mock_auth_backend, mock_authz + mock_request, mock_response, mock_config, mock_auth_backend, mock_authz ) # Verify authorize was called with the custom resource ID @@ -357,7 +358,7 @@ async def test_full_flow_with_auth_configured_and_authorized( perm = RequirePermission("checkpointer", "read") result = await perm( - mock_request, mock_response, mock_credential, mock_config, mock_auth_backend, mock_authz + mock_request, mock_response, mock_config, mock_auth_backend, mock_authz ) assert result == {"user_id": "user-123", "role": "admin"} @@ -386,7 +387,7 @@ async def test_full_flow_auth_not_configured_skips_checks( perm = RequirePermission("graph", "invoke") result = await perm( - mock_request, mock_response, mock_credential, mock_config, mock_auth_backend, mock_authz + mock_request, mock_response, mock_config, mock_auth_backend, mock_authz ) assert result == {} diff --git a/tests/unit_tests/test_websocket_auth.py b/tests/unit_tests/test_websocket_auth.py new file mode 100644 index 0000000..fc6dffc --- /dev/null +++ b/tests/unit_tests/test_websocket_auth.py @@ -0,0 +1,118 @@ +"""Regression tests for auth on WebSocket routes. + +These drive ``RequirePermission`` through FastAPI's real dependency-injection path on a +``@app.websocket(...)`` route (not by calling the endpoint function directly). That is the +only way to catch the class of bug where a dependency declares ``request: Request`` and +FastAPI cannot inject it on a WebSocket connection -- the previous failure mode raised +``TypeError: __call__() missing 1 required positional argument: 'request'`` at connect time, +which endpoint-level unit tests (that pass ``user=...`` directly) could never surface. +""" + +from typing import Any + +import pytest +from fastapi import Depends, FastAPI, WebSocket +from fastapi.testclient import TestClient +from injectq import InjectQ +from injectq.integrations import setup_fastapi + +from agentflow_cli.src.app.core.auth.auth_backend import BaseAuth +from agentflow_cli.src.app.core.auth.authorization import AuthorizationBackend +from agentflow_cli.src.app.core.auth.permissions import ( + RequirePermission, + _extract_credential, +) +from agentflow_cli.src.app.core.config.graph_config import GraphConfig + + +class _FakeAuth(BaseAuth): + """Authenticate by treating the bearer token as the user_id.""" + + def authenticate(self, request, response, credential): # type: ignore[override] + if credential is None: + return {} + return {"user_id": credential.credentials} + + +class _AllowAuthz(AuthorizationBackend): + def __init__(self): + pass + + async def authorize(self, user, resource, action, resource_id=None): # type: ignore[override] + return True + + +def _build_client(auth_configured: bool) -> TestClient: + container = InjectQ() + + config = type( + "_Cfg", + (), + {"auth_config": staticmethod(lambda: "custom" if auth_configured else None)}, + )() + container.bind_instance(GraphConfig, config, allow_none=True) + container.bind_instance(BaseAuth, _FakeAuth(), allow_none=True) + container.bind_instance(AuthorizationBackend, _AllowAuthz(), allow_none=True) + + app = FastAPI() + setup_fastapi(container, app) + + @app.websocket("/ws") + async def ws( + websocket: WebSocket, + user: dict[str, Any] = Depends(RequirePermission("graph", "stream")), + ): + await websocket.accept() + await websocket.send_json(user) + await websocket.close() + + return TestClient(app) + + +class TestWebSocketAuthResolves: + def test_token_query_param_authenticates_on_websocket(self): + """The ?token= fallback must resolve the dependency on a WS route (was a TypeError).""" + client = _build_client(auth_configured=True) + with client.websocket_connect("/ws?token=alice") as conn: + assert conn.receive_json() == {"user_id": "alice"} + + def test_authorization_header_authenticates_on_websocket(self): + client = _build_client(auth_configured=True) + with client.websocket_connect( + "/ws", headers={"Authorization": "Bearer bob"} + ) as conn: + assert conn.receive_json() == {"user_id": "bob"} + + def test_auth_not_configured_yields_empty_user_on_websocket(self): + client = _build_client(auth_configured=False) + with client.websocket_connect("/ws") as conn: + assert conn.receive_json() == {} + + +class TestExtractCredential: + def test_bearer_header_parsed(self): + conn = type("_C", (), {"headers": {"Authorization": "Bearer xyz"}, "query_params": {}})() + cred = _extract_credential(conn) + assert cred is not None + assert cred.credentials == "xyz" + assert cred.scheme == "Bearer" + + def test_query_token_fallback_when_no_header(self): + conn = type("_C", (), {"headers": {}, "query_params": {"token": "qtok"}})() + cred = _extract_credential(conn) + assert cred is not None + assert cred.credentials == "qtok" + + def test_header_takes_priority_over_query(self): + conn = type( + "_C", (), {"headers": {"Authorization": "Bearer hdr"}, "query_params": {"token": "q"}} + )() + assert _extract_credential(conn).credentials == "hdr" + + def test_non_bearer_scheme_ignored(self): + conn = type("_C", (), {"headers": {"Authorization": "Basic abc"}, "query_params": {}})() + assert _extract_credential(conn) is None + + def test_no_credentials_returns_none(self): + conn = type("_C", (), {"headers": {}, "query_params": {}})() + assert _extract_credential(conn) is None diff --git a/tests/unit_tests/test_websocket_guard.py b/tests/unit_tests/test_websocket_guard.py new file mode 100644 index 0000000..e993ee9 --- /dev/null +++ b/tests/unit_tests/test_websocket_guard.py @@ -0,0 +1,195 @@ +"""Tests for the WebSocket connection guard and secure token transport. + +Covers what middleware cannot, because rate-limit / request-size middleware are HTTP-only: + - the global rate limit applied at the WS handshake (shared backend/bucket with REST), + - the per-process concurrent-connection cap (websocket.max_connections), + - bearer token via the Sec-WebSocket-Protocol sentinel (kept out of URLs/logs). + +These drive the guard through FastAPI's real DI on a @app.websocket route. +""" + +from typing import Any + +import pytest +from fastapi import Depends, FastAPI, WebSocket +from fastapi.testclient import TestClient +from injectq import InjectQ +from injectq.integrations import setup_fastapi +from starlette.websockets import WebSocketDisconnect + +from agentflow_cli.src.app.core.auth.auth_backend import BaseAuth +from agentflow_cli.src.app.core.auth.authorization import AuthorizationBackend +from agentflow_cli.src.app.core.auth.permissions import ( + WS_BEARER_SUBPROTOCOL, + RequirePermission, + ws_bearer_subprotocol, +) +from agentflow_cli.src.app.core.config.graph_config import GraphConfig, WebSocketConfig +from agentflow_cli.src.app.core.middleware.rate_limit.base import RateLimitDecision +from agentflow_cli.src.app.routers.graph import realtime_guard +from agentflow_cli.src.app.routers.graph.realtime_guard import realtime_connection_guard + + +class _FakeAuth(BaseAuth): + def authenticate(self, request, response, credential): # type: ignore[override] + return {} if credential is None else {"user_id": credential.credentials} + + +class _AllowAuthz(AuthorizationBackend): + def __init__(self): + pass + + async def authorize(self, user, resource, action, resource_id=None): # type: ignore[override] + return True + + +class _FakeRateBackend: + def __init__(self, allowed: bool): + self._allowed = allowed + self.calls = 0 + + async def check(self, key, *, limit, window): + self.calls += 1 + return RateLimitDecision(allowed=self._allowed, remaining=0, reset_after=5) + + async def close(self): + pass + + +class _StubConfig: + def __init__(self, rate_limit, max_connections, auth_configured=True): + self._rl = rate_limit + self._ws = WebSocketConfig(max_connections=max_connections) + self._auth_configured = auth_configured + + def auth_config(self): + return "custom" if self._auth_configured else None + + @property + def rate_limit(self): + return self._rl + + @property + def websocket(self): + return self._ws + + +def _build_app(config, rate_backend=None, *, auth=False): + container = InjectQ() + container.bind_instance(GraphConfig, config, allow_none=True) + container.bind_instance(BaseAuth, _FakeAuth(), allow_none=True) + container.bind_instance(AuthorizationBackend, _AllowAuthz(), allow_none=True) + + app = FastAPI() + setup_fastapi(container, app) + if rate_backend is not None: + app.state.rate_limit_backend = rate_backend + + deps = [Depends(realtime_connection_guard)] + if auth: + deps.append(Depends(RequirePermission("graph", "stream"))) + + @app.websocket("/ws") + async def ws( + websocket: WebSocket, + _guard: None = deps[0], + user: dict[str, Any] = (deps[1] if auth else Depends(lambda: {})), + ): + await websocket.accept(subprotocol=ws_bearer_subprotocol(websocket)) + await websocket.send_json({"active": realtime_guard._registry.active, "user": user}) + await websocket.close() + + return app + + +@pytest.fixture(autouse=True) +def _reset_registry(): + realtime_guard._registry._active = 0 + yield + realtime_guard._registry._active = 0 + + +class TestConcurrencyCap: + def test_under_cap_connects_and_tracks_active(self): + app = _build_app(_StubConfig(None, 2)) + client = TestClient(app) + with client.websocket_connect("/ws") as conn: + msg = conn.receive_json() + assert msg["active"] == 1 + + def test_slot_released_after_disconnect(self): + app = _build_app(_StubConfig(None, 1)) + client = TestClient(app) + # Two sequential connections both succeed because the first releases its slot. + for _ in range(2): + with client.websocket_connect("/ws") as conn: + assert conn.receive_json()["active"] == 1 + assert realtime_guard._registry.active == 0 + + def test_over_cap_rejected(self): + app = _build_app(_StubConfig(None, 1)) + realtime_guard._registry._active = 1 # simulate one already-active connection + client = TestClient(app) + with pytest.raises(WebSocketDisconnect) as exc: + with client.websocket_connect("/ws") as conn: + conn.receive_json() + assert exc.value.code == realtime_guard.WS_TRY_AGAIN_LATER + + +class TestHandshakeRateLimit: + def test_rate_limited_handshake_rejected(self): + backend = _FakeRateBackend(allowed=False) + config = _StubConfig(_RL(), None) + app = _build_app(config, rate_backend=backend) + client = TestClient(app) + with pytest.raises(WebSocketDisconnect) as exc: + with client.websocket_connect("/ws") as conn: + conn.receive_json() + assert exc.value.code == realtime_guard.WS_TRY_AGAIN_LATER + assert backend.calls == 1 + + def test_allowed_handshake_passes(self): + backend = _FakeRateBackend(allowed=True) + config = _StubConfig(_RL(), None) + app = _build_app(config, rate_backend=backend) + client = TestClient(app) + with client.websocket_connect("/ws") as conn: + assert conn.receive_json()["active"] == 1 + assert backend.calls == 1 + + +class _RL: + """Minimal RateLimitConfig stand-in for keying + check().""" + + by = "global" + requests = 100 + window = 60 + trusted_proxy_headers = False + + +class TestSubprotocolToken: + def test_token_via_subprotocol_authenticates_and_is_echoed(self): + app = _build_app(_StubConfig(None, None), auth=True) + client = TestClient(app) + with client.websocket_connect( + "/ws", subprotocols=[WS_BEARER_SUBPROTOCOL, "alice"] + ) as conn: + msg = conn.receive_json() + assert msg["user"] == {"user_id": "alice"} + # Server must echo the sentinel subprotocol or browsers fail the handshake. + assert conn.accepted_subprotocol == WS_BEARER_SUBPROTOCOL + + +class TestWebSocketConfig: + def test_absent_is_unlimited(self): + assert WebSocketConfig.from_dict({}).max_connections is None + + def test_zero_is_unlimited(self): + assert WebSocketConfig.from_dict({"max_connections": 0}).max_connections is None + + def test_positive_value(self): + assert WebSocketConfig.from_dict({"max_connections": 25}).max_connections == 25 + + def test_negative_rejected(self): + with pytest.raises(ValueError): + WebSocketConfig.from_dict({"max_connections": -1}) diff --git a/tests/unit_tests/test_websocket_realtime.py b/tests/unit_tests/test_websocket_realtime.py new file mode 100644 index 0000000..da15a1c --- /dev/null +++ b/tests/unit_tests/test_websocket_realtime.py @@ -0,0 +1,291 @@ +"""Unit tests for the /v1/graph/live realtime WebSocket endpoint. + +Mirrors test_websocket_graph.py: the GraphService is mocked (so no live provider), +and a mock WebSocket drives the binary/JSON frame protocol. +""" + +import asyncio +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import WebSocketDisconnect + +from agentflow_cli.src.app.routers.graph.router import realtime_graph_ws + + +def _audio_event(data: bytes): + return SimpleNamespace(type="audio_delta", data=data, model_dump=lambda mode=None: {}) + + +def _json_event(type_: str, **fields): + payload = {"type": type_, **fields} + return SimpleNamespace(type=type_, model_dump=lambda mode=None: payload, **fields) + + +def _make_websocket(receive_side_effects: list, init: dict): + ws = MagicMock() + ws.accept = AsyncMock() + ws.receive_json = AsyncMock(return_value=init) + ws.receive = AsyncMock(side_effect=receive_side_effects) + ws.send_bytes = AsyncMock() + ws.send_text = AsyncMock() + ws.send_json = AsyncMock() + ws.close = AsyncMock() + return ws + + +def _make_service(events: list): + service = MagicMock() + captured = {} + + async def _gen(input_queue, init, user): + captured["queue"] = input_queue + captured["init"] = init + captured["user"] = user + for e in events: + yield e + + service.realtime_graph = _gen + service._captured = captured + return service + + +class TestRealtimeWebSocket: + @pytest.mark.asyncio + async def test_audio_events_sent_as_binary_others_as_json(self): + events = [_audio_event(b"\x01\x02"), _json_event("turn_complete")] + ws = _make_websocket([WebSocketDisconnect()], init={"model": "gemini-2.5-flash-live"}) + service = _make_service(events) + + await realtime_graph_ws(websocket=ws, service=service, user={"user_id": "u1"}) + + ws.accept.assert_awaited() + ws.send_bytes.assert_awaited_once_with(b"\x01\x02") + sent_text = [c.args[0] for c in ws.send_text.call_args_list] + assert any(json.loads(t)["type"] == "turn_complete" for t in sent_text) + + @pytest.mark.asyncio + async def test_upstream_binary_frame_becomes_audio_input(self): + ws = _make_websocket( + [ + {"type": "websocket.receive", "bytes": b"\xaa\xbb"}, + {"type": "websocket.receive", "text": json.dumps({"type": "close"})}, + ], + init={"model": "gemini-2.5-flash-live"}, + ) + service = _make_service([]) # downstream finishes immediately + + await realtime_graph_ws(websocket=ws, service=service, user={"user_id": "u1"}) + + q = service._captured["queue"] + item = q.get_nowait() + assert item.kind == "audio" + assert item.data == b"\xaa\xbb" + + @pytest.mark.asyncio + async def test_oversized_binary_frame_dropped(self): + from agentflow_cli.src.app.routers.graph.router import REALTIME_MAX_FRAME_BYTES + + big = b"\x00" * (REALTIME_MAX_FRAME_BYTES + 1) + ws = _make_websocket( + [ + {"type": "websocket.receive", "bytes": big}, + {"type": "websocket.receive", "bytes": b"\x01\x02"}, + {"type": "websocket.receive", "text": json.dumps({"type": "close"})}, + ], + init={"model": "gemini-2.5-flash-live"}, + ) + service = _make_service([]) + + await realtime_graph_ws(websocket=ws, service=service, user={"user_id": "u1"}) + + q = service._captured["queue"] + kinds = [] + try: + while True: + kinds.append(q.get_nowait()) + except Exception: + pass + # Only the small frame is enqueued; the oversized one is dropped. + audio = [i for i in kinds if i.kind == "audio"] + assert len(audio) == 1 + assert audio[0].data == b"\x01\x02" + + @pytest.mark.asyncio + async def test_upstream_text_control_frames_dispatch(self): + ws = _make_websocket( + [ + {"type": "websocket.receive", "text": json.dumps({"type": "activity_start"})}, + {"type": "websocket.receive", "text": json.dumps({"type": "text", "text": "hi"})}, + {"type": "websocket.receive", "text": json.dumps({"type": "close"})}, + ], + init={"model": "gemini-2.5-flash-live"}, + ) + service = _make_service([]) + + await realtime_graph_ws(websocket=ws, service=service, user={"user_id": "u1"}) + + q = service._captured["queue"] + kinds = [] + try: + while True: + kinds.append(q.get_nowait().kind) + except Exception: + pass + assert "activity_start" in kinds + assert "text" in kinds + + @pytest.mark.asyncio + async def test_downstream_error_closes_with_1011(self): + ws = _make_websocket([WebSocketDisconnect()], init={"model": "gemini-2.5-flash-live"}) + service = MagicMock() + + async def _boom(input_queue, init, user): + raise RuntimeError("graph has no LiveAgent") + yield # make it an async generator + + service.realtime_graph = _boom + + await realtime_graph_ws(websocket=ws, service=service, user={"user_id": "u1"}) + + close_codes = [c.kwargs.get("code") for c in ws.close.call_args_list] + assert 1011 in close_codes + + @pytest.mark.asyncio + async def test_non_dict_init_frame_rejected(self): + ws = _make_websocket([WebSocketDisconnect()], init=["not", "a", "dict"]) + service = _make_service([]) + + await realtime_graph_ws(websocket=ws, service=service, user={"user_id": "u1"}) + + close_codes = [c.kwargs.get("code") for c in ws.close.call_args_list] + assert 1003 in close_codes + # service must never be reached with a malformed init + assert "init" not in service._captured + + @pytest.mark.asyncio + async def test_init_frame_passed_to_service(self): + init = {"model": "gemini-2.5-flash-live", "thread_id": "t-99", "voice": "Puck"} + ws = _make_websocket([WebSocketDisconnect()], init=init) + service = _make_service([]) + + await realtime_graph_ws(websocket=ws, service=service, user={"user_id": "u1"}) + + assert service._captured["init"]["thread_id"] == "t-99" + assert service._captured["user"] == {"user_id": "u1"} + + @pytest.mark.asyncio + async def test_close_frame_does_not_truncate_final_events(self): + """A `close` control frame ends input but must not cut off the model's final + response: downstream is drained, not cancelled, when the client side finishes first. + """ + events = [ + _audio_event(b"a"), + _json_event("output_transcript", text="bye", finished=True), + _json_event("turn_complete"), + ] + + async def _slow_gen(input_queue, init, user): + # Still producing after the client closed input (model finishing its turn). + for e in events: + await asyncio.sleep(0.01) + yield e + + service = MagicMock() + service.realtime_graph = _slow_gen + + ws = _make_websocket( + [{"type": "websocket.receive", "text": json.dumps({"type": "close"})}], + init={"model": "gemini-2.5-flash-live"}, + ) + + await realtime_graph_ws(websocket=ws, service=service, user={"user_id": "u1"}) + + # All three trailing events must still reach the client. + assert ws.send_bytes.await_count == 1 + sent_types = [json.loads(c.args[0])["type"] for c in ws.send_text.call_args_list] + assert "output_transcript" in sent_types + assert "turn_complete" in sent_types + + @pytest.mark.asyncio + async def test_invalid_modalities_sends_error_event_not_opaque_close(self): + """A bad session config surfaces as a normalized fatal error frame, not a bare 1011.""" + + async def _boom(input_queue, init, user): + raise ValueError("response_modalities must contain exactly one modality") + yield # make it an async generator + + service = MagicMock() + service.realtime_graph = _boom + + ws = _make_websocket([WebSocketDisconnect()], init={"model": "gemini-x"}) + + await realtime_graph_ws(websocket=ws, service=service, user={"user_id": "u1"}) + + sent = [json.loads(c.args[0]) for c in ws.send_text.call_args_list] + errors = [m for m in sent if m.get("type") == "error"] + assert errors and errors[0]["fatal"] is True + assert errors[0]["code"] == "invalid_config" + + +class TestRealtimeGraphService: + @pytest.mark.asyncio + async def test_init_session_params_mapped_into_realtime_config(self): + from agentflow_cli.src.app.routers.graph.services.graph_service import GraphService + + captured = {} + + async def _arealtime(input_queue, config): + captured["config"] = config + for _ in (): + yield + + graph = MagicMock() + graph.arealtime = _arealtime + svc = GraphService(graph=graph, checkpointer=AsyncMock(), config=MagicMock()) + + init = { + "model": "gemini-x", + "voice": "Puck", + "modalities": ["TEXT"], + "vad": {"enabled": False}, + "system_prompt": "be brief", + "tools_tags": ["weather"], + "thread_id": "t1", + } + + async for _ in svc.realtime_graph(MagicMock(), init, {"user_id": "u1"}): + pass + + rt = captured["config"]["realtime"] + assert rt["model"] == "gemini-x" + assert rt["voice"] == "Puck" + assert rt["response_modalities"] == ["TEXT"] + assert rt["vad"] == {"enabled": False} + assert rt["system_instruction"] == "be brief" + assert rt["tools_tags"] == ["weather"] + + @pytest.mark.asyncio + async def test_string_modalities_coerced_to_list(self): + from agentflow_cli.src.app.routers.graph.services.graph_service import GraphService + + captured = {} + + async def _arealtime(input_queue, config): + captured["config"] = config + for _ in (): + yield + + graph = MagicMock() + graph.arealtime = _arealtime + svc = GraphService(graph=graph, checkpointer=AsyncMock(), config=MagicMock()) + + # Client shorthand: a bare string instead of a list. + init = {"model": "gemini-x", "modalities": "TEXT"} + + async for _ in svc.realtime_graph(MagicMock(), init, {"user_id": "u1"}): + pass + + assert captured["config"]["realtime"]["response_modalities"] == ["TEXT"]