Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,52 @@ async def invoke(payload):
# Limit tool calls to prevent infinite loops
tool_call_count = {"n": 0}

from strands.hooks import AfterToolCallEvent
from strands.hooks import BeforeToolCallEvent, BeforeModelCallEvent

def check_tool_limit(event: AfterToolCallEvent):
def check_tool_limit(event: BeforeToolCallEvent):
tool_call_count["n"] += 1
if tool_call_count["n"] >= 20:
if tool_call_count["n"] > 20:
logger.warning(f"⚠️ Tool call limit reached (20)")
raise RuntimeError("工具调用次数超过上限(20次),已强制停止。请简化问题后重试。")

agent.hooks.add_callback(AfterToolCallEvent, check_tool_limit)
event.cancel_tool = (
"工具调用次数已超过上限(20次)。"
"DO NOT CALL ANY MORE TOOLS. 请直接根据已有信息回答用户。"
)

def fix_messages_before_model(event: BeforeModelCallEvent):
"""Fix toolUse/toolResult mismatch right before model call.

Workaround for https://github.com/strands-agents/sdk-python/issues/2296
"""
messages = agent.messages
if not messages or len(messages) < 2:
return
for i in range(len(messages) - 1):
msg = messages[i]
if msg.get("role") != "assistant":
continue
tool_use_ids = [b["toolUse"]["toolUseId"] for b in msg.get("content", []) if "toolUse" in b]
if not tool_use_ids:
continue
next_msg = messages[i + 1]
if next_msg.get("role") != "user":
continue
next_content = next_msg.get("content", [])
tool_results = [b for b in next_content if "toolResult" in b]
if len(tool_results) == len(tool_use_ids):
continue
logger.warning(f"⚠️ Fixing toolUse/toolResult mismatch at msg {i}: "
f"{len(tool_use_ids)} toolUse vs {len(tool_results)} toolResult")
non_tool = [b for b in next_content if "toolResult" not in b]
fixed_results = []
for tid in tool_use_ids:
existing = next((b for b in tool_results
if b.get("toolResult", {}).get("toolUseId") == tid), None)
fixed_results.append(existing if existing else {"toolResult": {"toolUseId": tid,
"content": [{"text": "Tool execution was interrupted."}], "status": "error"}})
messages[i + 1]["content"] = non_tool + fixed_results

agent.hooks.add_callback(BeforeToolCallEvent, check_tool_limit)
agent.hooks.add_callback(BeforeModelCallEvent, fix_messages_before_model)

healthy_status.value = "HealthyBusy"
logger.info(f"🚀 Agent job starts | actor={actor_id} session={session_id}")
Expand Down