From 04b3909fdbc6aff1d2c20201be03b1b417fc1fc0 Mon Sep 17 00:00:00 2001 From: anandhkb Date: Wed, 13 May 2026 19:29:51 -0700 Subject: [PATCH] Fix necessary to address the proper handling of the agent's response structure by the API Catalog's UI --- .../nat_cuopt_agent/function/deepagent_fn.py | 159 ++++++++++++++---- 1 file changed, 125 insertions(+), 34 deletions(-) diff --git a/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py b/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py index 7f879c7..b774dbc 100755 --- a/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py +++ b/cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py @@ -13,9 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import logging import os import re +import uuid +from collections.abc import AsyncGenerator, AsyncIterator +from contextlib import asynccontextmanager from pathlib import Path from tempfile import TemporaryDirectory @@ -23,7 +27,14 @@ from nat.builder.framework_enum import LLMFrameworkEnum from nat.builder.function_info import FunctionInfo from nat.cli.register_workflow import register_function -from nat.data_models.api_server import ChatRequest, ChatRequestOrMessage, ChatResponse, Usage +from nat.data_models.api_server import ( + ChatRequest, + ChatRequestOrMessage, + ChatResponse, + ChatResponseChunk, + Usage, + UserMessageContentRoleType, +) from nat.data_models.component_ref import FunctionRef, LLMRef from nat.data_models.function import FunctionBaseConfig from nat.utils.type_converter import GlobalTypeConverter @@ -183,33 +194,21 @@ async def deep_agent(config: DeepAgentConfig, builder: Builder): # Workaround to strip reasoning patterns from the final response with minimax model strip_re = re.compile(config.strip_reasoning_pattern, re.DOTALL) if config.strip_reasoning_pattern else None - # Inner function that handles the agent invocation and response processing - async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str: - """Inner function that handles the agent invocation and response processing. - Args: - chat_request_or_message: The chat request or message to process. - Returns: - A chat response or string. - """ - chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest) + @asynccontextmanager + async def _agent_session( + chat_request: ChatRequest, + ) -> AsyncIterator[tuple[object, list]]: + """Yield (agent, messages_dict_list) inside a sandbox; cleans up child processes on exit.""" messages = [m.model_dump() for m in chat_request.messages] - - # Create a temporary sandbox directory for the agent - # Note execute tool will create files on host, a more robust sandbox should be used for production. with TemporaryDirectory() as sandbox_dir: sandbox = Path(sandbox_dir) - populate_sandbox(sandbox, skills_src_dirs, agents_md_src, config.workspace_dirs) - - # Create a local shell backend for the agent backend = LocalShellBackend( root_dir=sandbox, virtual_mode=True, inherit_env=True, env=env, ) - - # create subagent dictionaries sub_agent_dicts: list[dict] = [] for ref in config.subagents: fn = await builder.get_function(ref) @@ -222,7 +221,6 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse "Resolved %d subagent(s): %s", len(sub_agent_dicts), [sa.get("name", "?") for sa in sub_agent_dicts] ) - # Create a middleware chain for the agent to improve reliability and performance middleware = [ FixToolNamesMiddleware(), ToolRetryMiddleware(), @@ -236,7 +234,6 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse ), ] - # Create a dictionary of agent configuration arguments, including subagents if configured agent_kwargs: dict = dict( tools=config.tools, model=llm, @@ -252,29 +249,123 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse agent_kwargs["memory"] = effective_memory agent = create_deep_agent(**agent_kwargs) - - # Ensure child/orphaned processes are cleaned up pre_children = {c.pid for c in psutil.Process().children(recursive=True)} try: - agent_result = await agent.ainvoke({"messages": messages}) - - result_messages = agent_result["messages"] - content = result_messages[-1].content if result_messages else "" - content = strip_pattern(content, strip_re) + yield agent, messages finally: kill_orphaned_children(pre_children) - # Calculate usage metrics + def _usage_for_content(chat_request: ChatRequest, content: str) -> Usage: prompt_tokens = sum(len(str(m.content).split()) for m in chat_request.messages) completion_tokens = len(content.split()) if content else 0 - usage = Usage( + return Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) - response = ChatResponse.from_string(content, usage=usage) - if chat_request_or_message.is_string: - return GlobalTypeConverter.get().convert(response, to_type=str) - return response - yield FunctionInfo.from_fn(_inner, description=config.description) + def _response_model(chat_request: ChatRequest) -> str: + return (chat_request.model or "").strip() or "unknown-model" + + async def _single(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse: + """Non-streaming OpenAI chat completion (root JSON object, no ``value`` wrapper).""" + chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest) + async with _agent_session(chat_request) as (agent, messages): + agent_result = await agent.ainvoke({"messages": messages}) + result_messages = agent_result["messages"] + content = result_messages[-1].content if result_messages else "" + content = strip_pattern(content, strip_re) + usage = _usage_for_content(chat_request, content) + return ChatResponse.from_string(content, usage=usage, model=_response_model(chat_request)) + + async def _stream_llm_chunks(agent: object, messages: list) -> AsyncGenerator[str, None]: + """Yield main-agent assistant text segments (async generator for typing).""" + try: + astream = agent.astream( + {"messages": messages}, + stream_mode="messages", + subgraphs=True, + version="v2", + ) + except TypeError: + astream = agent.astream({"messages": messages}, stream_mode="messages", subgraphs=True) + + async for chunk in astream: + if not isinstance(chunk, dict) or chunk.get("type") != "messages": + continue + payload = chunk.get("data") + if not isinstance(payload, (list, tuple)) or len(payload) < 1: + continue + token = payload[0] + ns = chunk.get("ns") + if isinstance(ns, str): + ns = (ns,) + elif ns is None: + ns = () + if any(isinstance(s, str) and s.startswith("tools:") for s in ns): + continue + if getattr(token, "type", None) != "ai": + continue + if getattr(token, "tool_call_chunks", None): + continue + text = getattr(token, "content", None) or "" + if text: + yield text + + async def _stream(chat_request_or_message: ChatRequestOrMessage) -> AsyncGenerator[ChatResponseChunk, None]: + """OpenAI-style SSE chunks via NAT ``ChatResponseChunk`` (``data:`` lines when framed by NAT).""" + chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest) + response_model = _response_model(chat_request) + stream_id = str(uuid.uuid4()) + created = datetime.datetime.now(datetime.UTC) + assembled: list[str] = [] + + async with _agent_session(chat_request) as (agent, messages): + yield ChatResponseChunk.create_streaming_chunk( + "", + id_=stream_id, + created=created, + model=response_model, + role=UserMessageContentRoleType.ASSISTANT, + ) + try: + async for text in _stream_llm_chunks(agent, messages): + assembled.append(text) + yield ChatResponseChunk.create_streaming_chunk( + text, + id_=stream_id, + created=created, + model=response_model, + ) + except Exception: + logger.exception("Token streaming failed; falling back to buffered completion") + agent_result = await agent.ainvoke({"messages": messages}) + result_messages = agent_result["messages"] + content = result_messages[-1].content if result_messages else "" + content = strip_pattern(content, strip_re) + assembled.clear() + assembled.append(content) + yield ChatResponseChunk.create_streaming_chunk( + content, + id_=stream_id, + created=created, + model=response_model, + ) + + content = strip_pattern("".join(assembled), strip_re) + + usage = _usage_for_content(chat_request, content) + yield ChatResponseChunk.create_streaming_chunk( + "", + id_=stream_id, + created=created, + model=response_model, + finish_reason="stop", + usage=usage, + ) + + yield FunctionInfo.create( + single_fn=_single, + stream_fn=_stream, + description=config.description, + )