Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 20 additions & 31 deletions agentrun/server/agui_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
import pydash

from ..utils.helper import merge, MergeOptions
from ..utils.reasoning import (
get_reasoning_content,
is_thinking_enabled_from_env,
)
from ..utils.reasoning import get_reasoning_content
from .model import (
AgentEvent,
AgentRequest,
Expand Down Expand Up @@ -466,8 +463,6 @@ def _process_event_with_boundaries(
ToolCallStartEvent,
)

thinking_enabled = is_thinking_enabled_from_env()

# RAW 事件直接透传
if event.event == EventType.RAW:
raw_data = event.data.get("raw", "")
Expand All @@ -478,34 +473,31 @@ def _process_event_with_boundaries(
return

if event.event == EventType.REASONING:
if thinking_enabled:
reasoning_content = (
event.data.get("delta")
or get_reasoning_content(event.data)
or ""
)
if reasoning_content:
for sse_data in state.end_text_if_open(self._encoder):
yield sse_data
for sse_data in state.end_all_tools(self._encoder):
yield sse_data
for sse_data in state.ensure_reasoning_started():
yield sse_data
yield _encode_reasoning_event(
"REASONING_MESSAGE_CONTENT",
messageId=state.reasoning.message_id,
delta=reasoning_content,
)
reasoning_content = (
event.data.get("delta")
or get_reasoning_content(event.data)
or ""
)
if reasoning_content:
for sse_data in state.end_text_if_open(self._encoder):
yield sse_data
for sse_data in state.end_all_tools(self._encoder):
yield sse_data
for sse_data in state.ensure_reasoning_started():
yield sse_data
yield _encode_reasoning_event(
"REASONING_MESSAGE_CONTENT",
messageId=state.reasoning.message_id,
delta=reasoning_content,
)
return

# TEXT 事件:在首个 TEXT 前注入 TEXT_MESSAGE_START
# AG-UI 协议要求:发送 TEXT_MESSAGE_START 前必须先结束所有未结束的 TOOL_CALL
if event.event == EventType.TEXT:
addition = self._strip_reasoning_from_addition(
event.addition, thinking_enabled
)
addition = self._strip_reasoning_from_addition(event.addition)
addition_reasoning = get_reasoning_content(event.addition or {})
if thinking_enabled and addition_reasoning:
if addition_reasoning:
for sse_data in state.ensure_reasoning_started():
yield sse_data
yield _encode_reasoning_event(
Expand Down Expand Up @@ -874,7 +866,6 @@ def _apply_addition(
def _strip_reasoning_from_addition(
self,
addition: Optional[Dict[str, Any]],
thinking_enabled: bool,
) -> Optional[Dict[str, Any]]:
if not addition:
return addition
Expand All @@ -890,8 +881,6 @@ def _strip_reasoning_from_addition(
else:
stripped.pop("additional_kwargs", None)

if not thinking_enabled:
return stripped
return stripped or None

async def _error_stream(self, message: str) -> AsyncIterator[str]:
Expand Down
48 changes: 19 additions & 29 deletions agentrun/server/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
from fastapi.responses import JSONResponse, StreamingResponse
import pydash

from ..utils.reasoning import (
get_reasoning_content,
is_thinking_enabled_from_env,
)
from ..utils.reasoning import get_reasoning_content
from ..utils.helper import merge, MergeOptions
from .model import (
AgentEvent,
Expand Down Expand Up @@ -304,7 +301,6 @@ async def _format_stream(
# 状态追踪
sent_role = False
has_text = False
thinking_enabled = is_thinking_enabled_from_env()
tool_call_index = -1 # 从 -1 开始,第一个工具调用时变为 0
# 工具调用状态:{tool_id: {"started": bool, "index": int}}
tool_call_states: Dict[str, Dict[str, Any]] = {}
Expand Down Expand Up @@ -341,19 +337,18 @@ async def _format_stream(
event.addition_merge_options,
)

self._apply_reasoning_gate(delta, thinking_enabled)
self._promote_reasoning_content(delta)
yield self._build_chunk(context, delta)
continue

if event.event == EventType.REASONING:
if thinking_enabled:
reasoning_content = event.data.get("delta", "")
if reasoning_content:
has_text = True
yield self._build_chunk(
context,
{"reasoning_content": reasoning_content},
)
reasoning_content = event.data.get("delta", "")
if reasoning_content:
has_text = True
yield self._build_chunk(
Comment on lines 344 to +348
context,
{"reasoning_content": reasoning_content},
)
continue

# TOOL_CALL_CHUNK 事件
Expand Down Expand Up @@ -401,7 +396,7 @@ async def _format_stream(
event.addition_merge_options,
)

self._apply_reasoning_gate(delta, thinking_enabled)
self._promote_reasoning_content(delta)
yield self._build_chunk(context, delta)
continue

Expand Down Expand Up @@ -477,7 +472,6 @@ def _format_non_stream(
"""
content_parts: List[str] = []
reasoning_parts: List[str] = []
thinking_enabled = is_thinking_enabled_from_env()
# 工具调用状态:{tool_id: {id, name, arguments}}
tool_call_map: Dict[str, Dict[str, Any]] = {}
has_tool_calls = False
Expand All @@ -486,12 +480,12 @@ def _format_non_stream(
if event.event == EventType.TEXT:
content_parts.append(event.data.get("delta", ""))
reasoning_content = get_reasoning_content(event.addition or {})
if thinking_enabled and reasoning_content:
if reasoning_content:
reasoning_parts.append(reasoning_content)
Comment on lines 480 to 484

elif event.event == EventType.REASONING:
reasoning_content = event.data.get("delta", "")
if thinking_enabled and reasoning_content:
if reasoning_content:
reasoning_parts.append(reasoning_content)

elif event.event == EventType.TOOL_CALL_CHUNK:
Expand Down Expand Up @@ -564,18 +558,14 @@ def _apply_addition(

return merge(delta, addition, **(merge_options or {}))

def _apply_reasoning_gate(
self,
payload: Dict[str, Any],
thinking_enabled: bool,
) -> None:
if thinking_enabled:
reasoning_content = get_reasoning_content(payload)
if reasoning_content is not None:
payload["reasoning_content"] = reasoning_content
return

def _promote_reasoning_content(self, payload: Dict[str, Any]) -> None:
reasoning_content = get_reasoning_content(payload)
payload.pop("reasoning_content", None)
additional_kwargs = payload.get("additional_kwargs")
if isinstance(additional_kwargs, dict):
additional_kwargs.pop("reasoning_content", None)
if not additional_kwargs:
payload.pop("additional_kwargs", None)

if reasoning_content:
payload["reasoning_content"] = reasoning_content
29 changes: 18 additions & 11 deletions tests/unittests/server/test_agui_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@ async def invoke_agent(request: AgentRequest):


class TestAGUIReasoningContent:
"""测试 AG-UI reasoning 事件输出开关"""
"""测试 AG-UI reasoning 事件输出"""

def get_client(self, invoke_agent):
server = AgentRunServer(invoke_agent=invoke_agent)
Expand Down Expand Up @@ -1228,7 +1228,7 @@ async def invoke_agent(request: AgentRequest):
assert reasoning_event["delta"] == "thinking"
assert "TEXT_MESSAGE_CONTENT" in types

def test_stream_suppresses_reasoning_when_thinking_disabled(
def test_stream_includes_reasoning_when_thinking_disabled(
self, monkeypatch
):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}')
Expand All @@ -1246,9 +1246,14 @@ async def invoke_agent(request: AgentRequest):
)

events = _agui_sse_events(response)
assert "REASONING_MESSAGE_CONTENT" not in [
event["type"] for event in events
]
types = [event["type"] for event in events]
reasoning_event = next(
event
for event in events
if event["type"] == "REASONING_MESSAGE_CONTENT"
)
assert "REASONING_START" in types
assert reasoning_event["delta"] == "thinking"
text_event = next(
event for event in events if event["type"] == "TEXT_MESSAGE_CONTENT"
)
Expand All @@ -1257,7 +1262,7 @@ async def invoke_agent(request: AgentRequest):
def test_stream_promotes_chunk_additional_kwargs_reasoning(
self, monkeypatch
):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": true}')
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}')

async def invoke_agent(request: AgentRequest):
yield SimpleNamespace(
Expand All @@ -1282,9 +1287,7 @@ async def invoke_agent(request: AgentRequest):
assert reasoning_event["delta"] == "thinking"
assert text_event["delta"] == "answer"

def test_text_addition_reasoning_is_emitted_before_text(
self, monkeypatch
):
def test_text_addition_reasoning_is_emitted_before_text(self, monkeypatch):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": true}')

async def invoke_agent(request: AgentRequest):
Expand Down Expand Up @@ -1314,7 +1317,7 @@ async def invoke_agent(request: AgentRequest):
assert text_event["delta"] == "answer"
assert "additional_kwargs" not in text_event

def test_text_addition_reasoning_is_stripped_when_thinking_disabled(
def test_text_addition_reasoning_is_emitted_when_thinking_disabled(
self, monkeypatch
):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}')
Expand All @@ -1335,7 +1338,11 @@ async def invoke_agent(request: AgentRequest):

events = _agui_sse_events(response)
types = [event["type"] for event in events]
assert all(not event_type.startswith("REASONING") for event_type in types)
assert types.index("REASONING_MESSAGE_CONTENT") < types.index(
"TEXT_MESSAGE_START"
)
assert "REASONING_MESSAGE_END" in types
assert "REASONING_END" in types
text_event = next(
event for event in events if event["type"] == "TEXT_MESSAGE_CONTENT"
)
Expand Down
52 changes: 41 additions & 11 deletions tests/unittests/server/test_openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ def invoke_agent(request: AgentRequest):


class TestOpenAIReasoningContent:
"""测试 OpenAI reasoning_content 输出开关"""
"""测试 OpenAI reasoning_content 输出"""

def get_client(self, invoke_agent):
server = AgentRunServer(invoke_agent=invoke_agent)
Expand All @@ -1042,10 +1042,12 @@ async def invoke_agent(request: AgentRequest):
)

events = _openai_sse_events(response)
assert events[0]["choices"][0]["delta"]["reasoning_content"] == "thinking"
assert (
events[0]["choices"][0]["delta"]["reasoning_content"] == "thinking"
)
assert events[1]["choices"][0]["delta"]["content"] == "answer"

def test_stream_suppresses_reasoning_when_thinking_disabled(
def test_stream_includes_reasoning_when_thinking_disabled(
self, monkeypatch
):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}')
Expand All @@ -1066,11 +1068,10 @@ async def invoke_agent(request: AgentRequest):
)

events = _openai_sse_events(response)
assert all(
"reasoning_content" not in event["choices"][0]["delta"]
for event in events
assert (
events[0]["choices"][0]["delta"]["reasoning_content"] == "thinking"
)
assert events[0]["choices"][0]["delta"]["content"] == "answer"
assert events[1]["choices"][0]["delta"]["content"] == "answer"

def test_non_stream_includes_reasoning_when_thinking_enabled(
self, monkeypatch
Expand Down Expand Up @@ -1098,7 +1099,7 @@ def invoke_agent(request: AgentRequest):
assert message["content"] == "answer"
assert message["reasoning_content"] == "thinking"

def test_non_stream_suppresses_reasoning_when_thinking_disabled(
def test_non_stream_includes_reasoning_when_thinking_disabled(
self, monkeypatch
):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}')
Expand All @@ -1122,12 +1123,12 @@ def invoke_agent(request: AgentRequest):

message = response.json()["choices"][0]["message"]
assert message["content"] == "answer"
assert "reasoning_content" not in message
assert message["reasoning_content"] == "thinking"

def test_stream_promotes_chunk_additional_kwargs_reasoning(
self, monkeypatch
):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": true}')
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}')

async def invoke_agent(request: AgentRequest):
yield SimpleNamespace(
Expand All @@ -1144,9 +1145,38 @@ async def invoke_agent(request: AgentRequest):
)

events = _openai_sse_events(response)
assert events[0]["choices"][0]["delta"]["reasoning_content"] == "thinking"
assert (
events[0]["choices"][0]["delta"]["reasoning_content"] == "thinking"
)
assert events[1]["choices"][0]["delta"]["content"] == "answer"

def test_stream_promotes_text_addition_reasoning_when_thinking_disabled(
self, monkeypatch
):
monkeypatch.setenv("MODEL_PARAMETER_RULES", '{"thinking": false}')

async def invoke_agent(request: AgentRequest):
yield AgentEvent(
event=EventType.TEXT,
data={"delta": "answer"},
addition={
"additional_kwargs": {"reasoning_content": "thinking"}
},
)

response = self.get_client(invoke_agent).post(
"/openai/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hi"}],
"stream": True,
},
)

delta = _openai_sse_events(response)[0]["choices"][0]["delta"]
assert delta["content"] == "answer"
assert delta["reasoning_content"] == "thinking"
assert "additional_kwargs" not in delta

def test_parses_request_message_reasoning_content(self):
captured_request = {}

Expand Down
Loading