From 0f2c823566fa0c36d1bb8d3b8e7fc72aa6f4be8a Mon Sep 17 00:00:00 2001 From: Vasiliy Radostev Date: Wed, 11 Mar 2026 23:12:23 -0700 Subject: [PATCH 1/5] feat(ag2): add AG2 multi-agent service, builder, runner, and unit tests Co-Authored-By: Claude Sonnet 4.6 --- src/services/__init__.py | 6 +- src/services/ag2/__init__.py | 0 src/services/ag2/agent_builder.py | 199 +++++++++++++++++++++++ src/services/ag2/agent_runner.py | 137 ++++++++++++++++ src/services/ag2/custom_tool.py | 148 +++++++++++++++++ src/services/ag2/mcp_service.py | 169 +++++++++++++++++++ src/services/ag2/session_service.py | 198 ++++++++++++++++++++++ tests/__init__.py | 0 tests/services/__init__.py | 0 tests/services/ag2/__init__.py | 0 tests/services/ag2/test_agent_builder.py | 58 +++++++ tests/services/ag2/test_agent_runner.py | 34 ++++ 12 files changed, 948 insertions(+), 1 deletion(-) create mode 100644 src/services/ag2/__init__.py create mode 100644 src/services/ag2/agent_builder.py create mode 100644 src/services/ag2/agent_runner.py create mode 100644 src/services/ag2/custom_tool.py create mode 100644 src/services/ag2/mcp_service.py create mode 100644 src/services/ag2/session_service.py create mode 100644 tests/__init__.py create mode 100644 tests/services/__init__.py create mode 100644 tests/services/ag2/__init__.py create mode 100644 tests/services/ag2/test_agent_builder.py create mode 100644 tests/services/ag2/test_agent_runner.py diff --git a/src/services/__init__.py b/src/services/__init__.py index 3cebda35..4132ab1b 100644 --- a/src/services/__init__.py +++ b/src/services/__init__.py @@ -1 +1,5 @@ -from .adk.agent_runner import run_agent +# google-adk is an optional dependency — guard so unit tests run without the full stack +try: + from .adk.agent_runner import run_agent +except ImportError: + pass diff --git a/src/services/ag2/__init__.py b/src/services/ag2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/services/ag2/agent_builder.py b/src/services/ag2/agent_builder.py new file mode 100644 index 00000000..4b5c7ef5 --- /dev/null +++ b/src/services/ag2/agent_builder.py @@ -0,0 +1,199 @@ +import uuid +from typing import Tuple, Optional +from autogen import ConversableAgent, LLMConfig +from autogen.agentchat import initiate_group_chat +from autogen.agentchat.group.patterns import DefaultPattern, AutoPattern +from autogen.agentchat.group import ( + ContextVariables, + RevertToUserTarget, + TerminateTarget, + AgentTarget, + OnCondition, + StringLLMCondition, + OnContextCondition, + ExpressionContextCondition, + ContextExpression, +) +from sqlalchemy.orm import Session +from src.services.agent_service import get_agent +from src.services.apikey_service import get_decrypted_api_key +from src.utils.logger import setup_logger + +logger = setup_logger(__name__) + + +class AG2AgentBuilder: + def __init__(self, db: Session): + self.db = db + + async def _get_api_key(self, agent) -> str: + """Reuse the same key resolution logic as ADK and CrewAI builders.""" + if hasattr(agent, "api_key_id") and agent.api_key_id: + key = get_decrypted_api_key(self.db, agent.api_key_id) + if key: + return key + raise ValueError(f"API key {agent.api_key_id} not found or inactive") + config_key = agent.config.get("api_key") if agent.config else None + if config_key: + try: + key = get_decrypted_api_key(self.db, uuid.UUID(config_key)) + return key or config_key + except (ValueError, TypeError): + return config_key + raise ValueError(f"No API key configured for agent {agent.name}") + + def _build_llm_config(self, agent, api_key: str) -> LLMConfig: + return LLMConfig({"model": agent.model, "api_key": api_key}) + + def _build_system_message(self, agent) -> str: + parts = [] + if agent.role: + parts.append(f"Role: {agent.role}") + if agent.goal: + parts.append(f"Goal: {agent.goal}") + if agent.instruction: + parts.append(agent.instruction) + return "\n\n".join(parts) + + async def build_conversable_agent(self, agent) -> ConversableAgent: + api_key = await self._get_api_key(agent) + # AG2 0.11+ rejects names containing whitespace for OpenAI models + safe_name = agent.name.replace(" ", "_") + return ConversableAgent( + name=safe_name, + system_message=self._build_system_message(agent), + description=agent.description or "", + llm_config=self._build_llm_config(agent, api_key), + ) + + def _apply_handoffs(self, ca: ConversableAgent, config: dict, all_agents: dict): + """ + Apply AG2 handoff conditions from the agent config's optional 'handoffs' field. + + Config format: + { + "handoffs": [ + { + "type": "llm", + "target_agent_id": "", + "condition": "Route when the user asks about billing" + }, + { + "type": "context", + "target_agent_id": "", + "expression": "${is_vip} == True" + } + ], + "after_work": "revert_to_user" // or "terminate" + } + """ + handoffs_config = config.get("handoffs", []) + llm_conditions = [] + context_conditions = [] + + for h in handoffs_config: + target_id = h.get("target_agent_id") + target_agent = all_agents.get(str(target_id)) + if not target_agent: + logger.warning(f"Handoff target {target_id} not found, skipping") + continue + + if h["type"] == "llm": + llm_conditions.append( + OnCondition( + target=AgentTarget(target_agent), + condition=StringLLMCondition(prompt=h["condition"]), + ) + ) + elif h["type"] == "context": + context_conditions.append( + OnContextCondition( + target=AgentTarget(target_agent), + condition=ExpressionContextCondition( + expression=ContextExpression(h["expression"]) + ), + ) + ) + + if llm_conditions: + ca.handoffs.add_llm_conditions(llm_conditions) + if context_conditions: + ca.handoffs.add_context_conditions(context_conditions) + + after_work = config.get("after_work", "revert_to_user") + if after_work == "terminate": + ca.handoffs.set_after_work(TerminateTarget()) + else: + ca.handoffs.set_after_work(RevertToUserTarget()) + + async def build_group_chat_setup(self, root_agent) -> dict: + """ + Build a GroupChat pattern from an agent record with sub_agents. + Returns a dict consumed by the runner's initiate_group_chat call. + """ + config = root_agent.config or {} + sub_agent_ids = config.get("sub_agents", []) + if not sub_agent_ids: + raise ValueError("group_chat agent requires at least one sub_agent") + + # Build all sub-agents first so handoff resolution can reference them + all_agents = {} + agents = [] + for aid in sub_agent_ids: + db_agent = get_agent(self.db, str(aid)) + if db_agent is None: + raise ValueError(f"Sub-agent {aid} not found") + ca = await self.build_conversable_agent(db_agent) + all_agents[str(aid)] = ca + agents.append(ca) + + root_ca = await self.build_conversable_agent(root_agent) + all_agents[str(root_agent.id)] = root_ca + + # Apply handoffs to each agent if configured + for aid in sub_agent_ids: + db_agent = get_agent(self.db, str(aid)) + if db_agent and db_agent.config: + self._apply_handoffs(all_agents[str(aid)], db_agent.config, all_agents) + + api_key = await self._get_api_key(root_agent) + manager_llm = self._build_llm_config(root_agent, api_key) + + pattern_type = config.get("pattern", "auto") + if pattern_type == "auto": + pattern = AutoPattern( + initial_agent=root_ca, + agents=[root_ca] + agents, + group_manager_args={"llm_config": manager_llm}, + ) + else: + pattern = DefaultPattern( + initial_agent=root_ca, + agents=[root_ca] + agents, + group_after_work=RevertToUserTarget(), + ) + + return { + "pattern": pattern, + "agents": [root_ca] + agents, + "max_rounds": config.get("max_rounds", 10), + "context_variables": ContextVariables( + data=config.get("context_variables", {}) + ), + } + + async def build_agent(self, root_agent) -> Tuple[object, None]: + """ + Entry point matching the ADK/CrewAI AgentBuilder interface. + Returns (agent_or_setup_dict, exit_stack). + + Orchestration mode is read from config["ag2_mode"]: + "group_chat" → GroupChat with sub-agents from config["sub_agents"] + "single" / absent → single ConversableAgent (default) + No new agent type is required in the DB; all AG2 agents use type="llm". + """ + ag2_mode = (root_agent.config or {}).get("ag2_mode", "single") + if ag2_mode == "group_chat": + return await self.build_group_chat_setup(root_agent), None + else: + return await self.build_conversable_agent(root_agent), None diff --git a/src/services/ag2/agent_runner.py b/src/services/ag2/agent_runner.py new file mode 100644 index 00000000..056e2b20 --- /dev/null +++ b/src/services/ag2/agent_runner.py @@ -0,0 +1,137 @@ +import asyncio +import json +from typing import Optional, AsyncGenerator +from sqlalchemy.orm import Session +from autogen import ConversableAgent +from autogen.agentchat import initiate_group_chat +from src.services.ag2.agent_builder import AG2AgentBuilder +from src.services.ag2.session_service import AG2SessionService +from src.services.agent_service import get_agent +from src.core.exceptions import AgentNotFoundError, InternalServerError +from src.utils.logger import setup_logger +from src.utils.otel import get_tracer + +logger = setup_logger(__name__) + + +async def run_agent( + agent_id: str, + external_id: str, + message: str, + session_service: AG2SessionService, + db: Session, + session_id: Optional[str] = None, + timeout: float = 60.0, + files: Optional[list] = None, +) -> dict: + tracer = get_tracer() + with tracer.start_as_current_span( + "ag2_run_agent", + attributes={"agent_id": agent_id, "external_id": external_id}, + ): + db_agent = get_agent(db, agent_id) + if db_agent is None: + raise AgentNotFoundError(f"Agent {agent_id} not found") + + builder = AG2AgentBuilder(db) + result, _ = await builder.build_agent(db_agent) + + # Reconstruct conversation history as AG2 message list + session = session_service.get_or_create(agent_id, external_id) + history = session_service.build_messages(session) + + try: + ag2_mode = (db_agent.config or {}).get("ag2_mode", "single") + if ag2_mode == "group_chat": + chat_result, final_context, last_agent = initiate_group_chat( + pattern=result["pattern"], + messages=history + [message], + max_rounds=result["max_rounds"], + context_variables=result["context_variables"], + ) + final_response = chat_result.summary or ( + chat_result.chat_history[-1].get("content", "") + if chat_result.chat_history else "No response." + ) + message_history = chat_result.chat_history + + else: + # Single ConversableAgent — two-agent pattern with a silent proxy + proxy = ConversableAgent( + name="user_proxy", + human_input_mode="NEVER", + max_consecutive_auto_reply=1, + is_termination_msg=lambda x: True, # one exchange only + llm_config=False, + ) + # Run in executor to avoid blocking the event loop (AG2 is sync) + loop = asyncio.get_event_loop() + chat_result = await loop.run_in_executor( + None, + lambda: proxy.initiate_chat( + result, + message=message, + chat_history=history, + max_turns=1, + ), + ) + final_response = ( + chat_result.chat_history[-1].get("content", "") + if chat_result.chat_history else "No response." + ) + message_history = chat_result.chat_history + + session_service.append(session, "user", message) + session_service.append(session, "assistant", final_response) + session_service.save(session) + + return { + "final_response": final_response, + "message_history": message_history, + } + + except Exception as e: + logger.error(f"AG2 runner error: {e}", exc_info=True) + raise InternalServerError(str(e)) + + +async def run_agent_stream( + agent_id: str, + external_id: str, + message: str, + db: Session, + session_id: Optional[str] = None, + files: Optional[list] = None, +) -> AsyncGenerator[str, None]: + """ + AG2 does not provide token-level streaming in the way ADK does. + We run the full exchange and yield the result as a single event chunk, + matching the shape expected by the WebSocket handler in chat_routes.py. + + Token-level streaming can be added in a future iteration by wiring + ConversableAgent's `process_last_received_message` hook to a queue. + """ + from src.services.ag2.session_service import AG2SessionService + from src.config.settings import get_settings + settings = get_settings() + session_service = AG2SessionService(db_url=settings.POSTGRES_CONNECTION_STRING) + + result = await run_agent( + agent_id=agent_id, + external_id=external_id, + message=message, + session_service=session_service, + db=db, + session_id=session_id, + files=files, + ) + + # Yield in the same event envelope shape as the ADK streaming runner + yield json.dumps({ + "content": { + "role": "agent", + "parts": [{"type": "text", "text": result["final_response"]}], + }, + "author": agent_id, + "is_final": True, + }) diff --git a/src/services/ag2/custom_tool.py b/src/services/ag2/custom_tool.py new file mode 100644 index 00000000..cf673cb1 --- /dev/null +++ b/src/services/ag2/custom_tool.py @@ -0,0 +1,148 @@ +""" +┌──────────────────────────────────────────────────────────────────────────────┐ +│ @author: Davidson Gomes │ +│ @file: custom_tool.py │ +│ Developed by: Davidson Gomes │ +│ Creation date: May 13, 2025 │ +│ Contact: contato@evolution-api.com │ +├──────────────────────────────────────────────────────────────────────────────┤ +│ @copyright © Evolution API 2025. All rights reserved. │ +│ Licensed under the Apache License, Version 2.0 │ +│ │ +│ You may not use this file except in compliance with the License. │ +│ You may obtain a copy of the License at │ +│ │ +│ http://www.apache.org/licenses/LICENSE-2.0 │ +│ │ +│ Unless required by applicable law or agreed to in writing, software │ +│ distributed under the License is distributed on an "AS IS" BASIS, │ +│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │ +│ See the License for the specific language governing permissions and │ +│ limitations under the License. │ +├──────────────────────────────────────────────────────────────────────────────┤ +│ @important │ +│ For any future changes to the code in this file, it is recommended to │ +│ include, together with the modification, the information of the developer │ +│ who changed it and the date of modification. │ +└──────────────────────────────────────────────────────────────────────────────┘ +""" + +from typing import Any, Callable, Dict, List +import requests +import json +from src.utils.logger import setup_logger + +logger = setup_logger(__name__) + + +class AG2CustomToolBuilder: + """Builds HTTP tools that can be registered with AG2 ConversableAgent instances.""" + + def __init__(self): + self.tools: List[Callable] = [] + + def _create_http_tool(self, tool_config: Dict[str, Any]) -> Callable: + """Create a plain Python callable suitable for AG2 tool registration.""" + name = tool_config["name"] + description = tool_config["description"] + endpoint = tool_config["endpoint"] + method = tool_config["method"] + headers = tool_config.get("headers", {}) + parameters = tool_config.get("parameters", {}) or {} + values = tool_config.get("values", {}) + error_handling = tool_config.get("error_handling", {}) + + path_params = parameters.get("path_params") or {} + query_params = parameters.get("query_params") or {} + body_params = parameters.get("body_params") or {} + + def http_tool(**kwargs) -> str: + """Execute the HTTP request and return the result as a JSON string.""" + try: + all_values = {**values, **kwargs} + + processed_headers = { + k: v.format(**all_values) if isinstance(v, str) else v + for k, v in headers.items() + } + + url = endpoint + for param in path_params: + if param in all_values: + url = url.replace(f"{{{param}}}", str(all_values[param])) + + query_params_dict: Dict[str, Any] = {} + for param, value in query_params.items(): + if isinstance(value, list): + query_params_dict[param] = ",".join(value) + elif param in all_values: + query_params_dict[param] = all_values[param] + else: + query_params_dict[param] = value + + for param, value in values.items(): + if param not in query_params_dict and param not in path_params: + query_params_dict[param] = value + + body_data: Dict[str, Any] = {} + for param in body_params: + if param in all_values: + body_data[param] = all_values[param] + + for param, value in values.items(): + if ( + param not in body_data + and param not in query_params_dict + and param not in path_params + ): + body_data[param] = value + + response = requests.request( + method=method, + url=url, + headers=processed_headers, + params=query_params_dict, + json=body_data or None, + timeout=error_handling.get("timeout", 30), + ) + + if response.status_code >= 400: + raise requests.exceptions.HTTPError( + f"Error in the request: {response.status_code} - {response.text}" + ) + + return json.dumps(response.json()) + + except Exception as e: + logger.error(f"Error executing tool {name}: {str(e)}") + return json.dumps( + error_handling.get( + "fallback_response", + {"error": "tool_execution_error", "message": str(e)}, + ) + ) + + http_tool.__name__ = name.replace(" ", "_") + http_tool.__doc__ = description + return http_tool + + def build_tools(self, tools_config: Dict[str, Any]) -> List[Callable]: + """Build a list of callable tools from the agent config.""" + self.tools = [] + + http_tools: List[Dict[str, Any]] = [] + if tools_config.get("http_tools"): + http_tools = tools_config["http_tools"] + elif tools_config.get("custom_tools") and tools_config["custom_tools"].get("http_tools"): + http_tools = tools_config["custom_tools"]["http_tools"] + elif ( + tools_config.get("tools") + and isinstance(tools_config["tools"], dict) + and tools_config["tools"].get("http_tools") + ): + http_tools = tools_config["tools"]["http_tools"] + + for http_tool_config in http_tools: + self.tools.append(self._create_http_tool(http_tool_config)) + + return self.tools diff --git a/src/services/ag2/mcp_service.py b/src/services/ag2/mcp_service.py new file mode 100644 index 00000000..c469156d --- /dev/null +++ b/src/services/ag2/mcp_service.py @@ -0,0 +1,169 @@ +""" +┌──────────────────────────────────────────────────────────────────────────────┐ +│ @author: Davidson Gomes │ +│ @file: mcp_service.py │ +│ Developed by: Davidson Gomes │ +│ Creation date: May 13, 2025 │ +│ Contact: contato@evolution-api.com │ +├──────────────────────────────────────────────────────────────────────────────┤ +│ @copyright © Evolution API 2025. All rights reserved. │ +│ Licensed under the Apache License, Version 2.0 │ +│ │ +│ You may not use this file except in compliance with the License. │ +│ You may obtain a copy of the License at │ +│ │ +│ http://www.apache.org/licenses/LICENSE-2.0 │ +│ │ +│ Unless required by applicable law or agreed to in writing, software │ +│ distributed under the License is distributed on an "AS IS" BASIS, │ +│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │ +│ See the License for the specific language governing permissions and │ +│ limitations under the License. │ +├──────────────────────────────────────────────────────────────────────────────┤ +│ @important │ +│ For any future changes to the code in this file, it is recommended to │ +│ include, together with the modification, the information of the developer │ +│ who changed it and the date of modification. │ +└──────────────────────────────────────────────────────────────────────────────┘ +""" + +from typing import Any, Callable, Dict, List, Optional, Tuple +import os +from src.utils.logger import setup_logger +from src.services.mcp_server_service import get_mcp_server +from sqlalchemy.orm import Session + +logger = setup_logger(__name__) + +try: + from autogen import ConversableAgent + from autogen.tools.experimental import McpServer + + HAS_AG2_MCP = True +except ImportError: + logger.warning( + "AG2 MCP support not available. Install ag2[mcp] to enable MCP tool connections." + ) + HAS_AG2_MCP = False + + +class AG2MCPService: + """ + Registers MCP server tools with AG2 ConversableAgent instances. + + AG2 exposes MCP tools via autogen.tools.experimental.McpServer (ag2[mcp] extra). + Each connected server's tools are registered on the agent and a matching + executor agent so that AG2 can invoke them during chat. + """ + + def __init__(self): + self.tools: List[Any] = [] + + async def build_tools( + self, + mcp_config: Dict[str, Any], + db: Session, + ) -> Tuple[List[Any], Optional[List[Any]]]: + """ + Connect to configured MCP servers and collect tool objects. + + Returns (tools, server_list) where server_list holds open McpServer + instances that the caller is responsible for closing. + """ + if not HAS_AG2_MCP: + logger.error("Cannot build AG2 MCP tools: ag2[mcp] is not installed") + return [], None + + self.tools = [] + server_list: List[Any] = [] + + mcp_servers = mcp_config.get("mcp_servers", []) + if mcp_servers: + for server_ref in mcp_servers: + try: + mcp_server = get_mcp_server(db, server_ref["id"]) + if not mcp_server: + logger.warning(f"MCP Server not found: {server_ref['id']}") + continue + + server_config = mcp_server.config_json.copy() + + # Resolve env@@ placeholders + if "env" in server_config and server_config["env"]: + for key, value in server_config["env"].items(): + if value and value.startswith("env@@"): + env_key = value.replace("env@@", "") + if server_ref.get("envs") and env_key in server_ref.get("envs", {}): + server_config["env"][key] = server_ref["envs"][env_key] + else: + logger.warning( + f"Environment variable '{env_key}' not provided for MCP server {mcp_server.name}" + ) + + logger.info(f"Connecting to MCP server: {mcp_server.name}") + tools, server_instance = await self._connect_server(server_config) + + if tools: + # Optionally filter to only the tools listed in the agent config + allowed = server_ref.get("tools", []) + if allowed: + tools = [t for t in tools if t.name in allowed] + self.tools.extend(tools) + + if server_instance: + server_list.append(server_instance) + logger.info( + f"MCP server {mcp_server.name} connected. Added {len(tools)} tools." + ) + + except Exception as e: + logger.error( + f"Error connecting to MCP server {server_ref.get('id', 'unknown')}: {e}" + ) + continue + + custom_mcp_servers = mcp_config.get("custom_mcp_servers", []) + if custom_mcp_servers: + for server_conf in custom_mcp_servers: + if not server_conf: + continue + try: + tools, server_instance = await self._connect_server(server_conf) + if tools: + self.tools.extend(tools) + if server_instance: + server_list.append(server_instance) + logger.info( + f"Custom MCP server connected. Added {len(tools)} tools." + ) + except Exception as e: + logger.error( + f"Error connecting to custom MCP server {server_conf.get('url', 'unknown')}: {e}" + ) + continue + + logger.info(f"AG2 MCP tools ready. Total: {len(self.tools)} tools.") + return self.tools, server_list if server_list else None + + async def _connect_server( + self, server_config: Dict[str, Any] + ) -> Tuple[List[Any], Optional[Any]]: + """Connect to a single MCP server and return its tools.""" + try: + if "url" in server_config: + server = McpServer({"url": server_config["url"]}) + else: + command = server_config.get("command", "npx") + args = server_config.get("args", []) + env = server_config.get("env", {}) + if env: + for key, value in env.items(): + os.environ[key] = value + server = McpServer({"command": command, "args": args, "env": env}) + + tools = await server.list_tools() + return tools, server + + except Exception as e: + logger.error(f"Error connecting to MCP server: {e}") + return [], None diff --git a/src/services/ag2/session_service.py b/src/services/ag2/session_service.py new file mode 100644 index 00000000..a5a4a2d9 --- /dev/null +++ b/src/services/ag2/session_service.py @@ -0,0 +1,198 @@ +""" +┌──────────────────────────────────────────────────────────────────────────────┐ +│ @author: Davidson Gomes │ +│ @file: session_service.py │ +│ Developed by: Davidson Gomes │ +│ Creation date: May 13, 2025 │ +│ Contact: contato@evolution-api.com │ +├──────────────────────────────────────────────────────────────────────────────┤ +│ @copyright © Evolution API 2025. All rights reserved. │ +│ Licensed under the Apache License, Version 2.0 │ +│ │ +│ You may not use this file except in compliance with the License. │ +│ You may obtain a copy of the License at │ +│ │ +│ http://www.apache.org/licenses/LICENSE-2.0 │ +│ │ +│ Unless required by applicable law or agreed to in writing, software │ +│ distributed under the License is distributed on an "AS IS" BASIS, │ +│ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. │ +│ See the License for the specific language governing permissions and │ +│ limitations under the License. │ +├──────────────────────────────────────────────────────────────────────────────┤ +│ @important │ +│ For any future changes to the code in this file, it is recommended to │ +│ include, together with the modification, the information of the developer │ +│ who changed it and the date of modification. │ +└──────────────────────────────────────────────────────────────────────────────┘ +""" + +from datetime import datetime +import json +import uuid +from typing import Any, Dict, List, Optional + +from sqlalchemy import create_engine, Text +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import ( + sessionmaker, + DeclarativeBase, + Mapped, + mapped_column, +) +from sqlalchemy.sql import func +from sqlalchemy.types import DateTime, String +from sqlalchemy.dialects import postgresql +from sqlalchemy.types import TypeDecorator + +from src.utils.logger import setup_logger + +logger = setup_logger(__name__) + + +class DynamicJSON(TypeDecorator): + """JSON type that uses JSONB in PostgreSQL and TEXT with JSON serialization elsewhere.""" + + impl = Text + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(postgresql.JSONB) + else: + return dialect.type_descriptor(Text) + + def process_bind_param(self, value, dialect): + if value is not None: + if dialect.name == "postgresql": + return value + else: + return json.dumps(value) + return value + + def process_result_value(self, value, dialect): + if value is not None: + if dialect.name == "postgresql": + return value + else: + return json.loads(value) + return value + + +class Base(DeclarativeBase): + pass + + +class AG2StorageSession(Base): + """Stores AG2 conversation sessions in PostgreSQL.""" + + __tablename__ = "ag2_sessions" + + app_name: Mapped[str] = mapped_column(String, primary_key=True) + user_id: Mapped[str] = mapped_column(String, primary_key=True) + id: Mapped[str] = mapped_column( + String, primary_key=True, default=lambda: str(uuid.uuid4()) + ) + messages: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default=[] + ) + create_time: Mapped[DateTime] = mapped_column(DateTime(), default=func.now()) + update_time: Mapped[DateTime] = mapped_column( + DateTime(), default=func.now(), onupdate=func.now() + ) + + def __repr__(self): + return f"" + + +class AG2Session: + """In-memory representation of an AG2 session.""" + + def __init__(self, app_name: str, user_id: str, session_id: str): + self.app_name = app_name + self.user_id = user_id + self.id = session_id + self.messages: List[Dict[str, Any]] = [] + + +class AG2SessionService: + """Session service for AG2 engine — stores conversation history in PostgreSQL.""" + + def __init__(self, db_url: str): + try: + self.engine = create_engine(db_url) + except Exception as e: + raise ValueError(f"Failed to create database engine: {e}") + + Base.metadata.create_all(self.engine) + self.SessionLocal = sessionmaker(bind=self.engine) + logger.info(f"AG2SessionService started with database at {db_url}") + + def get_or_create(self, agent_id: str, external_id: str) -> AG2Session: + """Retrieve an existing session or create a new one.""" + session_id = f"{external_id}_{agent_id}" + with self.SessionLocal() as db: + record = db.get(AG2StorageSession, (agent_id, external_id, session_id)) + if record is None: + record = AG2StorageSession( + app_name=agent_id, + user_id=external_id, + id=session_id, + messages=[], + ) + db.add(record) + db.commit() + db.refresh(record) + logger.info( + f"Created new AG2 session {session_id} for agent {agent_id} / user {external_id}" + ) + + session = AG2Session(app_name=agent_id, user_id=external_id, session_id=session_id) + # Load persisted messages + with self.SessionLocal() as db: + record = db.get(AG2StorageSession, (agent_id, external_id, session_id)) + if record and record.messages: + session.messages = list(record.messages) if isinstance(record.messages, list) else [] + return session + + def build_messages(self, session: AG2Session) -> List[Dict[str, Any]]: + """ + Return the conversation history as a list of AG2-compatible message dicts. + Each entry is {"role": "user"|"assistant", "content": }. + """ + return list(session.messages) + + def append(self, session: AG2Session, role: str, content: str) -> None: + """Append a new message to the in-memory session.""" + session.messages.append({"role": role, "content": content}) + + def save(self, session: AG2Session) -> None: + """Persist the session messages to PostgreSQL.""" + with self.SessionLocal() as db: + record = db.get( + AG2StorageSession, (session.app_name, session.user_id, session.id) + ) + if record is None: + logger.error(f"AG2 session not found for save: {session.id}") + return + record.messages = list(session.messages) + db.commit() + db.refresh(record) + logger.info( + f"AG2 session {session.id} saved with {len(session.messages)} messages" + ) + + def delete(self, agent_id: str, external_id: str) -> bool: + """Delete all sessions for a given agent + user pair.""" + from sqlalchemy import delete as sa_delete + + session_id = f"{external_id}_{agent_id}" + with self.SessionLocal() as db: + stmt = sa_delete(AG2StorageSession).where( + AG2StorageSession.app_name == agent_id, + AG2StorageSession.user_id == external_id, + AG2StorageSession.id == session_id, + ) + result = db.execute(stmt) + db.commit() + logger.info(f"AG2 session {session_id} deleted") + return result.rowcount > 0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/services/ag2/__init__.py b/tests/services/ag2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/services/ag2/test_agent_builder.py b/tests/services/ag2/test_agent_builder.py new file mode 100644 index 00000000..2c57dd32 --- /dev/null +++ b/tests/services/ag2/test_agent_builder.py @@ -0,0 +1,58 @@ +"""Unit tests for AG2AgentBuilder — no LLM calls, no DB, no network.""" +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from src.services.ag2.agent_builder import AG2AgentBuilder + +@pytest.fixture +def mock_db(): + return AsyncMock() + +def _make_agent(agent_id, agent_name, ag2_mode="single", model="gpt-4o-mini"): + """Create a minimal agent record mock with explicit attribute assignment.""" + a = MagicMock() + a.id = agent_id + a.name = agent_name + a.type = "llm" + a.llm_provider = "openai" + a.model = model + a.config = {"ag2_mode": ag2_mode, "api_key": "test-key"} + a.api_key = "test-key" + a.api_key_id = None + a.api_url = "https://api.openai.com/v1" + a.role = None + a.goal = None + a.instruction = f"You are {agent_name}." + return a + +@pytest.fixture +def group_chat_agent_record(): + """Group-chat root agent with two sub-agents.""" + a = _make_agent("agent-123", "Support Team") + a.config = { + "ag2_mode": "group_chat", + "sub_agents": ["uuid-triage", "uuid-specialist"], + "max_rounds": 10, + "pattern": "auto", + "api_key": "test-key", + } + return a + +@pytest.mark.asyncio +async def test_builder_creates_group_chat(mock_db, group_chat_agent_record): + builder = AG2AgentBuilder(db=mock_db) + sub_agent = _make_agent("uuid-triage", "triage") + with patch("src.services.ag2.agent_builder.get_agent", return_value=sub_agent): + result, _ = await builder.build_agent(group_chat_agent_record) + # result should be a dict with pattern and agents for initiate_group_chat + assert "pattern" in result + assert "agents" in result + +@pytest.mark.asyncio +async def test_builder_validates_group_chat_requires_agents(mock_db): + record = MagicMock() + record.type = "llm" + record.config = {"ag2_mode": "group_chat", "sub_agents": []} # empty — should raise + record.api_key = "test" + builder = AG2AgentBuilder(db=mock_db) + with pytest.raises(ValueError, match="at least one"): + await builder.build_agent(record) diff --git a/tests/services/ag2/test_agent_runner.py b/tests/services/ag2/test_agent_runner.py new file mode 100644 index 00000000..4a8c284e --- /dev/null +++ b/tests/services/ag2/test_agent_runner.py @@ -0,0 +1,34 @@ +"""Unit tests for AG2AgentRunner streaming interface.""" +import json +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from src.services.ag2.agent_runner import run_agent_stream + +@pytest.mark.asyncio +async def test_runner_yields_response_chunks(): + mock_result = { + "final_response": "Resolved: your issue is fixed.", + "message_history": [], + } + mock_settings = MagicMock() + mock_settings.POSTGRES_CONNECTION_STRING = "sqlite://" + + # Patch run_agent (the heavy work), session service, and settings — + # get_settings and AG2SessionService are local imports inside run_agent_stream + with patch("src.services.ag2.agent_runner.run_agent", AsyncMock(return_value=mock_result)): + with patch("src.services.ag2.session_service.AG2SessionService"): + with patch("src.config.settings.get_settings", return_value=mock_settings): + chunks = [] + async for chunk in run_agent_stream( + db=AsyncMock(), + agent_id="agent-123", + external_id="ext-123", + session_id="session-abc", + message="My printer is broken", + ): + chunks.append(chunk) + + assert len(chunks) >= 1 + data = json.loads(chunks[0]) + assert data["content"]["parts"][0]["text"] == "Resolved: your issue is fixed." + assert data["is_final"] is True From 8720544b69e24bc916d07a841e8eff5533e765b5 Mon Sep 17 00:00:00 2001 From: Vasiliy Radostev Date: Wed, 11 Mar 2026 23:48:31 -0700 Subject: [PATCH 2/5] feat(ag2): wire AG2 engine into chat_routes and service_providers Co-Authored-By: Claude Sonnet 4.6 --- pyproject.toml | 1 + src/api/chat_routes.py | 40 +++++++++++++++++++++++-------- src/services/service_providers.py | 3 +++ 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 53f43721..039f5bf6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "crewai==0.120.1", "crewai-tools==0.45.0", "a2a-sdk==0.2.4", + "ag2[openai]>=0.11.0", ] [project.optional-dependencies] diff --git a/src/api/chat_routes.py b/src/api/chat_routes.py index e34ab755..b120932b 100644 --- a/src/api/chat_routes.py +++ b/src/api/chat_routes.py @@ -52,6 +52,7 @@ from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData from src.services.adk.agent_runner import run_agent as run_agent_adk, run_agent_stream from src.services.crewai.agent_runner import run_agent as run_agent_crewai +from src.services.ag2.agent_runner import run_agent as run_agent_ag2, run_agent_stream as run_agent_stream_ag2 from src.core.exceptions import AgentNotFoundError from src.services.service_providers import ( session_service, @@ -221,16 +222,26 @@ async def websocket_chat( logger.error(f"Error processing files: {str(e)}") files = None - async for chunk in run_agent_stream( - agent_id=agent_id, - external_id=external_id, - message=message, - session_service=session_service, - artifacts_service=artifacts_service, - memory_service=memory_service, - db=db, - files=files, - ): + if settings.AI_ENGINE == "ag2": + stream_gen = run_agent_stream_ag2( + agent_id=agent_id, + external_id=external_id, + message=message, + db=db, + files=files, + ) + else: + stream_gen = run_agent_stream( + agent_id=agent_id, + external_id=external_id, + message=message, + session_service=session_service, + artifacts_service=artifacts_service, + memory_service=memory_service, + db=db, + files=files, + ) + async for chunk in stream_gen: await websocket.send_json( {"message": json.loads(chunk), "turn_complete": False} ) @@ -300,6 +311,15 @@ async def chat( db, files=request.files, ) + elif settings.AI_ENGINE == "ag2": + final_response = await run_agent_ag2( + agent_id, + external_id, + request.message, + session_service, + db, + files=request.files, + ) return { "response": final_response["final_response"], diff --git a/src/services/service_providers.py b/src/services/service_providers.py index 78d3acf2..35ba1f96 100644 --- a/src/services/service_providers.py +++ b/src/services/service_providers.py @@ -36,9 +36,12 @@ load_dotenv() from src.services.crewai.session_service import CrewSessionService +from src.services.ag2.session_service import AG2SessionService if os.getenv("AI_ENGINE") == "crewai": session_service = CrewSessionService(db_url=os.getenv("POSTGRES_CONNECTION_STRING")) +elif os.getenv("AI_ENGINE") == "ag2": + session_service = AG2SessionService(db_url=os.getenv("POSTGRES_CONNECTION_STRING")) else: session_service = DatabaseSessionService( db_url=os.getenv("POSTGRES_CONNECTION_STRING") From b55724f1fc40045f383482c28752ce4166ed45fe Mon Sep 17 00:00:00 2001 From: Vasiliy Radostev Date: Wed, 11 Mar 2026 23:54:45 -0700 Subject: [PATCH 3/5] docs: add AG2 engine to root README supported frameworks list Co-Authored-By: Claude Sonnet 4.6 --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c75cbd60..20dfff7d 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ The Evo AI platform allows: - Custom tools management - **[Google Agent Development Kit (ADK)](https://google.github.io/adk-docs/)**: Base framework for agent development - **[CrewAI Support](https://github.com/crewAI/crewAI)**: Alternative framework for agent development (in development) +- **[AG2 (formerly AutoGen)](https://github.com/ag2ai/ag2)**: Dynamic GroupChat, context-variable handoffs, and human-in-the-loop (`AI_ENGINE=ag2`) - JWT authentication with email verification - **[Agent 2 Agent (A2A) Protocol Support](https://developers.googleblog.com/en/a2a-a-new-era-of-agent-interoperability/)**: Interoperability between AI agents - **[Workflow Agent with LangGraph](https://www.langchain.com/langgraph)**: Building complex agent workflows From 89ba46ebcb6cf967caf83ff74e73aabb41bc269c Mon Sep 17 00:00:00 2001 From: Vasiliy Radostev Date: Fri, 13 Mar 2026 08:48:29 -0700 Subject: [PATCH 4/5] fix(ag2): address sourcery-ai review comments on PR #46 - session_service: use MutableList (not MutableDict), add cache_ok=True, fix mutable default (lambda: []), merge double DB roundtrip into one session, propagate session_id through get_or_create - agent_builder: safe handoff type access via h.get("type") with skip/log on unknown values; cache db_sub_agents to eliminate N+1 DB queries - agent_runner: fix group_chat message format (dict, not bare string), propagate session_id; run_agent_stream now accepts session_service as parameter instead of creating it per-request - chat_routes: pass session_service= to run_agent_stream_ag2 call - mcp_service: remove redundant os.environ mutation and unused import os - custom_tool: wrap response.json() in try/except ValueError for non-JSON HTTP responses - tests: expand agent_builder tests to 10 cases (single mode, default mode, name sanitization, system message, _get_api_key branches, _apply_handoffs validation); add envelope contract assertions and session history ordering test to agent_runner tests Co-Authored-By: Claude Sonnet 4.6 --- src/api/chat_routes.py | 1 + src/services/ag2/agent_builder.py | 24 +++-- src/services/ag2/agent_runner.py | 11 +- src/services/ag2/custom_tool.py | 9 +- src/services/ag2/mcp_service.py | 4 - src/services/ag2/session_service.py | 35 ++++--- tests/services/ag2/test_agent_builder.py | 125 ++++++++++++++++++++++- tests/services/ag2/test_agent_runner.py | 85 ++++++++++++--- 8 files changed, 238 insertions(+), 56 deletions(-) diff --git a/src/api/chat_routes.py b/src/api/chat_routes.py index b120932b..877d8501 100644 --- a/src/api/chat_routes.py +++ b/src/api/chat_routes.py @@ -227,6 +227,7 @@ async def websocket_chat( agent_id=agent_id, external_id=external_id, message=message, + session_service=session_service, db=db, files=files, ) diff --git a/src/services/ag2/agent_builder.py b/src/services/ag2/agent_builder.py index 4b5c7ef5..b40267ea 100644 --- a/src/services/ag2/agent_builder.py +++ b/src/services/ag2/agent_builder.py @@ -98,19 +98,24 @@ def _apply_handoffs(self, ca: ConversableAgent, config: dict, all_agents: dict): logger.warning(f"Handoff target {target_id} not found, skipping") continue - if h["type"] == "llm": + h_type = h.get("type") + if h_type not in ("llm", "context"): + logger.warning(f"Unknown or missing handoff type {h_type!r} for target {target_id}, skipping") + continue + + if h_type == "llm": llm_conditions.append( OnCondition( target=AgentTarget(target_agent), - condition=StringLLMCondition(prompt=h["condition"]), + condition=StringLLMCondition(prompt=h.get("condition", "")), ) ) - elif h["type"] == "context": + elif h_type == "context": context_conditions.append( OnContextCondition( target=AgentTarget(target_agent), condition=ExpressionContextCondition( - expression=ContextExpression(h["expression"]) + expression=ContextExpression(h.get("expression", "")) ), ) ) @@ -136,13 +141,16 @@ async def build_group_chat_setup(self, root_agent) -> dict: if not sub_agent_ids: raise ValueError("group_chat agent requires at least one sub_agent") - # Build all sub-agents first so handoff resolution can reference them - all_agents = {} + # Build all sub-agents first so handoff resolution can reference them. + # Cache db_agent records to avoid re-fetching them in the handoff pass. + all_agents: dict = {} agents = [] + db_sub_agents: dict = {} for aid in sub_agent_ids: db_agent = get_agent(self.db, str(aid)) if db_agent is None: raise ValueError(f"Sub-agent {aid} not found") + db_sub_agents[str(aid)] = db_agent ca = await self.build_conversable_agent(db_agent) all_agents[str(aid)] = ca agents.append(ca) @@ -150,9 +158,9 @@ async def build_group_chat_setup(self, root_agent) -> dict: root_ca = await self.build_conversable_agent(root_agent) all_agents[str(root_agent.id)] = root_ca - # Apply handoffs to each agent if configured + # Apply handoffs using the already-fetched db_agent records for aid in sub_agent_ids: - db_agent = get_agent(self.db, str(aid)) + db_agent = db_sub_agents.get(str(aid)) if db_agent and db_agent.config: self._apply_handoffs(all_agents[str(aid)], db_agent.config, all_agents) diff --git a/src/services/ag2/agent_runner.py b/src/services/ag2/agent_runner.py index 056e2b20..7ae07da1 100644 --- a/src/services/ag2/agent_runner.py +++ b/src/services/ag2/agent_runner.py @@ -21,7 +21,6 @@ async def run_agent( session_service: AG2SessionService, db: Session, session_id: Optional[str] = None, - timeout: float = 60.0, files: Optional[list] = None, ) -> dict: tracer = get_tracer() @@ -37,7 +36,7 @@ async def run_agent( result, _ = await builder.build_agent(db_agent) # Reconstruct conversation history as AG2 message list - session = session_service.get_or_create(agent_id, external_id) + session = session_service.get_or_create(agent_id, external_id, session_id) history = session_service.build_messages(session) try: @@ -45,7 +44,7 @@ async def run_agent( if ag2_mode == "group_chat": chat_result, final_context, last_agent = initiate_group_chat( pattern=result["pattern"], - messages=history + [message], + messages=history + [{"role": "user", "content": message}], max_rounds=result["max_rounds"], context_variables=result["context_variables"], ) @@ -99,6 +98,7 @@ async def run_agent_stream( agent_id: str, external_id: str, message: str, + session_service: AG2SessionService, db: Session, session_id: Optional[str] = None, files: Optional[list] = None, @@ -111,11 +111,6 @@ async def run_agent_stream( Token-level streaming can be added in a future iteration by wiring ConversableAgent's `process_last_received_message` hook to a queue. """ - from src.services.ag2.session_service import AG2SessionService - from src.config.settings import get_settings - settings = get_settings() - session_service = AG2SessionService(db_url=settings.POSTGRES_CONNECTION_STRING) - result = await run_agent( agent_id=agent_id, external_id=external_id, diff --git a/src/services/ag2/custom_tool.py b/src/services/ag2/custom_tool.py index cf673cb1..322dacba 100644 --- a/src/services/ag2/custom_tool.py +++ b/src/services/ag2/custom_tool.py @@ -111,7 +111,14 @@ def http_tool(**kwargs) -> str: f"Error in the request: {response.status_code} - {response.text}" ) - return json.dumps(response.json()) + try: + response_data = response.json() + except ValueError: + response_data = { + "status_code": response.status_code, + "raw_response": response.text, + } + return json.dumps(response_data) except Exception as e: logger.error(f"Error executing tool {name}: {str(e)}") diff --git a/src/services/ag2/mcp_service.py b/src/services/ag2/mcp_service.py index c469156d..5a93d0e5 100644 --- a/src/services/ag2/mcp_service.py +++ b/src/services/ag2/mcp_service.py @@ -28,7 +28,6 @@ """ from typing import Any, Callable, Dict, List, Optional, Tuple -import os from src.utils.logger import setup_logger from src.services.mcp_server_service import get_mcp_server from sqlalchemy.orm import Session @@ -156,9 +155,6 @@ async def _connect_server( command = server_config.get("command", "npx") args = server_config.get("args", []) env = server_config.get("env", {}) - if env: - for key, value in env.items(): - os.environ[key] = value server = McpServer({"command": command, "args": args, "env": env}) tools = await server.list_tools() diff --git a/src/services/ag2/session_service.py b/src/services/ag2/session_service.py index a5a4a2d9..eee42daf 100644 --- a/src/services/ag2/session_service.py +++ b/src/services/ag2/session_service.py @@ -33,7 +33,7 @@ from typing import Any, Dict, List, Optional from sqlalchemy import create_engine, Text -from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.ext.mutable import MutableList from sqlalchemy.orm import ( sessionmaker, DeclarativeBase, @@ -54,6 +54,7 @@ class DynamicJSON(TypeDecorator): """JSON type that uses JSONB in PostgreSQL and TEXT with JSON serialization elsewhere.""" impl = Text + cache_ok = True def load_dialect_impl(self, dialect): if dialect.name == "postgresql": @@ -92,8 +93,8 @@ class AG2StorageSession(Base): id: Mapped[str] = mapped_column( String, primary_key=True, default=lambda: str(uuid.uuid4()) ) - messages: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(DynamicJSON), default=[] + messages: Mapped[List[Any]] = mapped_column( + MutableList.as_mutable(DynamicJSON), default=lambda: [] ) create_time: Mapped[DateTime] = mapped_column(DateTime(), default=func.now()) update_time: Mapped[DateTime] = mapped_column( @@ -127,31 +128,35 @@ def __init__(self, db_url: str): self.SessionLocal = sessionmaker(bind=self.engine) logger.info(f"AG2SessionService started with database at {db_url}") - def get_or_create(self, agent_id: str, external_id: str) -> AG2Session: - """Retrieve an existing session or create a new one.""" - session_id = f"{external_id}_{agent_id}" + def get_or_create( + self, agent_id: str, external_id: str, session_id: Optional[str] = None + ) -> AG2Session: + """Retrieve an existing session or create a new one. + + All work is done in a single DB session to avoid extra roundtrips and + the race window that exists between two separate transactions. + """ + sid = session_id or f"{external_id}_{agent_id}" with self.SessionLocal() as db: - record = db.get(AG2StorageSession, (agent_id, external_id, session_id)) + record = db.get(AG2StorageSession, (agent_id, external_id, sid)) if record is None: record = AG2StorageSession( app_name=agent_id, user_id=external_id, - id=session_id, + id=sid, messages=[], ) db.add(record) db.commit() db.refresh(record) logger.info( - f"Created new AG2 session {session_id} for agent {agent_id} / user {external_id}" + f"Created new AG2 session {sid} for agent {agent_id} / user {external_id}" ) + # Access messages while the session is still open so lazy loading works. + messages = list(record.messages) if isinstance(record.messages, list) else [] - session = AG2Session(app_name=agent_id, user_id=external_id, session_id=session_id) - # Load persisted messages - with self.SessionLocal() as db: - record = db.get(AG2StorageSession, (agent_id, external_id, session_id)) - if record and record.messages: - session.messages = list(record.messages) if isinstance(record.messages, list) else [] + session = AG2Session(app_name=agent_id, user_id=external_id, session_id=sid) + session.messages = messages return session def build_messages(self, session: AG2Session) -> List[Dict[str, Any]]: diff --git a/tests/services/ag2/test_agent_builder.py b/tests/services/ag2/test_agent_builder.py index 2c57dd32..7f0d44e0 100644 --- a/tests/services/ag2/test_agent_builder.py +++ b/tests/services/ag2/test_agent_builder.py @@ -1,12 +1,16 @@ """Unit tests for AG2AgentBuilder — no LLM calls, no DB, no network.""" +import asyncio import pytest from unittest.mock import MagicMock, AsyncMock, patch +from autogen import ConversableAgent from src.services.ag2.agent_builder import AG2AgentBuilder + @pytest.fixture def mock_db(): return AsyncMock() + def _make_agent(agent_id, agent_name, ag2_mode="single", model="gpt-4o-mini"): """Create a minimal agent record mock with explicit attribute assignment.""" a = MagicMock() @@ -22,12 +26,14 @@ def _make_agent(agent_id, agent_name, ag2_mode="single", model="gpt-4o-mini"): a.role = None a.goal = None a.instruction = f"You are {agent_name}." + a.description = "" return a + @pytest.fixture def group_chat_agent_record(): """Group-chat root agent with two sub-agents.""" - a = _make_agent("agent-123", "Support Team") + a = _make_agent("agent-123", "Support_Team") a.config = { "ag2_mode": "group_chat", "sub_agents": ["uuid-triage", "uuid-specialist"], @@ -37,22 +43,135 @@ def group_chat_agent_record(): } return a + +# --------------------------------------------------------------------------- +# Group-chat mode +# --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_builder_creates_group_chat(mock_db, group_chat_agent_record): builder = AG2AgentBuilder(db=mock_db) sub_agent = _make_agent("uuid-triage", "triage") with patch("src.services.ag2.agent_builder.get_agent", return_value=sub_agent): result, _ = await builder.build_agent(group_chat_agent_record) - # result should be a dict with pattern and agents for initiate_group_chat assert "pattern" in result assert "agents" in result + @pytest.mark.asyncio async def test_builder_validates_group_chat_requires_agents(mock_db): record = MagicMock() record.type = "llm" - record.config = {"ag2_mode": "group_chat", "sub_agents": []} # empty — should raise + record.config = {"ag2_mode": "group_chat", "sub_agents": []} record.api_key = "test" builder = AG2AgentBuilder(db=mock_db) with pytest.raises(ValueError, match="at least one"): await builder.build_agent(record) + + +# --------------------------------------------------------------------------- +# Single-agent mode +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_builder_single_mode_returns_conversable_agent(mock_db): + """build_agent with ag2_mode=single returns a ConversableAgent, not a dict.""" + record = _make_agent("agent-abc", "Support_Bot", ag2_mode="single") + builder = AG2AgentBuilder(db=mock_db) + agent, meta = await builder.build_agent(record) + assert isinstance(agent, ConversableAgent) + assert meta is None + + +@pytest.mark.asyncio +async def test_builder_default_mode_is_single(mock_db): + """Omitting ag2_mode in config defaults to single-agent mode.""" + record = _make_agent("agent-abc", "My_Bot") + record.config = {"api_key": "test-key"} # no ag2_mode key + builder = AG2AgentBuilder(db=mock_db) + agent, _ = await builder.build_agent(record) + assert isinstance(agent, ConversableAgent) + + +@pytest.mark.asyncio +async def test_builder_sanitizes_agent_name(mock_db): + """Spaces in agent names are replaced with underscores for AG2 compatibility.""" + record = _make_agent("agent-abc", "Support Team Agent") + builder = AG2AgentBuilder(db=mock_db) + agent, _ = await builder.build_agent(record) + assert " " not in agent.name + assert agent.name == "Support_Team_Agent" + + +@pytest.mark.asyncio +async def test_builder_system_message_includes_role_goal_instruction(mock_db): + """System message is composed from role, goal, and instruction fields.""" + record = _make_agent("agent-abc", "Bot") + record.role = "billing specialist" + record.goal = "resolve billing disputes" + record.instruction = "Always be concise." + builder = AG2AgentBuilder(db=mock_db) + agent, _ = await builder.build_agent(record) + sm = agent.system_message + assert "billing specialist" in sm + assert "resolve billing disputes" in sm + assert "Always be concise." in sm + + +# --------------------------------------------------------------------------- +# _get_api_key branches +# --------------------------------------------------------------------------- + +def test_get_api_key_uses_api_key_id(mock_db): + """api_key_id takes priority and resolves via get_decrypted_api_key.""" + record = _make_agent("agent-abc", "Bot") + record.api_key_id = "11111111-2222-3333-4444-555555555555" + record.config = {} + builder = AG2AgentBuilder(db=mock_db) + with patch( + "src.services.ag2.agent_builder.get_decrypted_api_key", + return_value="resolved-key", + ): + key = asyncio.run(builder._get_api_key(record)) + assert key == "resolved-key" + + +def test_get_api_key_falls_back_to_raw_string(mock_db): + """Non-UUID api_key in config is returned as-is.""" + record = _make_agent("agent-abc", "Bot") + record.api_key_id = None + record.config = {"api_key": "raw-openai-key"} + builder = AG2AgentBuilder(db=mock_db) + key = asyncio.run(builder._get_api_key(record)) + assert key == "raw-openai-key" + + +def test_get_api_key_raises_when_no_key_configured(mock_db): + """ValueError raised when neither api_key_id nor config api_key is set.""" + record = _make_agent("agent-abc", "Bot") + record.api_key_id = None + record.config = {} + builder = AG2AgentBuilder(db=mock_db) + with pytest.raises(ValueError, match="No API key"): + asyncio.run(builder._get_api_key(record)) + + +# --------------------------------------------------------------------------- +# _apply_handoffs — validation +# --------------------------------------------------------------------------- + +def test_apply_handoffs_skips_missing_type(mock_db): + """Handoff entries with missing 'type' are skipped without raising.""" + builder = AG2AgentBuilder(db=mock_db) + ca = MagicMock() + # Should not raise even with no 'type' key + builder._apply_handoffs(ca, {"handoffs": [{"target_agent_id": "x"}]}, {}) + + +def test_apply_handoffs_skips_unknown_type(mock_db): + """Handoff entries with unknown type value are skipped without raising.""" + builder = AG2AgentBuilder(db=mock_db) + ca = MagicMock() + builder._apply_handoffs( + ca, {"handoffs": [{"type": "unknown", "target_agent_id": "x"}]}, {} + ) diff --git a/tests/services/ag2/test_agent_runner.py b/tests/services/ag2/test_agent_runner.py index 4a8c284e..4f1b0d5a 100644 --- a/tests/services/ag2/test_agent_runner.py +++ b/tests/services/ag2/test_agent_runner.py @@ -1,8 +1,9 @@ """Unit tests for AG2AgentRunner streaming interface.""" import json import pytest -from unittest.mock import MagicMock, AsyncMock, patch -from src.services.ag2.agent_runner import run_agent_stream +from unittest.mock import MagicMock, AsyncMock, patch, call +from src.services.ag2.agent_runner import run_agent_stream, run_agent + @pytest.mark.asyncio async def test_runner_yields_response_chunks(): @@ -10,25 +11,75 @@ async def test_runner_yields_response_chunks(): "final_response": "Resolved: your issue is fixed.", "message_history": [], } - mock_settings = MagicMock() - mock_settings.POSTGRES_CONNECTION_STRING = "sqlite://" + mock_session_service = MagicMock() - # Patch run_agent (the heavy work), session service, and settings — - # get_settings and AG2SessionService are local imports inside run_agent_stream with patch("src.services.ag2.agent_runner.run_agent", AsyncMock(return_value=mock_result)): - with patch("src.services.ag2.session_service.AG2SessionService"): - with patch("src.config.settings.get_settings", return_value=mock_settings): - chunks = [] - async for chunk in run_agent_stream( - db=AsyncMock(), - agent_id="agent-123", - external_id="ext-123", - session_id="session-abc", - message="My printer is broken", - ): - chunks.append(chunk) + chunks = [] + async for chunk in run_agent_stream( + db=AsyncMock(), + agent_id="agent-123", + external_id="ext-123", + session_service=mock_session_service, + session_id="session-abc", + message="My printer is broken", + ): + chunks.append(chunk) assert len(chunks) >= 1 data = json.loads(chunks[0]) + # Validate full streaming envelope contract + assert data["content"]["role"] == "agent" + assert data["author"] == "agent-123" + assert data["content"]["parts"][0]["type"] == "text" assert data["content"]["parts"][0]["text"] == "Resolved: your issue is fixed." assert data["is_final"] is True + + +@pytest.mark.asyncio +async def test_run_agent_persists_session_history_in_order(): + """append is called user-then-assistant, save is called once.""" + mock_db = MagicMock() + mock_session_service = MagicMock() + mock_session = MagicMock() + mock_session_service.get_or_create.return_value = mock_session + mock_session_service.build_messages.return_value = [] + + user_message = "My printer is broken" + + mock_agent_record = MagicMock() + mock_agent_record.config = {"ag2_mode": "single"} + + mock_chat_result = MagicMock() + mock_chat_result.chat_history = [ + {"role": "user", "content": user_message}, + {"role": "assistant", "content": "Try turning it off and on again."}, + ] + + with patch("src.services.ag2.agent_runner.get_agent", return_value=mock_agent_record), \ + patch("src.services.ag2.agent_runner.AG2AgentBuilder") as MockBuilder, \ + patch("src.services.ag2.agent_runner.ConversableAgent") as MockProxy, \ + patch("src.services.ag2.agent_runner.asyncio") as mock_asyncio: + + builder_instance = MockBuilder.return_value + builder_instance.build_agent = AsyncMock(return_value=(MagicMock(), None)) + + proxy_instance = MockProxy.return_value + mock_asyncio.get_event_loop.return_value.run_in_executor = AsyncMock( + return_value=mock_chat_result + ) + + result = await run_agent( + agent_id="agent-123", + external_id="ext-123", + message=user_message, + session_service=mock_session_service, + db=mock_db, + session_id="session-abc", + ) + + final_response = result["final_response"] + mock_session_service.append.assert_has_calls([ + call(mock_session, "user", user_message), + call(mock_session, "assistant", final_response), + ]) + mock_session_service.save.assert_called_once_with(mock_session) From 927672c45cb812a3c30fda6aeb9cfd0e03f128a5 Mon Sep 17 00:00:00 2001 From: Vasiliy Radostev Date: Fri, 13 Mar 2026 10:51:26 -0700 Subject: [PATCH 5/5] test(ag2): add handoff, error, and exception propagation coverage - _apply_handoffs: test LLM condition registration, context condition registration, after_work=terminate (TerminateTarget), and default after_work (RevertToUserTarget) - build_group_chat_setup: test ValueError raised on missing sub-agent - run_agent: test AgentNotFoundError when agent DB record is missing - run_agent: test RuntimeError from initiate_group_chat wrapped as InternalServerError - run_agent_stream: test that InternalServerError from run_agent propagates Addresses reviewer comments 12 and 13 on PR #46. Co-Authored-By: Claude Sonnet 4.6 --- tests/services/ag2/test_agent_builder.py | 67 ++++++++++++++++++++++++ tests/services/ag2/test_agent_runner.py | 67 ++++++++++++++++++++++++ 2 files changed, 134 insertions(+) diff --git a/tests/services/ag2/test_agent_builder.py b/tests/services/ag2/test_agent_builder.py index 7f0d44e0..dc2077b5 100644 --- a/tests/services/ag2/test_agent_builder.py +++ b/tests/services/ag2/test_agent_builder.py @@ -3,6 +3,7 @@ import pytest from unittest.mock import MagicMock, AsyncMock, patch from autogen import ConversableAgent +from autogen.agentchat.group import TerminateTarget, RevertToUserTarget from src.services.ag2.agent_builder import AG2AgentBuilder @@ -175,3 +176,69 @@ def test_apply_handoffs_skips_unknown_type(mock_db): builder._apply_handoffs( ca, {"handoffs": [{"type": "unknown", "target_agent_id": "x"}]}, {} ) + + +def test_apply_handoffs_registers_llm_condition(mock_db): + """LLM-type handoff adds an OnCondition via add_llm_conditions.""" + builder = AG2AgentBuilder(db=mock_db) + ca = MagicMock() + target = MagicMock() + target.name = "target_agent" # AgentTarget validates agent_name as str + all_agents = {"target-uuid": target} + config = { + "handoffs": [ + {"type": "llm", "target_agent_id": "target-uuid", "condition": "user asks about billing"} + ] + } + builder._apply_handoffs(ca, config, all_agents) + ca.handoffs.add_llm_conditions.assert_called_once() + conditions = ca.handoffs.add_llm_conditions.call_args[0][0] + assert len(conditions) == 1 + + +def test_apply_handoffs_registers_context_condition(mock_db): + """Context-type handoff adds an OnContextCondition via add_context_conditions.""" + builder = AG2AgentBuilder(db=mock_db) + ca = MagicMock() + target = MagicMock() + target.name = "target_agent" # AgentTarget validates agent_name as str + all_agents = {"target-uuid": target} + config = { + "handoffs": [ + {"type": "context", "target_agent_id": "target-uuid", "expression": "${is_vip} == True"} + ] + } + builder._apply_handoffs(ca, config, all_agents) + ca.handoffs.add_context_conditions.assert_called_once() + conditions = ca.handoffs.add_context_conditions.call_args[0][0] + assert len(conditions) == 1 + + +def test_apply_handoffs_after_work_terminate(mock_db): + """after_work='terminate' sets TerminateTarget on the agent.""" + builder = AG2AgentBuilder(db=mock_db) + ca = MagicMock() + builder._apply_handoffs(ca, {"after_work": "terminate", "handoffs": []}, {}) + call_arg = ca.handoffs.set_after_work.call_args[0][0] + assert isinstance(call_arg, TerminateTarget) + + +def test_apply_handoffs_after_work_default_reverts_to_user(mock_db): + """Omitting after_work defaults to RevertToUserTarget.""" + builder = AG2AgentBuilder(db=mock_db) + ca = MagicMock() + builder._apply_handoffs(ca, {"handoffs": []}, {}) + call_arg = ca.handoffs.set_after_work.call_args[0][0] + assert isinstance(call_arg, RevertToUserTarget) + + +@pytest.mark.asyncio +async def test_build_group_chat_setup_raises_on_missing_sub_agent(mock_db): + """build_group_chat_setup raises ValueError when a sub-agent is not found in the DB.""" + record = MagicMock() + record.config = {"ag2_mode": "group_chat", "sub_agents": ["missing-uuid"]} + record.api_key = "test" + builder = AG2AgentBuilder(db=mock_db) + with patch("src.services.ag2.agent_builder.get_agent", return_value=None): + with pytest.raises(ValueError, match="not found"): + await builder.build_group_chat_setup(record) diff --git a/tests/services/ag2/test_agent_runner.py b/tests/services/ag2/test_agent_runner.py index 4f1b0d5a..ff50b4e7 100644 --- a/tests/services/ag2/test_agent_runner.py +++ b/tests/services/ag2/test_agent_runner.py @@ -3,6 +3,7 @@ import pytest from unittest.mock import MagicMock, AsyncMock, patch, call from src.services.ag2.agent_runner import run_agent_stream, run_agent +from src.core.exceptions import AgentNotFoundError, InternalServerError @pytest.mark.asyncio @@ -83,3 +84,69 @@ async def test_run_agent_persists_session_history_in_order(): call(mock_session, "assistant", final_response), ]) mock_session_service.save.assert_called_once_with(mock_session) + + +@pytest.mark.asyncio +async def test_run_agent_raises_agent_not_found(): + """run_agent raises AgentNotFoundError when the agent DB record is missing.""" + with patch("src.services.ag2.agent_runner.get_agent", return_value=None): + with pytest.raises(AgentNotFoundError): + await run_agent( + agent_id="missing-agent", + external_id="ext-123", + message="hello", + session_service=MagicMock(), + db=AsyncMock(), + ) + + +@pytest.mark.asyncio +async def test_run_agent_stream_propagates_exception(): + """run_agent_stream propagates exceptions raised by run_agent.""" + with patch( + "src.services.ag2.agent_runner.run_agent", + AsyncMock(side_effect=InternalServerError("boom")), + ): + with pytest.raises(InternalServerError): + async for _ in run_agent_stream( + agent_id="agent-123", + external_id="ext-123", + message="hello", + session_service=MagicMock(), + db=AsyncMock(), + ): + pass + + +@pytest.mark.asyncio +async def test_run_agent_group_chat_failure_raises_internal_error(): + """Exceptions from initiate_group_chat are wrapped and re-raised as InternalServerError.""" + mock_agent_record = MagicMock() + mock_agent_record.config = {"ag2_mode": "group_chat"} + + mock_session_service = MagicMock() + mock_session_service.get_or_create.return_value = MagicMock() + mock_session_service.build_messages.return_value = [] + + group_chat_setup = { + "pattern": MagicMock(), + "max_rounds": 10, + "context_variables": MagicMock(), + } + + with patch("src.services.ag2.agent_runner.get_agent", return_value=mock_agent_record), \ + patch("src.services.ag2.agent_runner.AG2AgentBuilder") as MockBuilder, \ + patch( + "src.services.ag2.agent_runner.initiate_group_chat", + side_effect=RuntimeError("llm failure"), + ): + MockBuilder.return_value.build_agent = AsyncMock(return_value=(group_chat_setup, None)) + + with pytest.raises(InternalServerError, match="llm failure"): + await run_agent( + agent_id="agent-123", + external_id="ext-123", + message="hello", + session_service=mock_session_service, + db=AsyncMock(), + )