Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 125 additions & 34 deletions cuopt-agent/cuopt_agent/src/nat_cuopt_agent/function/deepagent_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import logging
import os
import re
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
from tempfile import TemporaryDirectory

from nat.builder.builder import Builder
from nat.builder.framework_enum import LLMFrameworkEnum
from nat.builder.function_info import FunctionInfo
from nat.cli.register_workflow import register_function
from nat.data_models.api_server import ChatRequest, ChatRequestOrMessage, ChatResponse, Usage
from nat.data_models.api_server import (
ChatRequest,
ChatRequestOrMessage,
ChatResponse,
ChatResponseChunk,
Usage,
UserMessageContentRoleType,
)
from nat.data_models.component_ref import FunctionRef, LLMRef
from nat.data_models.function import FunctionBaseConfig
from nat.utils.type_converter import GlobalTypeConverter
Expand Down Expand Up @@ -183,33 +194,21 @@ async def deep_agent(config: DeepAgentConfig, builder: Builder):
# Workaround to strip reasoning patterns from the final response with minimax model
strip_re = re.compile(config.strip_reasoning_pattern, re.DOTALL) if config.strip_reasoning_pattern else None

# Inner function that handles the agent invocation and response processing
async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str:
"""Inner function that handles the agent invocation and response processing.
Args:
chat_request_or_message: The chat request or message to process.
Returns:
A chat response or string.
"""
chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
@asynccontextmanager
async def _agent_session(
chat_request: ChatRequest,
) -> AsyncIterator[tuple[object, list]]:
"""Yield (agent, messages_dict_list) inside a sandbox; cleans up child processes on exit."""
messages = [m.model_dump() for m in chat_request.messages]

# Create a temporary sandbox directory for the agent
# Note execute tool will create files on host, a more robust sandbox should be used for production.
with TemporaryDirectory() as sandbox_dir:
sandbox = Path(sandbox_dir)

populate_sandbox(sandbox, skills_src_dirs, agents_md_src, config.workspace_dirs)

# Create a local shell backend for the agent
backend = LocalShellBackend(
root_dir=sandbox,
virtual_mode=True,
inherit_env=True,
env=env,
)

# create subagent dictionaries
sub_agent_dicts: list[dict] = []
for ref in config.subagents:
fn = await builder.get_function(ref)
Expand All @@ -222,7 +221,6 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse
"Resolved %d subagent(s): %s", len(sub_agent_dicts), [sa.get("name", "?") for sa in sub_agent_dicts]
)

# Create a middleware chain for the agent to improve reliability and performance
middleware = [
FixToolNamesMiddleware(),
ToolRetryMiddleware(),
Expand All @@ -236,7 +234,6 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse
),
]

# Create a dictionary of agent configuration arguments, including subagents if configured
agent_kwargs: dict = dict(
tools=config.tools,
model=llm,
Expand All @@ -252,29 +249,123 @@ async def _inner(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse
agent_kwargs["memory"] = effective_memory

agent = create_deep_agent(**agent_kwargs)

# Ensure child/orphaned processes are cleaned up
pre_children = {c.pid for c in psutil.Process().children(recursive=True)}
try:
agent_result = await agent.ainvoke({"messages": messages})

result_messages = agent_result["messages"]
content = result_messages[-1].content if result_messages else ""
content = strip_pattern(content, strip_re)
yield agent, messages
finally:
kill_orphaned_children(pre_children)

# Calculate usage metrics
def _usage_for_content(chat_request: ChatRequest, content: str) -> Usage:
prompt_tokens = sum(len(str(m.content).split()) for m in chat_request.messages)
completion_tokens = len(content.split()) if content else 0
usage = Usage(
return Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response = ChatResponse.from_string(content, usage=usage)
if chat_request_or_message.is_string:
return GlobalTypeConverter.get().convert(response, to_type=str)
return response

yield FunctionInfo.from_fn(_inner, description=config.description)
def _response_model(chat_request: ChatRequest) -> str:
return (chat_request.model or "").strip() or "unknown-model"

async def _single(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse:
"""Non-streaming OpenAI chat completion (root JSON object, no ``value`` wrapper)."""
chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
async with _agent_session(chat_request) as (agent, messages):
agent_result = await agent.ainvoke({"messages": messages})
result_messages = agent_result["messages"]
content = result_messages[-1].content if result_messages else ""
content = strip_pattern(content, strip_re)
usage = _usage_for_content(chat_request, content)
return ChatResponse.from_string(content, usage=usage, model=_response_model(chat_request))

async def _stream_llm_chunks(agent: object, messages: list) -> AsyncGenerator[str, None]:
"""Yield main-agent assistant text segments (async generator for typing)."""
try:
astream = agent.astream(
{"messages": messages},
stream_mode="messages",
subgraphs=True,
version="v2",
)
except TypeError:
astream = agent.astream({"messages": messages}, stream_mode="messages", subgraphs=True)

async for chunk in astream:
if not isinstance(chunk, dict) or chunk.get("type") != "messages":
continue
payload = chunk.get("data")
if not isinstance(payload, (list, tuple)) or len(payload) < 1:
continue
token = payload[0]
ns = chunk.get("ns")
if isinstance(ns, str):
ns = (ns,)
elif ns is None:
ns = ()
if any(isinstance(s, str) and s.startswith("tools:") for s in ns):
continue
if getattr(token, "type", None) != "ai":
continue
if getattr(token, "tool_call_chunks", None):
continue
text = getattr(token, "content", None) or ""
if text:
yield text

async def _stream(chat_request_or_message: ChatRequestOrMessage) -> AsyncGenerator[ChatResponseChunk, None]:
"""OpenAI-style SSE chunks via NAT ``ChatResponseChunk`` (``data:`` lines when framed by NAT)."""
chat_request = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
response_model = _response_model(chat_request)
stream_id = str(uuid.uuid4())
created = datetime.datetime.now(datetime.UTC)
assembled: list[str] = []

async with _agent_session(chat_request) as (agent, messages):
yield ChatResponseChunk.create_streaming_chunk(
"",
id_=stream_id,
created=created,
model=response_model,
role=UserMessageContentRoleType.ASSISTANT,
)
try:
async for text in _stream_llm_chunks(agent, messages):
assembled.append(text)
yield ChatResponseChunk.create_streaming_chunk(
text,
id_=stream_id,
created=created,
model=response_model,
)
except Exception:
logger.exception("Token streaming failed; falling back to buffered completion")
agent_result = await agent.ainvoke({"messages": messages})
result_messages = agent_result["messages"]
content = result_messages[-1].content if result_messages else ""
content = strip_pattern(content, strip_re)
assembled.clear()
assembled.append(content)
yield ChatResponseChunk.create_streaming_chunk(
content,
id_=stream_id,
created=created,
model=response_model,
)

content = strip_pattern("".join(assembled), strip_re)

usage = _usage_for_content(chat_request, content)
yield ChatResponseChunk.create_streaming_chunk(
"",
id_=stream_id,
created=created,
model=response_model,
finish_reason="stop",
usage=usage,
)

yield FunctionInfo.create(
single_fn=_single,
stream_fn=_stream,
description=config.description,
)