diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index ee53264d2..8f43eaff8 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -117,6 +117,7 @@ from .tools import ( Tool, ToolBinaryResult, + ToolError, ToolInvocation, ToolResult, ToolResultType, @@ -223,6 +224,7 @@ "TelemetryConfig", "Tool", "ToolBinaryResult", + "ToolError", "ToolInvocation", "ToolResult", "ToolResultType", diff --git a/python/copilot/tools.py b/python/copilot/tools.py index c6a29dc61..ee220344e 100644 --- a/python/copilot/tools.py +++ b/python/copilot/tools.py @@ -18,6 +18,13 @@ ToolResultType = Literal["success", "failure", "rejected", "denied", "timeout"] +class ToolError(Exception): + """ + Exception raised by tool handlers to return a failure result to the LLM. + Unlike other exceptions, the message is intentionally surfaced to the LLM. + """ + + @dataclass class ToolBinaryResult: """Binary content returned by a tool.""" @@ -215,6 +222,14 @@ async def wrapped_handler(invocation: ToolInvocation) -> ToolResult: return _normalize_result(result) + except ToolError as exc: + msg = str(exc) + return ToolResult( + text_result_for_llm=msg, + result_type="failure", + error=msg, + ) + except Exception as exc: # Don't expose detailed error information to the LLM for security reasons. # The actual error is stored in the 'error' field for debugging. diff --git a/python/test_tools.py b/python/test_tools.py index d583b59c0..f62906f7f 100644 --- a/python/test_tools.py +++ b/python/test_tools.py @@ -7,6 +7,7 @@ from copilot import define_tool from copilot.tools import ( + ToolError, ToolInvocation, ToolResult, _normalize_result, @@ -197,6 +198,30 @@ def failing_tool(params: Params, invocation: ToolInvocation) -> str: # But the actual error is stored internally assert result.error == "secret error message" + async def test_tool_error_is_surfaced_to_llm(self): + class Params(BaseModel): + pass + + @define_tool("failing", description="A failing tool") + def failing_tool(params: Params, invocation: ToolInvocation) -> str: + raise ToolError("public error message") + + invocation = ToolInvocation( + session_id="s1", + tool_call_id="c1", + tool_name="failing", + arguments={}, + ) + + result = await failing_tool.handler(invocation) + + assert result.result_type == "failure" + assert result.text_result_for_llm == "public error message" + assert result.error == "public error message" + # ToolError must take the deliberate-failure path so the structured + # result reaches the LLM verbatim. + assert result._from_exception is False + async def test_function_style_api(self): class Params(BaseModel): value: str