diff --git a/README.md b/README.md index ac82480..f26d69d 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,35 @@ Use `--json` when you want machine-readable output for an agent or script: npx -y github:illscience/vibe-debug debug-python ./buggy_invoice.py --break ./buggy_invoice.py:13 --eval "subtotal * (1 - rate)" --json ``` +For pytest failures, use `debug-pytest`. It launches pytest through the selected Python interpreter, so pytest comes from the target project environment rather than from `vibe-debug` itself: + +```bash +npx -y github:illscience/vibe-debug debug-pytest examples/test_buggy_discount.py::test_gold --json +``` + +By default, `debug-pytest` stops on the failing `AssertionError` frame and returns the same stopped location, locals, and optional evaluations as `debug-python`, plus a `pytest` block: + +```json +{ + "ok": true, + "pytest": { + "test_id": "examples/test_buggy_discount.py::test_gold", + "outcome": "failed" + }, + "exception": { + "name": "AssertionError" + } +} +``` + +Useful pytest options: + +```bash +npx -y github:illscience/vibe-debug debug-pytest examples/test_buggy_discount.py --break examples/test_buggy_discount.py:5 +npx -y github:illscience/vibe-debug debug-pytest examples/test_buggy_discount.py --pytest-arg -k --pytest-arg test_gold +npx -y github:illscience/vibe-debug debug-pytest examples/test_buggy_discount.py::test_gold --no-break-on-failure --json +``` + ## Status This is an alpha release. The first debugger backend is Python via [`debugpy`](https://github.com/microsoft/debugpy); the MCP server is designed to grow to TypeScript/Node and other language runtimes. @@ -208,6 +237,7 @@ Workflow tools: - `debug_guidance`: returns instructions that tell agents when to use the debugger. - `debug_python_repro`: best first tool for a reproducible Python bug. It launches a Python script under `debugpy`, sets breakpoints, continues to the first stop, and returns stack plus top-frame locals. +- `debug_pytest`: launches pytest under `debugpy`, stops on a failing assertion by default, and returns pytest outcome plus stopped-frame state. Debugger primitives: @@ -236,7 +266,7 @@ If using the local venv: .venv/bin/python tools/runtime_proof.py ``` -The proof talks to the MCP server over stdio, launches `examples/buggy_discount.py` under `debugpy`, sets a breakpoint, continues to it, steps into and out of functions, inspects local variables, evaluates expressions in a paused frame, tests attach mode, and cleans up the session. +The proof talks to the MCP server over stdio, launches `examples/buggy_discount.py` under `debugpy`, runs `examples/test_buggy_discount.py` with `debug_pytest`, sets a breakpoint, continues to it, steps into and out of functions, inspects local variables, evaluates expressions in a paused frame, tests attach mode, and cleans up the session. Expected output: @@ -247,6 +277,7 @@ Expected output: "MCP initialize/tools/list", "debug_guidance", "debug_python_repro", + "debug_pytest", "debug_launch", "debug_attach", "debug_set_breakpoints", @@ -350,7 +381,6 @@ Build a wheel: ## Roadmap -- `debug_pytest_failure`: run a failing pytest test under the debugger automatically. - Breakpoints by function name, symbol, marker comment, or exception type. - Richer first-stop summaries with surrounding source and suggested next debugger actions. - Agent-optimized CLI commands, with MCP as a thin wrapper for clients that prefer tools over shell commands. diff --git a/examples/test_buggy_discount.py b/examples/test_buggy_discount.py new file mode 100644 index 0000000..e57b8af --- /dev/null +++ b/examples/test_buggy_discount.py @@ -0,0 +1,8 @@ +from buggy_discount import apply_discount + + +def test_gold(): + price = 120.0 + loyalty_level = "gold" + actual = apply_discount(price, loyalty_level) + assert actual == 102.0 diff --git a/src/vibe_debug/cli.py b/src/vibe_debug/cli.py index 5d08c6b..6170ea6 100644 --- a/src/vibe_debug/cli.py +++ b/src/vibe_debug/cli.py @@ -3,9 +3,11 @@ import argparse import json import os +import shlex import shutil import subprocess import sys +import time from pathlib import Path from . import __version__ @@ -399,6 +401,90 @@ def _local_summaries(snapshot: dict[str, object]) -> list[dict[str, object]]: return summaries +def _exception_summary(exception_info: dict[str, object]) -> dict[str, object]: + exception_id = exception_info.get("exceptionId") + details = exception_info.get("details") + message = exception_info.get("description") + stack_trace = None + + if isinstance(details, dict): + detail_message = details.get("message") + if isinstance(detail_message, str) and detail_message: + message = detail_message + detail_stack = details.get("stackTrace") + if isinstance(detail_stack, str) and detail_stack: + stack_trace = detail_stack + + name = str(exception_id) if exception_id is not None else "" + if "." in name: + name = name.rsplit(".", 1)[-1] + + summary: dict[str, object] = { + "name": name or "Exception", + "message": message if isinstance(message, str) else "", + } + if stack_trace: + summary["stackTrace"] = stack_trace + return summary + + +def _pytest_outcome(stopped: dict[str, object], exception: dict[str, object] | None) -> str: + if exception and exception.get("name") == "AssertionError": + return "failed" + + body = stopped.get("body") + exit_code = body.get("exitCode") if isinstance(body, dict) else None + if exit_code == 0: + return "passed" + if isinstance(exit_code, int): + return "failed" + if stopped.get("state") == "stopped": + return "stopped" + return str(stopped.get("state") or "unknown") + + +def _pytest_args(values: list[str]) -> list[str]: + args: list[str] = [] + for value in values: + args.extend(shlex.split(value)) + return args + + +def _skip_pytest_exception_stop(stopped: dict[str, object], cwd: str | None) -> bool: + if stopped.get("state") != "stopped" or stopped.get("stoppedReason") != "exception": + return False + + location = stopped.get("location") + if not isinstance(location, dict): + return False + source = location.get("source") + if not isinstance(source, dict): + return False + file_name = source.get("path") + if not isinstance(file_name, str) or not file_name: + return False + + path = Path(file_name) + parts = set(path.parts) + if "site-packages" in parts or ".venv" in parts or "_pytest" in parts: + return True + + try: + path.resolve().relative_to(Path(cwd or os.getcwd()).resolve()) + except ValueError: + return True + return False + + +def _continue_pytest_execution(session: DebugSession, timeout: float, cwd: str | None) -> dict[str, object]: + stopped = session.continue_execution(timeout=timeout) + for _ in range(50): + if not _skip_pytest_exception_stop(stopped, cwd): + return stopped + stopped = session.continue_execution(timeout=timeout) + return stopped + + def _debug_python_payload(args: argparse.Namespace) -> dict[str, object]: breakpoints = args.breakpoints or [] stop_on_entry = bool(args.stop_on_entry or not breakpoints) @@ -466,6 +552,102 @@ def _debug_python_payload(args: argparse.Namespace) -> dict[str, object]: pass +def _debug_pytest_payload(args: argparse.Namespace) -> dict[str, object]: + breakpoints = args.breakpoints or [] + pytest_args = _pytest_args(args.pytest_args or []) + start = time.monotonic() + session: DebugSession | None = None + + try: + session = DebugSession.launch_pytest( + test_ids=args.test, + pytest_args=pytest_args, + cwd=args.cwd, + python=args.python, + timeout=float(args.timeout), + ) + exception_breakpoints: dict[str, object] = {} + if args.break_on_failure: + exception_breakpoints = session.set_exception_breakpoints( + exception_options=[ + { + "path": [{"names": ["Python Exceptions"]}, {"names": ["AssertionError"]}], + "breakMode": "always", + } + ] + ) + breakpoint_results = [ + session.set_breakpoints(file=str(item["file"]), lines=[int(item["line"])], cwd=args.cwd) + for item in breakpoints + ] + stopped = _continue_pytest_execution(session, timeout=float(args.timeout), cwd=args.cwd) + snapshot: dict[str, object] = {} + evaluations: list[dict[str, object]] = [] + exception: dict[str, object] | None = None + + if stopped.get("state") == "stopped": + snapshot = session.top_frame_locals(limit=int(args.locals_limit)) + if stopped.get("stoppedReason") == "exception": + try: + exception = _exception_summary(session.exception_info()) + except Exception as exc: + exception = {"name": type(exc).__name__, "message": str(exc)} + for expression in args.evaluate or []: + try: + result = session.evaluate(expression=expression) + evaluations.append( + { + "expression": expression, + "result": result.get("result"), + "type": result.get("type"), + } + ) + except Exception as exc: + evaluations.append( + { + "expression": expression, + "error": str(exc), + "exceptionType": type(exc).__name__, + } + ) + + stopped_summary: dict[str, object] = { + "state": stopped.get("state"), + "event": stopped.get("event"), + "reason": stopped.get("stoppedReason"), + } + stopped_summary.update(_location_summary(stopped.get("location"))) + + pytest_summary: dict[str, object] = { + "rootdir": str(Path(args.cwd or os.getcwd()).resolve()), + "test_id": " ".join(args.test), + "test_ids": args.test, + "pytest_args": pytest_args, + "outcome": _pytest_outcome(stopped, exception), + "duration_seconds": round(time.monotonic() - start, 3), + } + + payload: dict[str, object] = { + "ok": True, + "cwd": str(Path(args.cwd or os.getcwd()).resolve()), + "pytest": pytest_summary, + "breakpoints": _breakpoint_summaries(breakpoint_results), + "exceptionBreakpoints": exception_breakpoints, + "stopped": stopped_summary, + "locals": _local_summaries(snapshot), + "evaluations": evaluations, + } + if exception is not None: + payload["exception"] = exception + return payload + finally: + if session is not None: + try: + session.stop(terminate_debuggee=True) + except Exception: + pass + + def _print_debug_python_human(payload: dict[str, object]) -> None: stopped = payload.get("stopped") if isinstance(stopped, dict) and stopped.get("state") == "stopped": @@ -514,6 +696,19 @@ def _print_debug_python_human(payload: dict[str, object]) -> None: print(f" {expression} -> {item.get('result')}") +def _print_debug_pytest_human(payload: dict[str, object]) -> None: + pytest = payload.get("pytest") + if isinstance(pytest, dict): + print(f"Pytest: {pytest.get('test_id', 'unknown')}") + print(f"Outcome: {pytest.get('outcome', 'unknown')}") + + exception = payload.get("exception") + if isinstance(exception, dict): + print(f"Exception: {exception.get('name', 'Exception')}: {exception.get('message', '')}") + + _print_debug_python_human(payload) + + def _debug_python(args: argparse.Namespace) -> int: try: payload = _debug_python_payload(args) @@ -540,6 +735,32 @@ def _debug_python(args: argparse.Namespace) -> int: return 0 +def _debug_pytest(args: argparse.Namespace) -> int: + try: + payload = _debug_pytest_payload(args) + except Exception as exc: + if args.json: + print( + json.dumps( + { + "ok": False, + "error": str(exc), + "exceptionType": type(exc).__name__, + }, + indent=2, + ) + ) + else: + print(f"debug-pytest failed: {exc}", file=sys.stderr) + return 1 + + if args.json: + print(json.dumps(payload, indent=2, sort_keys=True)) + else: + _print_debug_pytest_human(payload) + return 0 + + def _load_json_maybe(value: object) -> object | None: if isinstance(value, dict): return value @@ -576,6 +797,10 @@ def _format_tool_use(name: str, tool_input: object) -> str: program = _basename(tool_input.get("program")) if program: return f"Tool: {_debugger_tool_name(name)} ({program})" + if name.endswith("__debug_pytest"): + test_id = tool_input.get("test_id") + if isinstance(test_id, str): + return f"Tool: {_debugger_tool_name(name)} ({test_id})" if name.endswith("__debug_evaluate"): expression = tool_input.get("expression") if isinstance(expression, str): @@ -590,6 +815,8 @@ def _format_tool_use(name: str, tool_input: object) -> str: command = tool_input.get("command") if isinstance(command, str) and "vibe-debug" in command and "debug-python" in command: return "Tool: Bash (vibe-debug debug-python)" + if isinstance(command, str) and "vibe-debug" in command and "debug-pytest" in command: + return "Tool: Bash (vibe-debug debug-pytest)" description = tool_input.get("description") if isinstance(description, str) and description: return f"Tool: Bash ({description})" @@ -799,7 +1026,25 @@ def flush_tool_errors(tool_name: str | None = None) -> None: return 0 +def _normalize_pytest_arg_flags(argv: list[str]) -> list[str]: + normalized: list[str] = [] + index = 0 + while index < len(argv): + item = argv[index] + if item == "--pytest-arg" and index + 1 < len(argv): + normalized.append(f"--pytest-arg={argv[index + 1]}") + index += 2 + continue + normalized.append(item) + index += 1 + return normalized + + def main(argv: list[str] | None = None) -> int: + if argv is None: + argv = sys.argv[1:] + argv = _normalize_pytest_arg_flags(argv) + parser = argparse.ArgumentParser(description="Utilities for the vibe-debug MCP server.") parser.add_argument("--version", action="version", version=f"vibe-debug {__version__}") subparsers = parser.add_subparsers(dest="command", required=True) @@ -852,6 +1097,36 @@ def main(argv: list[str] | None = None) -> int: debug_python.add_argument("--stop-on-entry", action="store_true") debug_python.add_argument("--eval", dest="evaluate", action="append", default=[], help="Evaluate expression at stop.") debug_python.add_argument("--json", action="store_true", help="Print machine-readable JSON output.") + debug_pytest = subparsers.add_parser( + "debug-pytest", + help="Run pytest under the debugger and print stopped-frame state.", + ) + debug_pytest.add_argument("test", nargs="+", help="Pytest node ID(s), such as tests/test_foo.py::test_bar.") + debug_pytest.add_argument( + "--pytest-arg", + dest="pytest_args", + action="append", + default=[], + help="Argument passed through to pytest. Repeat for multiple args.", + ) + debug_pytest.add_argument( + "--break", + "-b", + dest="breakpoints", + action="append", + type=_parse_breakpoint, + default=[], + metavar="FILE:LINE", + help="Set a line breakpoint before continuing. Repeat for multiple breakpoints.", + ) + debug_pytest.add_argument("--cwd", help="Working directory for pytest.") + debug_pytest.add_argument("--python", help="Python executable for debugpy and pytest.") + debug_pytest.add_argument("--timeout", type=float, default=30.0) + debug_pytest.add_argument("--locals-limit", type=int, default=40) + debug_pytest.add_argument("--eval", dest="evaluate", action="append", default=[], help="Evaluate expression at stop.") + debug_pytest.add_argument("--break-on-failure", action="store_true", default=True) + debug_pytest.add_argument("--no-break-on-failure", action="store_false", dest="break_on_failure") + debug_pytest.add_argument("--json", action="store_true", help="Print machine-readable JSON output.") subparsers.add_parser("claude-progress", help="Format Claude Code stream-json output for humans.") args = parser.parse_args(argv) @@ -873,6 +1148,8 @@ def main(argv: list[str] | None = None) -> int: return _demo_project(args.target, args.directory, args.force) if args.command == "debug-python": return _debug_python(args) + if args.command == "debug-pytest": + return _debug_pytest(args) if args.command == "claude-progress": return _format_claude_stream(sys.stdin, sys.stdout) parser.error(f"unknown command: {args.command}") diff --git a/src/vibe_debug/mcp_server.py b/src/vibe_debug/mcp_server.py index 0c25def..c857a84 100644 --- a/src/vibe_debug/mcp_server.py +++ b/src/vibe_debug/mcp_server.py @@ -2,8 +2,12 @@ import argparse import json +import os +import shlex import sys +import time import traceback +from pathlib import Path from typing import Any, Callable from . import __version__ @@ -72,6 +76,48 @@ def _tool_definitions() -> list[dict[str, Any]]: ["program"], ), }, + { + "name": "debug_pytest", + "description": ( + "Launch pytest under debugpy, optionally stop on AssertionError, and return pytest outcome, " + "stack, top-frame locals, and exception details." + ), + "inputSchema": _schema( + { + "test_id": {"type": "string", "description": "Pytest node ID, such as tests/test_foo.py::test_bar."}, + "test_ids": {"type": "array", "items": {"type": "string"}}, + "pytest_args": {"type": "array", "items": {"type": "string"}, "default": []}, + "cwd": {"type": "string", "description": "Working directory for pytest."}, + "python": {"type": "string", "description": "Python executable to use for debugpy and pytest."}, + "env": { + "type": "object", + "additionalProperties": {"type": "string"}, + "description": "Environment variables to add to pytest.", + }, + "breakpoints": { + "type": "array", + "description": "Breakpoints to set before pytest starts running.", + "items": { + "type": "object", + "properties": { + "file": {"type": "string"}, + "line": {"type": "integer"}, + }, + "required": ["file", "line"], + "additionalProperties": False, + }, + }, + "break_on_failure": {"type": "boolean", "default": True}, + "timeout": {"type": "number", "default": 30}, + "locals_limit": {"type": "integer", "default": 40}, + "keep_session": { + "type": "boolean", + "default": True, + "description": "Keep the session open after the first stop so the agent can step or inspect more.", + }, + }, + ), + }, { "name": "debug_launch", "description": "Launch a Python program under debugpy and pause before user code until breakpoints are configured.", @@ -215,12 +261,88 @@ def _tool_definitions() -> list[dict[str, Any]]: ] +def _exception_summary(exception_info: dict[str, Any]) -> dict[str, Any]: + exception_id = exception_info.get("exceptionId") + details = exception_info.get("details") + message = exception_info.get("description") + stack_trace = None + + if isinstance(details, dict): + detail_message = details.get("message") + if isinstance(detail_message, str) and detail_message: + message = detail_message + detail_stack = details.get("stackTrace") + if isinstance(detail_stack, str) and detail_stack: + stack_trace = detail_stack + + name = str(exception_id) if exception_id is not None else "" + if "." in name: + name = name.rsplit(".", 1)[-1] + + summary: dict[str, Any] = { + "name": name or "Exception", + "message": message if isinstance(message, str) else "", + } + if stack_trace: + summary["stackTrace"] = stack_trace + return summary + + +def _pytest_outcome(stopped: dict[str, Any], exception: dict[str, Any] | None) -> str: + if exception and exception.get("name") == "AssertionError": + return "failed" + + body = stopped.get("body") + exit_code = body.get("exitCode") if isinstance(body, dict) else None + if exit_code == 0: + return "passed" + if isinstance(exit_code, int): + return "failed" + if stopped.get("state") == "stopped": + return "stopped" + return str(stopped.get("state") or "unknown") + + +def _pytest_args(values: list[str]) -> list[str]: + args: list[str] = [] + for value in values: + args.extend(shlex.split(value)) + return args + + +def _skip_pytest_exception_stop(stopped: dict[str, Any], cwd: str | None) -> bool: + if stopped.get("state") != "stopped" or stopped.get("stoppedReason") != "exception": + return False + + location = stopped.get("location") + if not isinstance(location, dict): + return False + source = location.get("source") + if not isinstance(source, dict): + return False + file_name = source.get("path") + if not isinstance(file_name, str) or not file_name: + return False + + path = Path(file_name) + parts = set(path.parts) + if "site-packages" in parts or ".venv" in parts or "_pytest" in parts: + return True + + try: + path.resolve().relative_to(Path(cwd or os.getcwd()).resolve()) + except ValueError: + return True + return False + + class MCPDebuggerServer: def __init__(self) -> None: self.manager = DebugSessionManager() self.handlers: dict[str, ToolHandler] = { "debug_guidance": self._debug_guidance, "debug_python_repro": self._debug_python_repro, + "debug_pytest": self._debug_pytest, "debug_launch": self._debug_launch, "debug_attach": self._debug_attach, "debug_set_breakpoints": self._debug_set_breakpoints, @@ -238,6 +360,7 @@ def _debug_guidance(self, args: dict[str, Any]) -> dict[str, Any]: "guidance": AGENT_USAGE_GUIDANCE, "recommendedFirstTool": "debug_python_repro", "primitiveTools": [ + "debug_pytest", "debug_launch", "debug_attach", "debug_set_breakpoints", @@ -386,6 +509,90 @@ def _debug_python_repro(self, args: dict[str, Any]) -> dict[str, Any]: ], } + def _debug_pytest(self, args: dict[str, Any]) -> dict[str, Any]: + timeout = float(args.get("timeout", 30)) + test_ids = list(args.get("test_ids") or []) + if args.get("test_id"): + test_ids.append(str(args["test_id"])) + if not test_ids: + raise ValueError("debug_pytest requires test_id or test_ids") + + pytest_args = _pytest_args(args.get("pytest_args") or []) + start = time.monotonic() + launch = self.manager.launch_pytest( + test_ids=test_ids, + pytest_args=pytest_args, + cwd=args.get("cwd"), + python=args.get("python"), + env=args.get("env"), + timeout=timeout, + ) + session_id = launch["sessionId"] + session = self.manager.get(session_id) + breakpoint_results: list[dict[str, Any]] = [] + exception_breakpoints: dict[str, Any] = {} + + if bool(args.get("break_on_failure", True)): + exception_breakpoints = session.set_exception_breakpoints( + exception_options=[ + { + "path": [{"names": ["Python Exceptions"]}, {"names": ["AssertionError"]}], + "breakMode": "always", + } + ] + ) + + for item in args.get("breakpoints") or []: + breakpoint_results.append( + session.set_breakpoints( + file=item["file"], + lines=[int(item["line"])], + cwd=args.get("cwd"), + ) + ) + + stopped = session.continue_execution(timeout=timeout) + for _ in range(50): + if not _skip_pytest_exception_stop(stopped, args.get("cwd")): + break + stopped = session.continue_execution(timeout=timeout) + snapshot: dict[str, Any] = {} + exception: dict[str, Any] | None = None + if stopped.get("state") == "stopped": + snapshot = session.top_frame_locals(limit=int(args.get("locals_limit", 40))) + if stopped.get("stoppedReason") == "exception": + exception = _exception_summary(session.exception_info()) + + pytest = { + "rootdir": launch.get("cwd"), + "test_id": " ".join(test_ids), + "test_ids": test_ids, + "pytest_args": pytest_args, + "outcome": _pytest_outcome(stopped, exception), + "duration_seconds": round(time.monotonic() - start, 3), + } + + if not bool(args.get("keep_session", True)): + self.manager.stop(session_id) + + result: dict[str, Any] = { + "sessionId": session_id, + "launch": launch, + "pytest": pytest, + "breakpoints": breakpoint_results, + "exceptionBreakpoints": exception_breakpoints, + "stopped": stopped, + "snapshot": snapshot, + "nextActions": [ + "Use debug_step to move over/into/out from the current line.", + "Use debug_variables to expand object variablesReference values.", + "Use debug_stop when finished with this session.", + ], + } + if exception is not None: + result["exception"] = exception + return result + def _debug_attach(self, args: dict[str, Any]) -> dict[str, Any]: return self.manager.attach( host=args.get("host") or "127.0.0.1", diff --git a/src/vibe_debug/session.py b/src/vibe_debug/session.py index a5c85dc..3f2103b 100644 --- a/src/vibe_debug/session.py +++ b/src/vibe_debug/session.py @@ -122,6 +122,67 @@ def launch( session.state = "configuring" return session + @classmethod + def launch_pytest( + cls, + test_ids: list[str], + pytest_args: list[str] | None = None, + cwd: str | None = None, + python: str | None = None, + env: dict[str, str] | None = None, + timeout: float = 30.0, + ) -> "DebugSession": + python_executable = python or sys.executable + working_directory = _normalize_path(cwd or os.getcwd()) + host = "127.0.0.1" + port = _free_port() + args = [*(pytest_args or []), *test_ids] + + adapter_process = subprocess.Popen( + [python_executable, "-m", "debugpy.adapter", "--host", host, "--port", str(port)], + cwd=working_directory, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + text=True, + start_new_session=True, + ) + client = _connect_dap_with_retry(host, port, adapter_process=adapter_process, timeout=timeout) + session = cls( + session_id=str(uuid.uuid4()), + client=client, + adapter_process=adapter_process, + metadata={ + "mode": "pytest", + "module": "pytest", + "testIds": test_ids, + "pytestArgs": pytest_args or [], + "cwd": working_directory, + "adapterHost": host, + "adapterPort": port, + }, + ) + session._initialize() + + launch_args: dict[str, Any] = { + "name": "vibe-debug-pytest", + "type": "python", + "request": "launch", + "module": "pytest", + "cwd": working_directory, + "args": args, + "env": env or {}, + "console": "internalConsole", + "justMyCode": False, + "stopOnEntry": False, + "python": [python_executable], + } + session.launch_request_seq = client.send_request("launch", launch_args) + session.event_cursor = client.event_count() + initialized = client.wait_for_event("initialized", timeout=timeout, after=0) + session.event_cursor = max(session.event_cursor, client.events.index(initialized) + 1) + session.state = "configuring" + return session + @classmethod def attach( cls, @@ -168,6 +229,24 @@ def set_breakpoints(self, file: str, lines: list[int], cwd: str | None = None) - "breakpoints": body.get("breakpoints", []), } + def set_exception_breakpoints( + self, + filters: list[str] | None = None, + exception_options: list[dict[str, Any]] | None = None, + ) -> dict[str, Any]: + body = self.client.request( + "setExceptionBreakpoints", + { + "filters": filters or [], + "exceptionOptions": exception_options or [], + }, + ) + return { + "sessionId": self.session_id, + "state": self.state, + "breakpoints": body.get("breakpoints", []), + } + def continue_execution(self, timeout: float = 15.0) -> dict[str, Any]: if self.state == "configuring": start = self.client.event_count() @@ -273,6 +352,19 @@ def evaluate(self, expression: str, frame_id: int | None = None, context: str = "variablesReference": body.get("variablesReference"), } + def exception_info(self, thread_id: int | None = None) -> dict[str, Any]: + selected_thread = thread_id or self.stopped_thread_id + if selected_thread is None: + raise DebugSessionError("no stopped thread is available") + + body = self.client.request("exceptionInfo", {"threadId": selected_thread}) + return { + "sessionId": self.session_id, + "state": self.state, + "threadId": selected_thread, + **body, + } + def top_frame_locals(self, limit: int = 40) -> dict[str, Any]: stack = self.stack(levels=1) if not stack["frames"]: @@ -448,6 +540,15 @@ def launch(self, **kwargs: Any) -> dict[str, Any]: **session.metadata, } + def launch_pytest(self, **kwargs: Any) -> dict[str, Any]: + session = DebugSession.launch_pytest(**kwargs) + self._sessions[session.session_id] = session + return { + "sessionId": session.session_id, + "state": session.state, + **session.metadata, + } + def attach(self, **kwargs: Any) -> dict[str, Any]: session = DebugSession.attach(**kwargs) self._sessions[session.session_id] = session diff --git a/tests/test_cli.py b/tests/test_cli.py index 643642c..4aa607e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import sys import tempfile import unittest from contextlib import redirect_stdout @@ -137,6 +138,74 @@ def test_debug_python_stops_and_prints_locals(self) -> None: self.assertIn("y = 42", output) self.assertIn("y -> 42", output) + def test_debug_pytest_stops_on_assertion_failure_json(self) -> None: + test_id = "examples/test_buggy_discount.py::test_gold" + + code, output = call_cli( + [ + "debug-pytest", + test_id, + "--python", + sys.executable, + "--json", + ] + ) + + self.assertEqual(code, 0, output) + payload = json.loads(output) + self.assertEqual(payload["pytest"]["outcome"], "failed") + self.assertEqual(payload["pytest"]["test_id"], test_id) + self.assertEqual(payload["exception"]["name"], "AssertionError") + self.assertEqual(payload["stopped"]["reason"], "exception") + self.assertTrue(any(item["name"] == "price" for item in payload["locals"])) + self.assertTrue(any(item["name"] == "loyalty_level" for item in payload["locals"])) + + def test_debug_pytest_breakpoint_and_pytest_arg(self) -> None: + test_file = Path("examples/test_buggy_discount.py") + breakpoint_line = next( + index for index, line in enumerate(test_file.read_text().splitlines(), start=1) if "actual =" in line + ) + + code, output = call_cli( + [ + "debug-pytest", + str(test_file), + "--pytest-arg", + "-k", + "--pytest-arg", + "test_gold", + "--break", + f"{test_file}:{breakpoint_line}", + "--python", + sys.executable, + "--json", + ] + ) + + self.assertEqual(code, 0, output) + payload = json.loads(output) + self.assertEqual(payload["stopped"]["reason"], "breakpoint") + self.assertEqual(payload["stopped"]["function"], "test_gold") + self.assertEqual(payload["pytest"]["pytest_args"], ["-k", "test_gold"]) + + def test_debug_pytest_no_break_on_failure_reports_failed_exit(self) -> None: + code, output = call_cli( + [ + "debug-pytest", + "examples/test_buggy_discount.py::test_gold", + "--no-break-on-failure", + "--python", + sys.executable, + "--json", + ] + ) + + self.assertEqual(code, 0, output) + payload = json.loads(output) + self.assertEqual(payload["pytest"]["outcome"], "failed") + self.assertEqual(payload["stopped"]["state"], "exited") + self.assertNotIn("exception", payload) + def test_claude_progress_formats_debugger_events(self) -> None: events = [ { diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 6d6a0bb..c73b927 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -23,6 +23,7 @@ def test_initialize_and_tool_list(self) -> None: tools = {tool["name"] for tool in listed["result"]["tools"]} self.assertIn("debug_guidance", tools) self.assertIn("debug_python_repro", tools) + self.assertIn("debug_pytest", tools) self.assertIn("debug_launch", tools) self.assertIn("debug_set_breakpoints", tools) self.assertIn("debug_step", tools) diff --git a/tools/runtime_proof.py b/tools/runtime_proof.py index b6786e8..26379c4 100644 --- a/tools/runtime_proof.py +++ b/tools/runtime_proof.py @@ -13,6 +13,7 @@ ROOT = Path(__file__).resolve().parents[1] TARGET = ROOT / "examples" / "buggy_discount.py" +PYTEST_TARGET = ROOT / "examples" / "test_buggy_discount.py" class MCPClient: @@ -136,6 +137,7 @@ def main() -> int: required = { "debug_guidance", "debug_python_repro", + "debug_pytest", "debug_launch", "debug_attach", "debug_set_breakpoints", @@ -169,6 +171,21 @@ def main() -> int: assert repro["stopped"]["location"]["name"] == "main", repro assert any(variable["name"] == "price" for variable in repro["snapshot"]["locals"]), repro + pytest_repro = client.call_tool( + "debug_pytest", + { + "test_id": f"{PYTEST_TARGET.relative_to(ROOT)}::test_gold", + "cwd": str(ROOT), + "python": sys.executable, + "keep_session": False, + "timeout": 30, + }, + timeout=50, + ) + assert pytest_repro["pytest"]["outcome"] == "failed", pytest_repro + assert pytest_repro["exception"]["name"] == "AssertionError", pytest_repro + assert pytest_repro["stopped"]["location"]["name"] == "test_gold", pytest_repro + launch = client.call_tool( "debug_launch", { @@ -315,6 +332,7 @@ def main() -> int: "MCP initialize/tools/list", "debug_guidance", "debug_python_repro", + "debug_pytest", "debug_launch", "debug_attach", "debug_set_breakpoints",