diff --git a/src/bub/builtin/agent.py b/src/bub/builtin/agent.py
index 9d5362e9..a7c3311d 100644
--- a/src/bub/builtin/agent.py
+++ b/src/bub/builtin/agent.py
@@ -24,6 +24,7 @@
StreamEvent,
StreamState,
TapeContext,
+ Tool,
ToolAutoResult,
ToolContext,
)
@@ -34,7 +35,7 @@
from bub.builtin.tape import TapeService
from bub.framework import BubFramework
from bub.skills import discover_skills, render_skills_prompt
-from bub.tools import REGISTRY, model_tools, render_tools_prompt
+from bub.tools import REGISTRY, model_tools, render_tools_prompt, resolve_tool_names
from bub.types import State
from bub.utils import workspace_from_state
@@ -530,12 +531,12 @@ async def _run_once(
) -> AsyncStreamEvents | ToolAutoResult:
prompt_text = prompt if isinstance(prompt, str) else _extract_text_from_parts(prompt)
if allowed_tools is not None:
- allowed_tools = {name.casefold() for name in allowed_tools}
+ allowed_tools = resolve_tool_names(allowed_tools)
if allowed_skills is not None:
allowed_skills = {name.casefold() for name in allowed_skills}
tape.context.state["allowed_skills"] = list(allowed_skills)
if allowed_tools is not None:
- tools = [tool for tool in REGISTRY.values() if tool.name.casefold() in allowed_tools]
+ tools = [tool for tool in REGISTRY.values() if tool.name in allowed_tools]
else:
tools = list(REGISTRY.values())
async with asyncio.timeout(self.settings.model_timeout_seconds):
@@ -543,7 +544,7 @@ async def _run_once(
return await tape.stream_events_async(
prompt=prompt,
system_prompt=self._system_prompt(
- prompt_text, state=tape.context.state, allowed_skills=allowed_skills
+ prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools
),
max_tokens=self.settings.max_tokens,
tools=model_tools(tools),
@@ -553,18 +554,20 @@ async def _run_once(
return await tape.run_tools_async(
prompt=prompt,
system_prompt=self._system_prompt(
- prompt_text, state=tape.context.state, allowed_skills=allowed_skills
+ prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools
),
max_tokens=self.settings.max_tokens,
tools=model_tools(tools),
model=model,
)
- def _system_prompt(self, prompt: str, state: State, allowed_skills: set[str] | None = None) -> str:
+ def _system_prompt(
+ self, prompt: str, state: State, allowed_skills: set[str] | None = None, tools: Iterable[Tool] | None = None
+ ) -> str:
blocks: list[str] = []
if result := self.framework.get_system_prompt(prompt=prompt, state=state):
blocks.append(result)
- tools_prompt = render_tools_prompt(REGISTRY.values())
+ tools_prompt = render_tools_prompt(tools if tools is not None else REGISTRY.values())
if tools_prompt:
blocks.append(tools_prompt)
workspace = workspace_from_state(state)
diff --git a/src/bub/tools.py b/src/bub/tools.py
index 046cc4b1..7cdbd799 100644
--- a/src/bub/tools.py
+++ b/src/bub/tools.py
@@ -186,13 +186,24 @@ def model_tools(tools: Iterable[Tool]) -> list[Tool]:
return [replace(tool, name=_to_model_name(tool.name)) for tool in tools]
+def _tool_signature(tool: Tool) -> str:
+ properties = tool.parameters.get("properties", {})
+ if not isinstance(properties, dict) or not properties:
+ return f"{_to_model_name(tool.name)}()"
+
+ required = tool.parameters.get("required", [])
+ required_names = set(required) if isinstance(required, list) else set()
+ params = [name if name in required_names else f"{name}?" for name in properties]
+ return f"{_to_model_name(tool.name)}({', '.join(params)})"
+
+
def render_tools_prompt(tools: Iterable[Tool]) -> str:
"""Render a human-readable description of tools for model prompts."""
if not tools:
return ""
lines = []
for tool in tools:
- line = f"- {_to_model_name(tool.name)}"
+ line = f"- {_tool_signature(tool)}"
if tool.description:
line += f": {tool.description}"
lines.append(line)
diff --git a/tests/test_builtin_agent.py b/tests/test_builtin_agent.py
index 4adeb169..2ad86f76 100644
--- a/tests/test_builtin_agent.py
+++ b/tests/test_builtin_agent.py
@@ -12,6 +12,7 @@
import bub.builtin.agent as agent_module
from bub.builtin.agent import Agent
from bub.builtin.settings import AgentSettings
+from bub.tools import REGISTRY, tool
def test_build_llm_passes_codex_resolver_to_republic(monkeypatch) -> None:
@@ -84,6 +85,7 @@ class _FakeTapeService:
def __init__(self, fork_capture: _ForkCapture) -> None:
self._fork = fork_capture
self.run_tools_model: str | None = None
+ self.stream_kwargs: dict[str, Any] | None = None
def session_tape(self, session_id: str, workspace: Any) -> MagicMock:
tape = MagicMock()
@@ -92,6 +94,7 @@ def session_tape(self, session_id: str, workspace: Any) -> MagicMock:
async def fake_stream_events_async(**kwargs: Any) -> AsyncStreamEvents:
self.run_tools_model = kwargs.get("model")
+ self.stream_kwargs = kwargs
async def iterator():
yield StreamEvent("final", {"text": "done"})
@@ -184,3 +187,56 @@ async def test_agent_run_model_defaults_to_none() -> None:
[event async for event in result]
assert fake_tapes.run_tools_model is None
+
+
+@pytest.mark.asyncio
+async def test_agent_run_resolves_allowed_tool_aliases_and_limits_prompt() -> None:
+ allowed_name = "tests.allowed_agent_tool"
+ denied_name = "tests.denied_agent_tool"
+ REGISTRY.pop(allowed_name, None)
+ REGISTRY.pop(denied_name, None)
+
+ @tool(name=allowed_name, description="Allowed tool")
+ def allowed_agent_tool() -> str:
+ return "allowed"
+
+ @tool(name=denied_name, description="Denied tool")
+ def denied_agent_tool() -> str:
+ return "denied"
+
+ agent = _make_agent()
+ fork_capture = _ForkCapture()
+ fake_tapes = _FakeTapeService(fork_capture)
+ agent.tapes = fake_tapes # type: ignore[assignment]
+
+ result = await agent.run_stream(
+ session_id="user/s1",
+ prompt="hello",
+ state={"_runtime_workspace": "/tmp"}, # noqa: S108
+ allowed_tools=[" tests_allowed_agent_tool "],
+ )
+ [event async for event in result]
+
+ assert fake_tapes.stream_kwargs is not None
+ assert [tool.name for tool in fake_tapes.stream_kwargs["tools"]] == ["tests_allowed_agent_tool"]
+ system_prompt = fake_tapes.stream_kwargs["system_prompt"]
+ assert "- tests_allowed_agent_tool(): Allowed tool" in system_prompt
+ assert "tests_denied_agent_tool" not in system_prompt
+
+
+@pytest.mark.asyncio
+async def test_agent_run_rejects_unknown_allowed_tools() -> None:
+ agent = _make_agent()
+ fork_capture = _ForkCapture()
+ fake_tapes = _FakeTapeService(fork_capture)
+ agent.tapes = fake_tapes # type: ignore[assignment]
+
+ stream = await agent.run_stream(
+ session_id="user/s1",
+ prompt="hello",
+ state={"_runtime_workspace": "/tmp"}, # noqa: S108
+ allowed_tools=["tests_missing_agent_tool"],
+ )
+
+ with pytest.raises(ValueError, match="tests_missing_agent_tool"):
+ [event async for event in stream]
diff --git a/tests/test_tools.py b/tests/test_tools.py
index e30588fa..927de9c3 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -93,13 +93,15 @@ def test_model_tools_rewrites_dotted_names_without_mutating_original() -> None:
REGISTRY.pop(tool_name, None)
@tool(name=tool_name, description="rename")
- def rename_me() -> str:
+ def rename_me(value: str) -> str:
return "ok"
rewritten = model_tools([rename_me])
assert [item.name for item in rewritten] == ["tests_rename_me"]
+ assert rewritten[0].parameters == rename_me.parameters
assert rename_me.name == tool_name
+ assert "additionalProperties" not in rename_me.parameters
def test_render_tools_prompt_renders_available_tools_block() -> None:
@@ -118,7 +120,20 @@ def prompt_two() -> str:
rendered = render_tools_prompt([prompt_one, prompt_two])
- assert rendered == "\n- tests_prompt_one: First tool\n- tests_prompt_two\n"
+ assert rendered == "\n- tests_prompt_one(): First tool\n- tests_prompt_two()\n"
+
+
+def test_render_tools_prompt_includes_model_name_and_parameter_signature() -> None:
+ tool_name = "tests.prompt_signature"
+ REGISTRY.pop(tool_name, None)
+
+ @tool(name=tool_name, description="Read a file")
+ def prompt_signature(path: str, offset: int = 0) -> str:
+ return f"{path}:{offset}"
+
+ rendered = render_tools_prompt([prompt_signature])
+
+ assert rendered == "\n- tests_prompt_signature(path, offset?): Read a file\n"
def test_render_tools_prompt_returns_empty_string_for_empty_input() -> None:
@@ -128,8 +143,10 @@ def test_render_tools_prompt_returns_empty_string_for_empty_input() -> None:
def test_resolve_tool_names_accepts_runtime_names_and_model_aliases() -> None:
dotted_name = "tests.resolve_alias"
underscored_name = "tests_with_underscore"
+ excluded_name = "tests.excluded_tool"
REGISTRY.pop(dotted_name, None)
REGISTRY.pop(underscored_name, None)
+ REGISTRY.pop(excluded_name, None)
@tool(name=dotted_name)
def resolve_alias() -> str:
@@ -139,11 +156,18 @@ def resolve_alias() -> str:
def resolve_runtime_name() -> str:
return "runtime"
- assert resolve_tool_names([" tests_resolve_alias ", " tests_with_underscore "], exclude={" subagent "}) == {
+ @tool(name=excluded_name)
+ def excluded_tool() -> str:
+ return "excluded"
+
+ assert resolve_tool_names(
+ [" tests_resolve_alias ", " tests_with_underscore "], exclude={" tests_excluded_tool "}
+ ) == {
dotted_name,
underscored_name,
}
assert dotted_name not in resolve_tool_names(None, exclude={" tests_resolve_alias "})
+ assert excluded_name not in resolve_tool_names(None, exclude={" tests_excluded_tool "})
assert resolve_tool_names(None, exclude={" tests_resolve_alias "}) >= {underscored_name}