diff --git a/main.py b/main.py index 8d09e8d..e3fb980 100644 --- a/main.py +++ b/main.py @@ -176,15 +176,52 @@ async def invoke(payload): # Limit tool calls to prevent infinite loops tool_call_count = {"n": 0} - from strands.hooks import AfterToolCallEvent + from strands.hooks import BeforeToolCallEvent, BeforeModelCallEvent - def check_tool_limit(event: AfterToolCallEvent): + def check_tool_limit(event: BeforeToolCallEvent): tool_call_count["n"] += 1 - if tool_call_count["n"] >= 20: + if tool_call_count["n"] > 20: logger.warning(f"⚠️ Tool call limit reached (20)") - raise RuntimeError("工具调用次数超过上限(20次),已强制停止。请简化问题后重试。") - - agent.hooks.add_callback(AfterToolCallEvent, check_tool_limit) + event.cancel_tool = ( + "工具调用次数已超过上限(20次)。" + "DO NOT CALL ANY MORE TOOLS. 请直接根据已有信息回答用户。" + ) + + def fix_messages_before_model(event: BeforeModelCallEvent): + """Fix toolUse/toolResult mismatch right before model call. + + Workaround for https://github.com/strands-agents/sdk-python/issues/2296 + """ + messages = agent.messages + if not messages or len(messages) < 2: + return + for i in range(len(messages) - 1): + msg = messages[i] + if msg.get("role") != "assistant": + continue + tool_use_ids = [b["toolUse"]["toolUseId"] for b in msg.get("content", []) if "toolUse" in b] + if not tool_use_ids: + continue + next_msg = messages[i + 1] + if next_msg.get("role") != "user": + continue + next_content = next_msg.get("content", []) + tool_results = [b for b in next_content if "toolResult" in b] + if len(tool_results) == len(tool_use_ids): + continue + logger.warning(f"⚠️ Fixing toolUse/toolResult mismatch at msg {i}: " + f"{len(tool_use_ids)} toolUse vs {len(tool_results)} toolResult") + non_tool = [b for b in next_content if "toolResult" not in b] + fixed_results = [] + for tid in tool_use_ids: + existing = next((b for b in tool_results + if b.get("toolResult", {}).get("toolUseId") == tid), None) + fixed_results.append(existing if existing else {"toolResult": {"toolUseId": tid, + "content": [{"text": "Tool execution was interrupted."}], "status": "error"}}) + messages[i + 1]["content"] = non_tool + fixed_results + + agent.hooks.add_callback(BeforeToolCallEvent, check_tool_limit) + agent.hooks.add_callback(BeforeModelCallEvent, fix_messages_before_model) healthy_status.value = "HealthyBusy" logger.info(f"🚀 Agent job starts | actor={actor_id} session={session_id}")