diff --git a/pyproject.toml b/pyproject.toml index 1477e24..a130038 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,10 +19,12 @@ classifiers = [ ] keywords = ["dremio", "cli", "data-lake", "sql"] dependencies = [ - "typer[all]>=0.9", + "typer>=0.9", "httpx>=0.27", "pyyaml>=6", "pydantic>=2", + "rich>=13", + "prompt-toolkit>=3.0", ] [project.scripts] diff --git a/src/drs/chat_render.py b/src/drs/chat_render.py new file mode 100644 index 0000000..0746211 --- /dev/null +++ b/src/drs/chat_render.py @@ -0,0 +1,314 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Rich terminal renderer for Dremio AI Agent chat sessions.""" + +from __future__ import annotations + +import json +import sys +import threading +from typing import Any + +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.syntax import Syntax +from rich.text import Text + +# Spinner frames for the "Thinking..." animation. +_SPINNER_FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] +_SPINNER_INTERVAL = 0.08 + + +class _Spinner: + """A lightweight terminal spinner that does NOT use Rich's Live display. + + Rich's ``Status`` / ``Live`` captures all ``console.print()`` calls and + renders them on its own refresh cycle, which can visually delay SSE events. + This spinner writes its animation directly to *stderr* using ANSI escape + codes so that ``console.print()`` output flows to the terminal immediately. + """ + + def __init__(self, message: str = "Thinking...") -> None: + self._message = message + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + + def start(self) -> None: + if self._thread is not None: + return + self._stop_event.clear() + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def stop(self) -> None: + if self._thread is None: + return + self._stop_event.set() + self._thread.join(timeout=1.0) + self._thread = None + # Clear the spinner line + sys.stderr.write("\r\033[K") + sys.stderr.flush() + + def _run(self) -> None: + idx = 0 + while not self._stop_event.is_set(): + frame = _SPINNER_FRAMES[idx % len(_SPINNER_FRAMES)] + sys.stderr.write(f"\r{frame} {self._message}") + sys.stderr.flush() + idx += 1 + self._stop_event.wait(_SPINNER_INTERVAL) + + +class ChatRenderer: + """Renders agent SSE events to a Rich console (interactive mode).""" + + def __init__(self, console: Console | None = None) -> None: + self.console = console or Console() + self._spinner: _Spinner | None = None + + # -- Model output -- + + def render_model_chunk(self, name: str, result: dict) -> None: + """Render a model output chunk based on the task type.""" + text = result.get("text", "") + if not text: + return + + if name == "modelGenerateSql": + self.console.print(Syntax(text, "sql", theme="monokai", line_numbers=False)) + elif name == "modelReject": + self.console.print(Text(text, style="bold yellow")) + else: + # modelGeneric, modelSqlAnswer, and others + self.console.print(Markdown(text)) + + # -- Tool events -- + + def render_tool_request( + self, + call_id: str, + name: str, + arguments: dict | None = None, + title: str | None = None, + ) -> None: + """Show a tool call request in a bordered panel.""" + display_name = title or name + args_summary = "" + if arguments: + args_summary = _summarize_args(arguments) + + body = Text(args_summary, style="dim") if args_summary else Text("(no arguments)", style="dim") + self.console.print( + Panel(body, title=f"[bold cyan]Tool: {display_name}[/]", border_style="cyan", expand=False), + ) + + def render_tool_response(self, call_id: str, name: str, result: Any) -> None: + """Show a tool result in a muted panel.""" + if isinstance(result, dict): + text = json.dumps(result, indent=2, default=str) + if len(text) > 500: + text = text[:500] + "\n..." + elif isinstance(result, str): + text = result[:500] + ("..." if len(result) > 500 else "") + else: + text = str(result)[:500] + + self.console.print( + Panel(Text(text, style="dim"), title=f"[dim]{name} result[/]", border_style="dim", expand=False), + ) + + def render_tool_progress(self, status: str, message: str) -> None: + """Inline progress for long-running tools.""" + self.console.print(Text(f" ⏳ {message}", style="dim italic")) + + # -- Errors -- + + def render_error(self, error_type: str, message: str) -> None: + """Red error display.""" + self.console.print(Text(f"Error ({error_type}): {message}", style="bold red")) + + # -- Conversation metadata -- + + def render_conversation_title(self, title: str) -> None: + """Show conversation title update.""" + self.console.print(Text(f"📝 {title}", style="bold")) + + # -- Spinner -- + + def start_spinner(self) -> None: + """Start an animated 'Thinking...' indicator.""" + if self._spinner is None: + self._spinner = _Spinner() + self._spinner.start() + + def stop_spinner(self) -> None: + """Stop the spinner.""" + if self._spinner is not None: + self._spinner.stop() + self._spinner = None + + # -- Tool approval -- + + def prompt_tool_approval(self, nonce: str, tools: list[dict]) -> dict: + """Ask user Y/n for each pending tool call; return approval payload. + + Returns a dict suitable for the ``approvals`` field of the message body. + """ + decisions: list[dict] = [] + for tool in tools: + tool_name = tool.get("name", "unknown") + tool_id = tool.get("callId", tool.get("id", "")) + args = tool.get("arguments", {}) + self.render_tool_request(tool_id, tool_name, args) + try: + answer = self.console.input(f" Approve [bold cyan]{tool_name}[/]? [Y/n] ").strip().lower() + except (EOFError, KeyboardInterrupt): + answer = "n" + approved = answer in ("", "y", "yes") + decisions.append( + { + "callId": tool_id, + "decision": "approved" if approved else "denied", + } + ) + return { + "approvalNonce": nonce, + "toolDecisions": decisions, + } + + # -- Separators -- + + def print_separator(self) -> None: + """Print a visual separator between exchanges.""" + self.console.print(Text("─" * 40, style="dim")) + + def print_welcome(self, conv_id: str | None = None) -> None: + """Print welcome banner for interactive mode.""" + self.console.print( + Panel( + "[bold]Dremio AI Chat[/]\n" + "Type a question or use /help for commands.\n" + "Press [bold]Ctrl+D[/] or type [bold]/quit[/] to exit.", + border_style="blue", + expand=False, + ), + ) + if conv_id: + self.console.print(Text(f"Resuming conversation: {conv_id}", style="dim")) + + def print_help(self) -> None: + """Print slash command help.""" + help_text = ( + "[bold]Commands:[/]\n" + " /new Start a new conversation\n" + " /list List recent conversations\n" + " /continue Resume a conversation by ID\n" + " /history Show message history for current conversation\n" + " /cancel Cancel the active run\n" + " /delete [id] Delete current or specified conversation\n" + " /info Show current conversation metadata\n" + " /quit Exit (or Ctrl+D)" + ) + self.console.print(Panel(help_text, border_style="blue", expand=False)) + + +class PlainRenderer: + """Non-interactive renderer. + + When stdout is a terminal, model output is rendered as Rich Markdown. + When piped, plain text is written with no ANSI codes. + Tool events and progress always go to stderr. + """ + + def __init__(self) -> None: + self._is_tty = sys.stdout.isatty() + self._console = Console() if self._is_tty else None + self._stderr_console = Console(stderr=True, highlight=False) + self._spinner: _Spinner | None = None + + def render_model_chunk(self, name: str, result: dict) -> None: + text = result.get("text", "") + if not text: + return + if self._console is not None: + if name == "modelGenerateSql": + self._console.print(Syntax(text, "sql", theme="monokai", line_numbers=False)) + elif name == "modelReject": + self._console.print(Text(text, style="bold yellow")) + else: + self._console.print(Markdown(text)) + else: + sys.stdout.write(text) + sys.stdout.flush() + + def render_tool_request( + self, + call_id: str, + name: str, + arguments: dict | None = None, + title: str | None = None, + ) -> None: + self._stderr_console.print( + Text(f" ⚙ {title or name}", style="dim cyan"), + ) + + def render_tool_response(self, call_id: str, name: str, result: Any) -> None: + self._stderr_console.print( + Text(f" ✓ {name} done", style="dim"), + ) + + def render_tool_progress(self, status: str, message: str) -> None: + self._stderr_console.print( + Text(f" ⏳ {message}", style="dim italic"), + ) + + def render_error(self, error_type: str, message: str) -> None: + self._stderr_console.print( + Text(f"Error ({error_type}): {message}", style="bold red"), + ) + + def render_conversation_title(self, title: str) -> None: + pass + + def start_spinner(self) -> None: + if self._is_tty and self._spinner is None: + self._spinner = _Spinner() + self._spinner.start() + + def stop_spinner(self) -> None: + if self._spinner is not None: + self._spinner.stop() + self._spinner = None + + def print_separator(self) -> None: + sys.stdout.write("\n") + sys.stdout.flush() + + +def _summarize_args(args: dict, max_len: int = 200) -> str: + """Produce a compact summary of tool arguments.""" + parts: list[str] = [] + for k, v in args.items(): + s = str(v) + if len(s) > 60: + s = s[:57] + "..." + parts.append(f"{k}={s}") + text = ", ".join(parts) + if len(text) > max_len: + text = text[:max_len] + "..." + return text diff --git a/src/drs/cli.py b/src/drs/cli.py index 211da22..bcd55c7 100644 --- a/src/drs/cli.py +++ b/src/drs/cli.py @@ -19,6 +19,7 @@ import asyncio import json +import logging import sys from pathlib import Path @@ -27,12 +28,15 @@ from drs.auth import DrsConfig, load_config from drs.client import DremioClient -from drs.commands import engine, folder, grant, job, project, query, reflection, role, schema, tag, user, wiki +from drs.commands import chat, engine, folder, grant, job, project, query, reflection, role, schema, tag, user, wiki + +CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]} app = typer.Typer( name="dremio", help="Developer CLI for Dremio Cloud.", no_args_is_help=True, + context_settings=CONTEXT_SETTINGS, ) # Register command groups @@ -48,6 +52,7 @@ app.add_typer(role.app, name="role") app.add_typer(grant.app, name="grant") app.add_typer(project.app, name="project") +app.add_typer(chat.app, name="chat") # Global state for config _config: DrsConfig | None = None @@ -62,8 +67,29 @@ def main( uri: str | None = typer.Option( None, "--uri", help="Dremio API base URI (e.g., https://api.dremio.cloud, https://api.eu.dremio.cloud)" ), + verbose: int = typer.Option( + 0, "--verbose", "-v", count=True, help="Increase logging verbosity (-v for debug, -vv for trace)" + ), ) -> None: """Global options for dremio CLI.""" + # Configure logging based on verbosity + if verbose >= 2: + log_level = logging.DEBUG + # Also enable httpx/httpcore debug logging for -vv + logging.getLogger("httpx").setLevel(logging.DEBUG) + logging.getLogger("httpcore").setLevel(logging.DEBUG) + elif verbose == 1: + log_level = logging.DEBUG + else: + log_level = logging.WARNING + + logging.basicConfig( + level=log_level, + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", + datefmt="%H:%M:%S", + stream=sys.stderr, + ) + global _cli_opts _cli_opts = { "config_path": Path(config) if config else None, diff --git a/src/drs/client.py b/src/drs/client.py index c38ca44..e45ea69 100644 --- a/src/drs/client.py +++ b/src/drs/client.py @@ -108,24 +108,32 @@ async def _request_with_retry(self, method: str, url: str, **kwargs: Any) -> htt raise last_exc # type: ignore[misc] async def _get(self, url: str, params: dict | None = None) -> Any: + logger.debug("GET %s params=%s", url, params) resp = await self._request_with_retry("GET", url, params=params) + logger.debug("GET %s → %d (%d bytes)", url, resp.status_code, len(resp.content)) resp.raise_for_status() return resp.json() async def _post(self, url: str, json: dict | None = None) -> Any: + logger.debug("POST %s body=%s", url, json) resp = await self._request_with_retry("POST", url, json=json) + logger.debug("POST %s → %d (%d bytes)", url, resp.status_code, len(resp.content)) resp.raise_for_status() return resp.json() async def _put(self, url: str, json: dict | None = None) -> Any: + logger.debug("PUT %s body=%s", url, json) resp = await self._request_with_retry("PUT", url, json=json) + logger.debug("PUT %s → %d (%d bytes)", url, resp.status_code, len(resp.content)) resp.raise_for_status() if resp.content: return resp.json() return {"status": "ok"} async def _delete(self, url: str, params: dict | None = None) -> Any: + logger.debug("DELETE %s params=%s", url, params) resp = await self._request_with_retry("DELETE", url, params=params) + logger.debug("DELETE %s → %d (%d bytes)", url, resp.status_code, len(resp.content)) resp.raise_for_status() if resp.content: return resp.json() @@ -309,6 +317,91 @@ async def update_role(self, role_id: str, body: dict) -> dict: async def delete_role(self, role_id: str) -> dict: return await self._delete(self._v1(f"/roles/{role_id}")) + # -- Agent / Chat (SSE) -- + + def _agent(self, path: str) -> str: + """Agent API URL: /v1/projects/{pid}/agent/...""" + return f"{self.config.uri}/v1/projects/{self.config.project_id}/agent{path}" + + async def create_conversation(self, body: dict) -> dict: + """POST /agent/conversations — start a new conversation.""" + return await self._post(self._agent("/conversations"), json=body) + + async def send_conversation_message(self, conversation_id: str, body: dict) -> dict: + """POST /agent/conversations/{id}/messages — send a message or approval.""" + return await self._post( + self._agent(f"/conversations/{conversation_id}/messages"), + json=body, + ) + + async def stream_run(self, conversation_id: str, run_id: str) -> httpx.Response: + """GET /agent/conversations/{id}/runs/{runId} — returns raw SSE response. + + Returns the raw ``httpx.Response`` with ``stream=True``. Caller must + iterate ``resp.aiter_bytes()`` and close the response via ``async with``. + + We explicitly disable compression (``Accept-Encoding: identity``) so + that reverse proxies (GCP LB, envoy, etc.) do **not** gzip-buffer the + event stream — otherwise every SSE event is held until the stream ends. + """ + url = self._agent(f"/conversations/{conversation_id}/runs/{run_id}") + logger.debug("SSE GET %s", url) + resp = await self._client.send( + self._client.build_request( + "GET", + url, + headers={ + "Accept": "text/event-stream", + "Accept-Encoding": "identity", + "Cache-Control": "no-cache", + }, + ), + stream=True, + ) + logger.debug("SSE GET %s → %d", url, resp.status_code) + resp.raise_for_status() + return resp + + async def list_conversations( + self, + limit: int = 25, + page_token: str | None = None, + ) -> dict: + """GET /agent/conversations""" + params: dict[str, str | int] = {"maxResults": limit} + if page_token: + params["pageToken"] = page_token + return await self._get(self._agent("/conversations"), params=params) + + async def get_conversation_messages( + self, + conversation_id: str, + limit: int = 50, + page_token: str | None = None, + ) -> dict: + """GET /agent/conversations/{id}/messages""" + params: dict[str, str | int] = {"maxResults": limit} + if page_token: + params["pageToken"] = page_token + return await self._get( + self._agent(f"/conversations/{conversation_id}/messages"), + params=params, + ) + + async def delete_conversation(self, conversation_id: str) -> dict: + """DELETE /agent/conversations/{id}""" + return await self._delete(self._agent(f"/conversations/{conversation_id}")) + + async def cancel_conversation_run( + self, + conversation_id: str, + run_id: str, + ) -> dict: + """POST /agent/conversations/{id}/runs/{runId}:cancel""" + return await self._post( + self._agent(f"/conversations/{conversation_id}/runs/{run_id}:cancel"), + ) + # -- Grants (v1) -- async def get_grants(self, scope: str, scope_id: str, grantee_type: str, grantee_id: str) -> dict: diff --git a/src/drs/commands/chat.py b/src/drs/commands/chat.py new file mode 100644 index 0000000..a605323 --- /dev/null +++ b/src/drs/commands/chat.py @@ -0,0 +1,703 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""dremio chat — interactive and non-interactive chat with the Dremio AI Agent.""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +import sys +from enum import StrEnum +from pathlib import Path +from typing import Any + +import httpx +import typer +from prompt_toolkit import PromptSession +from prompt_toolkit.history import InMemoryHistory +from rich.console import Console +from rich.json import JSON as RichJSON +from rich.markdown import Markdown +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from drs.chat_render import ChatRenderer, PlainRenderer +from drs.client import DremioClient +from drs.output import error as print_error +from drs.sse import parse_sse_stream +from drs.utils import DremioAPIError, handle_api_error + +logger = logging.getLogger(__name__) + +app = typer.Typer(help="Chat with the Dremio AI Agent.", context_settings={"help_option_names": ["-h", "--help"]}) + + +# --------------------------------------------------------------------------- +# Output format helpers +# --------------------------------------------------------------------------- + + +class ChatFormat(StrEnum): + json = "json" + table = "table" + + +def _chat_output(data: Any, fmt: ChatFormat) -> None: + """Print chat subcommand results as JSON or a Rich table.""" + console = Console() + + if fmt == ChatFormat.json: + console.print(RichJSON(json.dumps(data, indent=2, default=str))) + return + + rows = data.get("data", data.get("conversations", data.get("messages", []))) + if not rows: + console.print("[dim]No results.[/]") + return + + if not isinstance(rows, list) or not isinstance(rows[0], dict): + console.print("[dim]No results.[/]") + return + + first = rows[0] + if "chunkType" in first: + _render_history_table(console, rows) + elif "conversationId" in first and "title" in first: + _render_conversations_table(console, rows) + else: + _render_generic_table(console, rows) + + +def _render_conversations_table(console: Console, rows: list[dict]) -> None: + """Render conversation list as a curated Rich table.""" + table = Table(show_edge=False, pad_edge=False, expand=False) + table.add_column("ID", style="cyan", no_wrap=True) + table.add_column("Title") + table.add_column("Model") + table.add_column("Modified") + + for row in rows: + modified = str(row.get("modifiedAt", "")) + if "T" in modified: + modified = modified.replace("T", " ").rstrip("Z").split(".")[0] + table.add_row( + row.get("conversationId", ""), + row.get("title", ""), + row.get("modelName", ""), + modified, + ) + console.print(table) + + +def _render_history_table(console: Console, rows: list[dict]) -> None: + """Render conversation history as a readable transcript.""" + for row in rows: + chunk_type = row.get("chunkType", "") + timestamp = str(row.get("createdAt", "")) + if "T" in timestamp: + timestamp = timestamp.replace("T", " ").rstrip("Z").split(".")[0] + + if chunk_type == "userMessage": + text = row.get("text", "") + console.print(Text(f" [{timestamp}]", style="dim")) + console.print(Panel(text, title="[bold green]You[/]", border_style="green", expand=False)) + + elif chunk_type == "model": + result = row.get("result", {}) + text = result.get("text", "") if isinstance(result, dict) else str(result) + name = row.get("name", "") + title = "[bold blue]Agent[/]" + if name and name != "modelGeneric": + title += f" [dim]({name})[/]" + console.print(Text(f" [{timestamp}]", style="dim")) + console.print(Panel(Markdown(text), title=title, border_style="blue", expand=False)) + + elif chunk_type == "toolRequest": + tool_name = row.get("name", "") + summarized = row.get("summarizedTitle", tool_name) + console.print(Text(f" ⚙ {summarized}", style="dim cyan")) + + elif chunk_type == "toolResponse": + tool_name = row.get("name", "") + console.print(Text(f" ✓ {tool_name} done", style="dim")) + + +def _render_generic_table(console: Console, rows: list[dict]) -> None: + """Fallback: render all columns as a Rich table.""" + columns = list(rows[0].keys()) + table = Table(show_edge=False, pad_edge=False, expand=False) + for col in columns: + table.add_column(col) + for row in rows: + table.add_row(*[str(row.get(c, "")) for c in columns]) + console.print(table) + + +# --------------------------------------------------------------------------- +# Core async functions (reusable, no UI) +# --------------------------------------------------------------------------- + + +async def create_conversation( + client: DremioClient, + text: str, + model: str | None = None, +) -> dict: + """POST /agent/conversations — start a new conversation.""" + body: dict[str, Any] = {"prompt": {"text": text}} + if model: + body["model"] = model + try: + return await client.create_conversation(body) + except httpx.HTTPStatusError as exc: + raise handle_api_error(exc) from exc + + +async def send_message( + client: DremioClient, + conversation_id: str, + text: str | None = None, + approvals: dict | None = None, + model: str | None = None, +) -> dict: + """POST /agent/conversations/{id}/messages.""" + body: dict[str, Any] = {"prompt": {}} + if text: + body["prompt"]["text"] = text + if approvals: + body["prompt"]["approvals"] = approvals + if model: + body["model"] = model + try: + return await client.send_conversation_message(conversation_id, body) + except httpx.HTTPStatusError as exc: + raise handle_api_error(exc) from exc + + +async def stream_run( + client: DremioClient, + conversation_id: str, + run_id: str, +): + """GET /agent/conversations/{id}/runs/{runId} as SSE — yields parsed events.""" + try: + resp = await client.stream_run(conversation_id, run_id) + except httpx.HTTPStatusError as exc: + raise handle_api_error(exc) from exc + + try: + async for event in parse_sse_stream(resp.aiter_bytes()): + yield event + finally: + await resp.aclose() + + +async def list_conversations(client: DremioClient, limit: int = 25) -> dict: + """GET /agent/conversations""" + try: + return await client.list_conversations(limit=limit) + except httpx.HTTPStatusError as exc: + raise handle_api_error(exc) from exc + + +async def get_messages( + client: DremioClient, + conversation_id: str, + limit: int = 50, +) -> dict: + """GET /agent/conversations/{id}/messages""" + try: + return await client.get_conversation_messages(conversation_id, limit=limit) + except httpx.HTTPStatusError as exc: + raise handle_api_error(exc) from exc + + +async def delete_conversation(client: DremioClient, conversation_id: str) -> dict: + """DELETE /agent/conversations/{id}""" + try: + return await client.delete_conversation(conversation_id) + except httpx.HTTPStatusError as exc: + raise handle_api_error(exc) from exc + + +async def cancel_run( + client: DremioClient, + conversation_id: str, + run_id: str, +) -> dict: + """POST /agent/conversations/{id}/runs/{runId}:cancel""" + try: + return await client.cancel_conversation_run(conversation_id, run_id) + except httpx.HTTPStatusError as exc: + raise handle_api_error(exc) from exc + + +def _extract_ids(result: dict) -> tuple[str | None, str | None]: + """Extract conversation_id and run_id from an API response.""" + conv_id = result.get("conversationId", result.get("id")) + run_id = result.get("currentRunId", result.get("runId", result.get("run", {}).get("id"))) + return conv_id, run_id + + +# --------------------------------------------------------------------------- +# SSE event dispatch +# --------------------------------------------------------------------------- + + +async def dispatch_events( + client: DremioClient, + renderer: ChatRenderer | PlainRenderer, + conversation_id: str, + run_id: str, + auto_approve: bool = False, + interactive: bool = True, + log_file: Any | None = None, +) -> str | None: + """Stream a run's SSE events and dispatch to the renderer. + + Returns the latest run_id (which may change after an approval cycle). + """ + renderer.start_spinner() + first_model_chunk = True + + try: + async for event in stream_run(client, conversation_id, run_id): + data = event.get("data", {}) + chunk_type = data.get("chunkType") + + if log_file: + log_file.write(json.dumps(data, default=str) + "\n") + log_file.flush() + + # All fields are at the top level of data (flat structure). + if chunk_type == "model": + if first_model_chunk: + renderer.stop_spinner() + first_model_chunk = False + name = data.get("name", "") + result = data.get("result", {}) + renderer.render_model_chunk(name, result) + + elif chunk_type == "toolRequest": + renderer.stop_spinner() + renderer.render_tool_request( + call_id=data.get("callId", ""), + name=data.get("name", ""), + arguments=data.get("arguments"), + title=data.get("summarizedTitle"), + ) + renderer.start_spinner() + + elif chunk_type == "toolResponse": + renderer.stop_spinner() + renderer.render_tool_response( + call_id=data.get("callId", ""), + name=data.get("name", ""), + result=data.get("result"), + ) + renderer.start_spinner() + + elif chunk_type == "toolProgress": + renderer.render_tool_progress( + status=data.get("status", ""), + message=data.get("message", ""), + ) + + elif chunk_type == "error": + renderer.stop_spinner() + renderer.render_error( + error_type=data.get("type", "unknown"), + message=data.get("message", str(data)), + ) + + elif chunk_type == "interrupt": + renderer.stop_spinner() + nonce = data.get("approvalNonce", "") + tools = data.get("toolDecisions", []) + + if interactive and isinstance(renderer, ChatRenderer): + approvals = renderer.prompt_tool_approval(nonce, tools) + else: + decisions = [] + for tool in tools: + decisions.append( + { + "callId": tool.get("callId", tool.get("id", "")), + "decision": "approved" if auto_approve else "denied", + } + ) + approvals = {"approvalNonce": nonce, "toolDecisions": decisions} + + resp = await send_message( + client, + conversation_id, + approvals=approvals, + ) + _, new_run_id = _extract_ids(resp) + if new_run_id: + run_id = new_run_id + renderer.start_spinner() + first_model_chunk = True + return await dispatch_events( + client, + renderer, + conversation_id, + run_id, + auto_approve=auto_approve, + interactive=interactive, + log_file=log_file, + ) + + elif chunk_type == "conversationUpdate": + title = data.get("title", "") + if title: + renderer.render_conversation_title(title) + + elif chunk_type == "endOfStream": + renderer.stop_spinner() + break + + elif chunk_type == "userMessage": + pass + + finally: + renderer.stop_spinner() + + return run_id + + +# --------------------------------------------------------------------------- +# Interactive REPL +# --------------------------------------------------------------------------- + + +async def chat_repl( + client: DremioClient, + renderer: ChatRenderer, + conv_id: str | None = None, + run_id: str | None = None, + model: str | None = None, + log_file: Any | None = None, +) -> None: + """Interactive REPL loop.""" + session: PromptSession = PromptSession(history=InMemoryHistory()) + renderer.print_welcome(conv_id) + + # If we were given a conversation + run_id, stream it first + if conv_id and run_id: + try: + run_id = await dispatch_events( + client, + renderer, + conv_id, + run_id, + interactive=True, + log_file=log_file, + ) + except DremioAPIError as exc: + renderer.render_error("api", str(exc)) + + while True: + try: + text = await session.prompt_async("You > ") + except (EOFError, KeyboardInterrupt): + break + + text = text.strip() + if not text: + continue + + # -- Slash commands -- + if text.startswith("/"): + parts = text.split(maxsplit=1) + cmd = parts[0].lower() + arg = parts[1].strip() if len(parts) > 1 else "" + + if cmd == "/quit": + break + if cmd == "/help": + renderer.print_help() + elif cmd == "/new": + conv_id = None + run_id = None + renderer.console.print("[dim]Starting new conversation.[/]") + elif cmd == "/list": + try: + result = await list_conversations(client, limit=25) + _render_conversations_table( + renderer.console, + result.get("data", result.get("conversations", [])), + ) + except DremioAPIError as exc: + renderer.render_error("api", str(exc)) + elif cmd == "/continue": + if not arg: + renderer.console.print("[yellow]Usage: /continue [/]") + else: + conv_id = arg + run_id = None + renderer.console.print(f"[dim]Switched to conversation: {conv_id}[/]") + elif cmd == "/history": + if not conv_id: + renderer.console.print("[yellow]No active conversation. Start one first.[/]") + else: + try: + result = await get_messages(client, conv_id, limit=50) + _render_history_table( + renderer.console, + result.get("data", result.get("messages", [])), + ) + except DremioAPIError as exc: + renderer.render_error("api", str(exc)) + elif cmd == "/cancel": + if not conv_id or not run_id: + renderer.console.print("[yellow]No active run to cancel.[/]") + else: + try: + await cancel_run(client, conv_id, run_id) + renderer.console.print("[dim]Run cancelled.[/]") + except DremioAPIError as exc: + renderer.render_error("api", str(exc)) + elif cmd == "/delete": + target = arg or conv_id + if not target: + renderer.console.print("[yellow]No conversation to delete. Provide an ID or start one first.[/]") + else: + try: + await delete_conversation(client, target) + renderer.console.print(f"[dim]Deleted conversation: {target}[/]") + if target == conv_id: + conv_id = None + run_id = None + except DremioAPIError as exc: + renderer.render_error("api", str(exc)) + elif cmd == "/info": + if not conv_id: + renderer.console.print("[yellow]No active conversation.[/]") + else: + renderer.console.print(f" Conversation: [cyan]{conv_id}[/]") + renderer.console.print(f" Run: [cyan]{run_id or '(none)'}[/]") + else: + renderer.console.print(f"[yellow]Unknown command: {cmd}. Type /help for commands.[/]") + continue + + # -- Send message -- + try: + if conv_id is None: + result = await create_conversation(client, text, model=model) + logger.debug("create_conversation response: %s", json.dumps(result, default=str)) + conv_id, run_id = _extract_ids(result) + else: + result = await send_message(client, conv_id, text=text, model=model) + logger.debug("send_message response: %s", json.dumps(result, default=str)) + _, run_id = _extract_ids(result) + + if run_id: + try: + run_id = await dispatch_events( + client, + renderer, + conv_id, + run_id, + interactive=True, + log_file=log_file, + ) + except KeyboardInterrupt: + renderer.stop_spinner() + renderer.console.print("\n[dim]Cancelling...[/]") + with contextlib.suppress(DremioAPIError): + await cancel_run(client, conv_id, run_id) + renderer.print_separator() + except DremioAPIError as exc: + renderer.render_error("api", str(exc)) + + +# --------------------------------------------------------------------------- +# Non-interactive (one-shot) mode +# --------------------------------------------------------------------------- + + +async def chat_oneshot( + client: DremioClient, + message: str, + conversation_id: str | None = None, + auto_approve: bool = False, + model: str | None = None, + log_file: Any | None = None, +) -> None: + """Send a single message and stream the response to stdout.""" + renderer = PlainRenderer() + + if conversation_id is None: + result = await create_conversation(client, message, model=model) + logger.debug("create_conversation response: %s", json.dumps(result, default=str)) + conversation_id, run_id = _extract_ids(result) + else: + result = await send_message(client, conversation_id, text=message, model=model) + logger.debug("send_message response: %s", json.dumps(result, default=str)) + _, run_id = _extract_ids(result) + + logger.debug("conversation_id=%s run_id=%s", conversation_id, run_id) + + if run_id: + await dispatch_events( + client, + renderer, + conversation_id, + run_id, + auto_approve=auto_approve, + interactive=False, + log_file=log_file, + ) + else: + logger.warning("No run_id found in response — cannot stream events") + # Ensure trailing newline for piped output + sys.stdout.write("\n") + sys.stdout.flush() + + +# --------------------------------------------------------------------------- +# CLI entry points +# --------------------------------------------------------------------------- + + +def _get_client() -> DremioClient: + # Deferred import to avoid circular dependency: cli.py imports this module. + from drs.cli import get_client + + return get_client() + + +@app.callback(invoke_without_command=True) +def chat_main( + ctx: typer.Context, + message: str | None = typer.Option(None, "--message", "-m", help="Send a single message (non-interactive mode)"), + conversation: str | None = typer.Option(None, "--conversation", "-C", help="Resume an existing conversation by ID"), + auto_approve: bool = typer.Option(False, "--auto-approve", help="Auto-approve tool calls (non-interactive only)"), + log_file: str | None = typer.Option(None, "--log-file", help="Path to JSON-lines event log file"), + model: str | None = typer.Option(None, "--model", help="Model override"), +) -> None: + """Chat with the Dremio AI Agent. Launches interactive REPL by default.""" + if ctx.invoked_subcommand is not None: + return + + client = _get_client() + + async def _run() -> None: + log_fh = None + try: + if log_file: + log_fh = Path(log_file).open("a") # noqa: SIM115 + if message is not None: + # Read from stdin if message is "-" + msg = sys.stdin.read().strip() if message == "-" else message + await chat_oneshot( + client, + msg, + conversation_id=conversation, + auto_approve=auto_approve, + model=model, + log_file=log_fh, + ) + else: + renderer = ChatRenderer() + await chat_repl( + client, + renderer, + conv_id=conversation, + model=model, + log_file=log_fh, + ) + finally: + await client.close() + if log_fh: + log_fh.close() + + try: + asyncio.run(_run()) + except DremioAPIError as exc: + print_error(str(exc)) + raise typer.Exit(1) + + +@app.command("list") +def chat_list( + limit: int = typer.Option(25, "--limit", "-n", help="Maximum conversations to return"), + fmt: ChatFormat = typer.Option(ChatFormat.table, "--format", "-f", help="Output format: json, table"), +) -> None: + """List recent conversations.""" + client = _get_client() + + async def _run(): + try: + return await list_conversations(client, limit=limit) + finally: + await client.close() + + try: + result = asyncio.run(_run()) + except DremioAPIError as exc: + print_error(str(exc)) + raise typer.Exit(1) + _chat_output(result, fmt) + + +@app.command("history") +def chat_history( + conversation_id: str = typer.Argument(help="Conversation ID"), + limit: int = typer.Option(50, "--limit", "-n", help="Maximum messages to return"), + fmt: ChatFormat = typer.Option(ChatFormat.table, "--format", "-f", help="Output format: json, table"), +) -> None: + """Show message history for a conversation.""" + client = _get_client() + + async def _run(): + try: + return await get_messages(client, conversation_id, limit=limit) + finally: + await client.close() + + try: + result = asyncio.run(_run()) + except DremioAPIError as exc: + print_error(str(exc)) + raise typer.Exit(1) + _chat_output(result, fmt) + + +@app.command("delete") +def chat_delete( + conversation_id: str = typer.Argument(help="Conversation ID to delete"), + fmt: ChatFormat = typer.Option(ChatFormat.json, "--format", "-f", help="Output format: json, table"), +) -> None: + """Delete a conversation.""" + client = _get_client() + + async def _run(): + try: + return await delete_conversation(client, conversation_id) + finally: + await client.close() + + try: + result = asyncio.run(_run()) + except DremioAPIError as exc: + print_error(str(exc)) + raise typer.Exit(1) + _chat_output(result, fmt) diff --git a/src/drs/commands/engine.py b/src/drs/commands/engine.py index 3c66a99..9db82ac 100644 --- a/src/drs/commands/engine.py +++ b/src/drs/commands/engine.py @@ -26,7 +26,7 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error -app = typer.Typer(help="Manage Dremio Cloud engines.") +app = typer.Typer(help="Manage Dremio Cloud engines.", context_settings={"help_option_names": ["-h", "--help"]}) async def list_engines(client: DremioClient) -> dict: diff --git a/src/drs/commands/folder.py b/src/drs/commands/folder.py index 47f9409..abe0547 100644 --- a/src/drs/commands/folder.py +++ b/src/drs/commands/folder.py @@ -27,7 +27,9 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error, parse_path, quote_path_sql -app = typer.Typer(help="Manage spaces and folders in the Dremio catalog.") +app = typer.Typer( + help="Manage spaces and folders in the Dremio catalog.", context_settings={"help_option_names": ["-h", "--help"]} +) async def list_catalog(client: DremioClient) -> dict: diff --git a/src/drs/commands/grant.py b/src/drs/commands/grant.py index dc1eec0..f2e33d0 100644 --- a/src/drs/commands/grant.py +++ b/src/drs/commands/grant.py @@ -26,7 +26,10 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error -app = typer.Typer(help="Manage grants on projects, engines, and org resources.") +app = typer.Typer( + help="Manage grants on projects, engines, and org resources.", + context_settings={"help_option_names": ["-h", "--help"]}, +) async def get_grants(client: DremioClient, scope: str, scope_id: str, grantee_type: str, grantee_id: str) -> dict: diff --git a/src/drs/commands/job.py b/src/drs/commands/job.py index fddad0b..65af403 100644 --- a/src/drs/commands/job.py +++ b/src/drs/commands/job.py @@ -27,7 +27,7 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error, validate_job_id, validate_job_state -app = typer.Typer(help="List and inspect query jobs.") +app = typer.Typer(help="List and inspect query jobs.", context_settings={"help_option_names": ["-h", "--help"]}) async def list_jobs(client: DremioClient, status_filter: str | None = None, limit: int = 25) -> dict: diff --git a/src/drs/commands/project.py b/src/drs/commands/project.py index 0af435a..db60789 100644 --- a/src/drs/commands/project.py +++ b/src/drs/commands/project.py @@ -26,7 +26,7 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error -app = typer.Typer(help="Manage Dremio Cloud projects.") +app = typer.Typer(help="Manage Dremio Cloud projects.", context_settings={"help_option_names": ["-h", "--help"]}) async def list_projects(client: DremioClient) -> dict: diff --git a/src/drs/commands/query.py b/src/drs/commands/query.py index c348b1d..e5c488c 100644 --- a/src/drs/commands/query.py +++ b/src/drs/commands/query.py @@ -27,7 +27,7 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error -app = typer.Typer(help="Run SQL queries against Dremio.") +app = typer.Typer(help="Run SQL queries against Dremio.", context_settings={"help_option_names": ["-h", "--help"]}) TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "CANCELLED"} POLL_INTERVAL = 1.0 diff --git a/src/drs/commands/reflection.py b/src/drs/commands/reflection.py index bca2f59..2171ea2 100644 --- a/src/drs/commands/reflection.py +++ b/src/drs/commands/reflection.py @@ -27,7 +27,9 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error, parse_path -app = typer.Typer(help="Manage reflections (materialized views).") +app = typer.Typer( + help="Manage reflections (materialized views).", context_settings={"help_option_names": ["-h", "--help"]} +) async def create(client: DremioClient, path: str, rtype: str, display_fields: list[str] | None = None) -> dict: diff --git a/src/drs/commands/role.py b/src/drs/commands/role.py index 93ecf6b..6caa807 100644 --- a/src/drs/commands/role.py +++ b/src/drs/commands/role.py @@ -26,7 +26,7 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error -app = typer.Typer(help="Manage Dremio Cloud roles.") +app = typer.Typer(help="Manage Dremio Cloud roles.", context_settings={"help_option_names": ["-h", "--help"]}) async def list_roles(client: DremioClient) -> dict: diff --git a/src/drs/commands/schema.py b/src/drs/commands/schema.py index 83b96f3..ba3d239 100644 --- a/src/drs/commands/schema.py +++ b/src/drs/commands/schema.py @@ -27,7 +27,10 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error, parse_path, quote_path_sql -app = typer.Typer(help="Describe table schemas, trace lineage, and sample data.") +app = typer.Typer( + help="Describe table schemas, trace lineage, and sample data.", + context_settings={"help_option_names": ["-h", "--help"]}, +) async def describe(client: DremioClient, path: str) -> dict: diff --git a/src/drs/commands/tag.py b/src/drs/commands/tag.py index 565625e..3f0b0b9 100644 --- a/src/drs/commands/tag.py +++ b/src/drs/commands/tag.py @@ -26,7 +26,9 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error, parse_path -app = typer.Typer(help="Get and update tags on catalog entities.") +app = typer.Typer( + help="Get and update tags on catalog entities.", context_settings={"help_option_names": ["-h", "--help"]} +) async def get_tags(client: DremioClient, path: str) -> dict: diff --git a/src/drs/commands/user.py b/src/drs/commands/user.py index 73d0ab6..c5d91f5 100644 --- a/src/drs/commands/user.py +++ b/src/drs/commands/user.py @@ -26,7 +26,7 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error -app = typer.Typer(help="Manage Dremio Cloud users.") +app = typer.Typer(help="Manage Dremio Cloud users.", context_settings={"help_option_names": ["-h", "--help"]}) async def list_users(client: DremioClient, max_results: int = 100) -> dict: diff --git a/src/drs/commands/wiki.py b/src/drs/commands/wiki.py index 42af725..c79fc7f 100644 --- a/src/drs/commands/wiki.py +++ b/src/drs/commands/wiki.py @@ -26,7 +26,10 @@ from drs.output import OutputFormat, error, output from drs.utils import handle_api_error, parse_path -app = typer.Typer(help="Get and update wiki descriptions on catalog entities.") +app = typer.Typer( + help="Get and update wiki descriptions on catalog entities.", + context_settings={"help_option_names": ["-h", "--help"]}, +) async def get_wiki(client: DremioClient, path: str) -> dict: diff --git a/src/drs/sse.py b/src/drs/sse.py new file mode 100644 index 0000000..ef03ae6 --- /dev/null +++ b/src/drs/sse.py @@ -0,0 +1,78 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""SSE (Server-Sent Events) stream parser for text/event-stream responses.""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator + + +async def parse_sse_stream(byte_stream: AsyncIterator[bytes]) -> AsyncIterator[dict]: + """Yield ``{"event": str, "data": dict}`` for each SSE event. + + Handles multi-line ``data:`` fields, ``event:`` types, comment lines + (``:`` prefix), empty-line delimiters, and partial chunk buffering. + """ + buf = "" + event_type = "message" + data_lines: list[str] = [] + + async for chunk in byte_stream: + buf += chunk.decode("utf-8", errors="replace") + while "\n" in buf: + line, buf = buf.split("\n", 1) + line = line.rstrip("\r") + + if not line: + # Empty line = event boundary + if data_lines: + raw = "\n".join(data_lines) + try: + data = json.loads(raw) + except json.JSONDecodeError: + data = {"raw": raw} + yield {"event": event_type, "data": data} + event_type = "message" + data_lines = [] + continue + + if line.startswith(":"): + # SSE comment — ignore + continue + + if line.startswith("event:"): + event_type = line[len("event:") :].strip() + elif line.startswith("data:"): + data_lines.append(line[len("data:") :].strip()) + + # Flush any remaining content left in buf (server closed without trailing \n) + if buf: + for line in buf.split("\n"): + line = line.rstrip("\r") + if line.startswith("data:"): + data_lines.append(line[len("data:") :].strip()) + elif line.startswith("event:"): + event_type = line[len("event:") :].strip() + + # Flush any remaining data (stream ended without trailing blank line) + if data_lines: + raw = "\n".join(data_lines) + try: + data = json.loads(raw) + except json.JSONDecodeError: + data = {"raw": raw} + yield {"event": event_type, "data": data} diff --git a/src/drs/utils.py b/src/drs/utils.py index 92f51b4..4019eb6 100644 --- a/src/drs/utils.py +++ b/src/drs/utils.py @@ -17,9 +17,12 @@ from __future__ import annotations +import logging import re from typing import TYPE_CHECKING, Any +logger = logging.getLogger(__name__) + if TYPE_CHECKING: import httpx @@ -172,7 +175,10 @@ def __init__(self, status_code: int, message: str, url: str = "") -> None: self.status_code = status_code self.message = message self.url = url - super().__init__(f"HTTP {status_code}: {message}") + text = f"HTTP {status_code}: {message}" + if url: + text += f"\n URL: {url}" + super().__init__(text) def to_dict(self) -> dict: d: dict[str, Any] = {"error": self.message, "status_code": self.status_code} @@ -186,17 +192,34 @@ def handle_api_error(exc: httpx.HTTPStatusError) -> DremioAPIError: status = exc.response.status_code url = str(exc.request.url) + # Always try to extract the server's error message first + server_msg = "" + try: + body = exc.response.json() + server_msg = body.get("errorMessage", body.get("message", "")) + except Exception: + server_msg = exc.response.text or "" + if status == 401: - msg = "Authentication failed — check your PAT token" + hint = "Authentication failed — check your PAT token" elif status == 403: - msg = "Permission denied — insufficient privileges for this operation" + hint = "Permission denied — insufficient privileges for this operation" elif status == 404: - msg = "Not found — check that the path or ID exists" + hint = "Not found — check that the path or ID exists" else: - try: - body = exc.response.json() - msg = body.get("errorMessage", body.get("message", str(exc))) - except Exception: - msg = exc.response.text or str(exc) + hint = "" + + # Log the full response for debugging + logger.debug( + "API error: %s %s → %d\n Response body: %s", + exc.request.method, + url, + status, + exc.response.text[:1000], + ) + + # Combine: server message + hint (if any), always include URL + parts = [p for p in (server_msg, hint) if p] + msg = " — ".join(parts) if parts else str(exc) return DremioAPIError(status, msg, url) diff --git a/tests/test_commands/test_chat.py b/tests/test_commands/test_chat.py new file mode 100644 index 0000000..e53dae2 --- /dev/null +++ b/tests/test_commands/test_chat.py @@ -0,0 +1,115 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for drs.commands.chat — core async functions.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from drs.commands.chat import ( + cancel_run, + create_conversation, + delete_conversation, + get_messages, + list_conversations, + send_message, +) + + +@pytest.mark.asyncio +async def test_create_conversation(mock_client) -> None: + mock_client.create_conversation = AsyncMock( + return_value={"id": "conv-1", "runId": "run-1"}, + ) + result = await create_conversation(mock_client, "hello") + mock_client.create_conversation.assert_called_once_with( + {"prompt": {"text": "hello"}}, + ) + assert result["id"] == "conv-1" + assert result["runId"] == "run-1" + + +@pytest.mark.asyncio +async def test_create_conversation_with_model(mock_client) -> None: + mock_client.create_conversation = AsyncMock(return_value={"id": "conv-1"}) + await create_conversation(mock_client, "hello", model="gpt-test") + call_args = mock_client.create_conversation.call_args[0][0] + assert call_args["model"] == "gpt-test" + + +@pytest.mark.asyncio +async def test_send_message_text(mock_client) -> None: + mock_client.send_conversation_message = AsyncMock( + return_value={"runId": "run-2"}, + ) + result = await send_message(mock_client, "conv-1", text="follow-up") + mock_client.send_conversation_message.assert_called_once() + body = mock_client.send_conversation_message.call_args[0][1] + assert body["prompt"]["text"] == "follow-up" + assert result["runId"] == "run-2" + + +@pytest.mark.asyncio +async def test_send_message_approval(mock_client) -> None: + mock_client.send_conversation_message = AsyncMock( + return_value={"runId": "run-3"}, + ) + approvals = { + "approvalNonce": "nonce-1", + "toolDecisions": [{"callId": "c1", "decision": "approved"}], + } + result = await send_message(mock_client, "conv-1", approvals=approvals) + body = mock_client.send_conversation_message.call_args[0][1] + assert body["prompt"]["approvals"] == approvals + assert result["runId"] == "run-3" + + +@pytest.mark.asyncio +async def test_list_conversations(mock_client) -> None: + mock_client.list_conversations = AsyncMock( + return_value={"data": [{"id": "c1", "title": "test"}]}, + ) + result = await list_conversations(mock_client, limit=10) + mock_client.list_conversations.assert_called_once_with(limit=10) + assert len(result["data"]) == 1 + + +@pytest.mark.asyncio +async def test_get_messages(mock_client) -> None: + mock_client.get_conversation_messages = AsyncMock( + return_value={"data": [{"role": "user", "content": "hi"}]}, + ) + result = await get_messages(mock_client, "conv-1", limit=25) + mock_client.get_conversation_messages.assert_called_once_with("conv-1", limit=25) + assert len(result["data"]) == 1 + + +@pytest.mark.asyncio +async def test_delete_conversation(mock_client) -> None: + mock_client.delete_conversation = AsyncMock(return_value={"status": "ok"}) + result = await delete_conversation(mock_client, "conv-1") + mock_client.delete_conversation.assert_called_once_with("conv-1") + assert result["status"] == "ok" + + +@pytest.mark.asyncio +async def test_cancel_run(mock_client) -> None: + mock_client.cancel_conversation_run = AsyncMock(return_value={"status": "ok"}) + result = await cancel_run(mock_client, "conv-1", "run-1") + mock_client.cancel_conversation_run.assert_called_once_with("conv-1", "run-1") + assert result["status"] == "ok" diff --git a/tests/test_sse.py b/tests/test_sse.py new file mode 100644 index 0000000..e2e2a20 --- /dev/null +++ b/tests/test_sse.py @@ -0,0 +1,130 @@ +# +# Copyright (C) 2017-2026 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for drs.sse — SSE stream parser.""" + +from __future__ import annotations + +import json + +import pytest + +from drs.sse import parse_sse_stream + + +async def _bytes_iter(chunks: list[bytes]): + """Helper: async iterator over byte chunks.""" + for chunk in chunks: + yield chunk + + +@pytest.mark.asyncio +async def test_parse_single_event(): + data = {"chunkType": "model", "model": {"name": "test", "result": {"text": "hello"}}} + raw = f"data: {json.dumps(data)}\n\n".encode() + events = [e async for e in parse_sse_stream(_bytes_iter([raw]))] + assert len(events) == 1 + assert events[0]["event"] == "message" + assert events[0]["data"] == data + + +@pytest.mark.asyncio +async def test_parse_multiple_events(): + data1 = {"chunkType": "model", "text": "a"} + data2 = {"chunkType": "endOfStream"} + raw = (f"data: {json.dumps(data1)}\n\ndata: {json.dumps(data2)}\n\n").encode() + events = [e async for e in parse_sse_stream(_bytes_iter([raw]))] + assert len(events) == 2 + assert events[0]["data"] == data1 + assert events[1]["data"] == data2 + + +@pytest.mark.asyncio +async def test_parse_event_type(): + data = {"foo": "bar"} + raw = f"event: custom\ndata: {json.dumps(data)}\n\n".encode() + events = [e async for e in parse_sse_stream(_bytes_iter([raw]))] + assert len(events) == 1 + assert events[0]["event"] == "custom" + assert events[0]["data"] == data + + +@pytest.mark.asyncio +async def test_comment_lines_ignored(): + data = {"chunkType": "model"} + raw = f": this is a comment\ndata: {json.dumps(data)}\n\n".encode() + events = [e async for e in parse_sse_stream(_bytes_iter([raw]))] + assert len(events) == 1 + assert events[0]["data"] == data + + +@pytest.mark.asyncio +async def test_partial_chunks(): + """Data split across multiple byte chunks.""" + data = {"chunkType": "model", "text": "hello world"} + full = f"data: {json.dumps(data)}\n\n" + mid = len(full) // 2 + chunk1 = full[:mid].encode() + chunk2 = full[mid:].encode() + events = [e async for e in parse_sse_stream(_bytes_iter([chunk1, chunk2]))] + assert len(events) == 1 + assert events[0]["data"] == data + + +@pytest.mark.asyncio +async def test_multiline_data(): + """Multiple data: lines for one event get joined.""" + raw = b'data: {"a":\ndata: 1}\n\n' + events = [e async for e in parse_sse_stream(_bytes_iter([raw]))] + assert len(events) == 1 + # Multiline data lines get joined with newline, parsed as raw if not valid JSON + assert "data" in events[0] + + +@pytest.mark.asyncio +async def test_flush_on_stream_end(): + """Data without trailing blank line gets flushed at end of stream.""" + data = {"chunkType": "endOfStream"} + raw = f"data: {json.dumps(data)}\n".encode() # No trailing blank line + events = [e async for e in parse_sse_stream(_bytes_iter([raw]))] + assert len(events) == 1 + assert events[0]["data"] == data + + +@pytest.mark.asyncio +async def test_flush_no_trailing_newline(): + """Data without any trailing newline gets flushed at end of stream.""" + data = {"chunkType": "endOfStream"} + raw = f"data: {json.dumps(data)}".encode() # No newline at all + events = [e async for e in parse_sse_stream(_bytes_iter([raw]))] + assert len(events) == 1 + assert events[0]["data"] == data + + +@pytest.mark.asyncio +async def test_empty_stream(): + events = [e async for e in parse_sse_stream(_bytes_iter([]))] + assert events == [] + + +@pytest.mark.asyncio +async def test_event_type_resets_after_event(): + """Event type resets to 'message' after each event.""" + data1 = {"a": 1} + data2 = {"b": 2} + raw = (f"event: custom\ndata: {json.dumps(data1)}\n\ndata: {json.dumps(data2)}\n\n").encode() + events = [e async for e in parse_sse_stream(_bytes_iter([raw]))] + assert events[0]["event"] == "custom" + assert events[1]["event"] == "message" diff --git a/uv.lock b/uv.lock index 1ed5beb..33108ae 100644 --- a/uv.lock +++ b/uv.lock @@ -22,15 +22,15 @@ wheels = [ [[package]] name = "anyio" -version = "4.12.1" +version = "4.13.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "idna" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/96/f0/5eb65b2bb0d09ac6776f2eb54adee6abe8228ea05b20a5ad0e4945de8aac/anyio-4.12.1.tar.gz", hash = "sha256:41cfcc3a4c85d3f05c932da7c26d0201ac36f72abd4435ba90d0464a3ffed703", size = 228685, upload-time = "2026-01-06T11:45:21.246Z" } +sdist = { url = "https://files.pythonhosted.org/packages/19/14/2c5dd9f512b66549ae92767a9c7b330ae88e1932ca57876909410251fe13/anyio-4.13.0.tar.gz", hash = "sha256:334b70e641fd2221c1505b3890c69882fe4a2df910cba14d97019b90b24439dc", size = 231622, upload-time = "2026-03-24T12:59:09.671Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, + { url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" }, ] [[package]] @@ -86,8 +86,10 @@ name = "dremio-cli" source = { editable = "." } dependencies = [ { name = "httpx" }, + { name = "prompt-toolkit" }, { name = "pydantic" }, { name = "pyyaml" }, + { name = "rich" }, { name = "typer" }, ] @@ -102,9 +104,11 @@ dev = [ [package.metadata] requires-dist = [ { name = "httpx", specifier = ">=0.27" }, + { name = "prompt-toolkit", specifier = ">=3.0" }, { name = "pydantic", specifier = ">=2" }, { name = "pyyaml", specifier = ">=6" }, - { name = "typer", extras = ["all"], specifier = ">=0.9" }, + { name = "rich", specifier = ">=13" }, + { name = "typer", specifier = ">=0.9" }, ] [package.metadata.requires-dev] @@ -261,6 +265,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/19/fd3ef348460c80af7bb4669ea7926651d1f95c23ff2df18b9d24bab4f3fa/pre_commit-4.5.1-py2.py3-none-any.whl", hash = "sha256:3b3afd891e97337708c1674210f8eba659b52a38ea5f822ff142d10786221f77", size = 226437, upload-time = "2025-12-16T21:14:32.409Z" }, ] +[[package]] +name = "prompt-toolkit" +version = "3.0.52" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/96/06e01a7b38dce6fe1db213e061a4602dd6032a8a97ef6c1a862537732421/prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855", size = 434198, upload-time = "2025-08-27T15:24:02.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, +] + [[package]] name = "pydantic" version = "2.12.5" @@ -576,3 +592,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/aa/92/58199fe10049f9703 wheels = [ { url = "https://files.pythonhosted.org/packages/c6/59/7d02447a55b2e55755011a647479041bc92a82e143f96a8195cb33bd0a1c/virtualenv-21.2.0-py3-none-any.whl", hash = "sha256:1bd755b504931164a5a496d217c014d098426cddc79363ad66ac78125f9d908f", size = 5825084, upload-time = "2026-03-09T17:24:35.378Z" }, ] + +[[package]] +name = "wcwidth" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/35/a2/8e3becb46433538a38726c948d3399905a4c7cabd0df578ede5dc51f0ec2/wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159", size = 159684, upload-time = "2026-02-06T19:19:40.919Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" }, +]