-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
feat: add stop functionality for active agent sessions and improve handling of stop requests #5380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -137,6 +137,8 @@ async def reset( | |
| self.tool_executor = tool_executor | ||
| self.agent_hooks = agent_hooks | ||
| self.run_context = run_context | ||
| self._stop_requested = False | ||
| self._aborted = False | ||
|
|
||
| # These two are used for tool schema mode handling | ||
| # We now have two modes: | ||
|
|
@@ -328,6 +330,14 @@ async def step(self): | |
| ), | ||
| ), | ||
| ) | ||
| if self._stop_requested: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (complexity): 建议提取构造“被中断的 LLM 响应”和“完成中止流程”的辅助函数,从而将 你可以通过抽取两个小 helper 来集中停止/中止路径并消除重复:一个用于构造“中断”的 1. 抽取构造中断
|
||
| llm_resp_result = LLMResponse( | ||
| role="assistant", | ||
| completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]", | ||
| reasoning_content=llm_response.reasoning_content, | ||
| reasoning_signature=llm_response.reasoning_signature, | ||
|
Comment on lines
+334
to
+338
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: 用户中断的系统消息字符串被重复使用了;建议将其集中管理。 这个中断标记字符串在这里以及循环结束后的 Suggested implementation: if self._stop_requested:
llm_resp_result = LLMResponse(
role="assistant",
completion_text=USER_ABORT_SYSTEM_MESSAGE,
reasoning_content=llm_response.reasoning_content,
reasoning_signature=llm_response.reasoning_signature,
)
break if not llm_resp_result:
if self._stop_requested:
llm_resp_result = LLMResponse(
role="assistant",
completion_text=USER_ABORT_SYSTEM_MESSAGE,
)
else:
return要完整实现这一建议,还需要:
USER_ABORT_SYSTEM_MESSAGE = (
"[SYSTEM: User actively interrupted the response generation. "
"Partial output before interruption is preserved.]"
)
如果你更倾向于用 helper 而不是裸常量(比如希望在有推理字段时保留下来),可以在本模块中定义一个类似 Original comment in Englishsuggestion: The user-interruption system message string is duplicated; consider centralizing it. This interruption marker string is hard-coded both here and in the Suggested implementation: if self._stop_requested:
llm_resp_result = LLMResponse(
role="assistant",
completion_text=USER_ABORT_SYSTEM_MESSAGE,
reasoning_content=llm_response.reasoning_content,
reasoning_signature=llm_response.reasoning_signature,
)
break if not llm_resp_result:
if self._stop_requested:
llm_resp_result = LLMResponse(
role="assistant",
completion_text=USER_ABORT_SYSTEM_MESSAGE,
)
else:
returnTo fully implement the suggestion, also:
USER_ABORT_SYSTEM_MESSAGE = (
"[SYSTEM: User actively interrupted the response generation. "
"Partial output before interruption is preserved.]"
)
If you prefer a helper instead of a bare constant (e.g., to preserve reasoning fields when available), you can define a function like |
||
| ) | ||
|
Comment on lines
+334
to
+339
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The runner currently does not accumulate streaming chunks. When a stop is requested, |
||
| break | ||
| continue | ||
| llm_resp_result = llm_response | ||
|
|
||
|
|
@@ -339,6 +349,48 @@ async def step(self): | |
| break # got final response | ||
|
|
||
| if not llm_resp_result: | ||
| if self._stop_requested: | ||
| llm_resp_result = LLMResponse(role="assistant", completion_text="") | ||
| else: | ||
| return | ||
|
|
||
| if self._stop_requested: | ||
| logger.info("Agent execution was requested to stop by user.") | ||
| llm_resp = llm_resp_result | ||
| if llm_resp.role != "assistant": | ||
| llm_resp = LLMResponse( | ||
| role="assistant", | ||
| completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]", | ||
| ) | ||
| self.final_llm_resp = llm_resp | ||
| self._aborted = True | ||
| self._transition_state(AgentState.DONE) | ||
| self.stats.end_time = time.time() | ||
|
|
||
| parts = [] | ||
| if llm_resp.reasoning_content or llm_resp.reasoning_signature: | ||
| parts.append( | ||
| ThinkPart( | ||
| think=llm_resp.reasoning_content, | ||
| encrypted=llm_resp.reasoning_signature, | ||
| ) | ||
| ) | ||
| if llm_resp.completion_text: | ||
| parts.append(TextPart(text=llm_resp.completion_text)) | ||
| if parts: | ||
| self.run_context.messages.append( | ||
| Message(role="assistant", content=parts) | ||
| ) | ||
|
|
||
| try: | ||
| await self.agent_hooks.on_agent_done(self.run_context, llm_resp) | ||
| except Exception as e: | ||
| logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) | ||
|
|
||
| yield AgentResponse( | ||
| type="aborted", | ||
| data=AgentResponseData(chain=MessageChain(type="aborted")), | ||
| ) | ||
| return | ||
|
|
||
| # 处理 LLM 响应 | ||
|
|
@@ -848,5 +900,11 @@ def done(self) -> bool: | |
| """检查 Agent 是否已完成工作""" | ||
| return self._state in (AgentState.DONE, AgentState.ERROR) | ||
|
|
||
| def request_stop(self) -> None: | ||
| self._stop_requested = True | ||
|
|
||
| def was_aborted(self) -> bool: | ||
| return self._aborted | ||
|
|
||
| def get_final_llm_resp(self) -> LLMResponse | None: | ||
| return self.final_llm_resp | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,10 @@ | |
| AgentRunner = ToolLoopAgentRunner[AstrAgentContext] | ||
|
|
||
|
|
||
| def _should_stop_agent(astr_event) -> bool: | ||
| return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested")) | ||
|
|
||
|
|
||
| async def run_agent( | ||
| agent_runner: AgentRunner, | ||
| max_step: int = 30, | ||
|
|
@@ -48,10 +52,28 @@ async def run_agent( | |
| ) | ||
| ) | ||
|
|
||
| stop_watcher = asyncio.create_task( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (complexity): 考虑重构停止处理逻辑,使只有一个 watcher 负责 你可以在不改变行为的前提下简化新的停止逻辑:
1.
|
||
| _watch_agent_stop_signal(agent_runner, astr_event), | ||
| ) | ||
|
Comment on lines
+55
to
+57
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Creating and cancelling the |
||
| try: | ||
| async for resp in agent_runner.step(): | ||
| if astr_event.is_stopped(): | ||
| if _should_stop_agent(astr_event): | ||
| agent_runner.request_stop() | ||
|
|
||
| if resp.type == "aborted": | ||
| if not stop_watcher.done(): | ||
| stop_watcher.cancel() | ||
| try: | ||
| await stop_watcher | ||
| except asyncio.CancelledError: | ||
| pass | ||
|
Comment on lines
+64
to
+69
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| astr_event.set_extra("agent_user_aborted", True) | ||
| astr_event.set_extra("agent_stop_requested", False) | ||
| return | ||
|
|
||
| if _should_stop_agent(astr_event): | ||
| continue | ||
|
|
||
| if resp.type == "tool_call_result": | ||
| msg_chain = resp.data["chain"] | ||
|
|
||
|
|
@@ -120,6 +142,12 @@ async def run_agent( | |
| # display the reasoning content only when configured | ||
| continue | ||
| yield resp.data["chain"] # MessageChain | ||
| if not stop_watcher.done(): | ||
| stop_watcher.cancel() | ||
| try: | ||
| await stop_watcher | ||
| except asyncio.CancelledError: | ||
| pass | ||
| if agent_runner.done(): | ||
| # send agent stats to webchat | ||
| if astr_event.get_platform_name() == "webchat": | ||
|
|
@@ -133,6 +161,12 @@ async def run_agent( | |
| break | ||
|
|
||
| except Exception as e: | ||
| if "stop_watcher" in locals() and not stop_watcher.done(): | ||
| stop_watcher.cancel() | ||
| try: | ||
| await stop_watcher | ||
| except asyncio.CancelledError: | ||
| pass | ||
| logger.error(traceback.format_exc()) | ||
|
|
||
| err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n" | ||
|
|
@@ -155,6 +189,14 @@ async def run_agent( | |
| return | ||
|
|
||
|
|
||
| async def _watch_agent_stop_signal(agent_runner: AgentRunner, astr_event) -> None: | ||
| while not agent_runner.done(): | ||
| if _should_stop_agent(astr_event): | ||
| agent_runner.request_stop() | ||
| return | ||
| await asyncio.sleep(0.5) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
|
|
||
| async def run_live_agent( | ||
| agent_runner: AgentRunner, | ||
| tts_provider: TTSProvider | None = None, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -247,13 +247,16 @@ async def process( | |
| yield | ||
|
|
||
| # 保存历史记录 | ||
| if not event.is_stopped() and agent_runner.done(): | ||
| if agent_runner.done() and ( | ||
| not event.is_stopped() or agent_runner.was_aborted() | ||
| ): | ||
| await self._save_to_history( | ||
| event, | ||
| req, | ||
| agent_runner.get_final_llm_resp(), | ||
| agent_runner.run_context.messages, | ||
| agent_runner.stats, | ||
| user_aborted=agent_runner.was_aborted(), | ||
| ) | ||
|
|
||
| elif streaming_response and not stream_to_general: | ||
|
|
@@ -308,13 +311,14 @@ async def process( | |
| ) | ||
|
|
||
| # 检查事件是否被停止,如果被停止则不保存历史记录 | ||
| if not event.is_stopped(): | ||
| if not event.is_stopped() or agent_runner.was_aborted(): | ||
| await self._save_to_history( | ||
| event, | ||
| req, | ||
| final_resp, | ||
| agent_runner.run_context.messages, | ||
| agent_runner.stats, | ||
| user_aborted=agent_runner.was_aborted(), | ||
| ) | ||
|
|
||
| asyncio.create_task( | ||
|
|
@@ -340,16 +344,29 @@ async def _save_to_history( | |
| llm_response: LLMResponse | None, | ||
| all_messages: list[Message], | ||
| runner_stats: AgentStats | None, | ||
| user_aborted: bool = False, | ||
| ) -> None: | ||
| if ( | ||
| not req | ||
| or not req.conversation | ||
| or not llm_response | ||
| or llm_response.role != "assistant" | ||
| ): | ||
| if not req or not req.conversation: | ||
| return | ||
|
|
||
| if not llm_response.completion_text and not req.tool_calls_result: | ||
| if not llm_response and not user_aborted: | ||
| return | ||
|
|
||
| if llm_response and llm_response.role != "assistant": | ||
| if not user_aborted: | ||
| return | ||
| llm_response = LLMResponse( | ||
| role="assistant", | ||
| completion_text=llm_response.completion_text or "", | ||
| ) | ||
| elif llm_response is None: | ||
| llm_response = LLMResponse(role="assistant", completion_text="") | ||
|
|
||
| if ( | ||
| not llm_response.completion_text | ||
| and not req.tool_calls_result | ||
| and not user_aborted | ||
| ): | ||
| logger.debug("LLM 响应为空,不保存记录。") | ||
| return | ||
|
|
||
|
|
@@ -363,6 +380,14 @@ async def _save_to_history( | |
| continue | ||
| message_to_save.append(message.model_dump()) | ||
|
|
||
| # if user_aborted: | ||
| # message_to_save.append( | ||
| # Message( | ||
| # role="assistant", | ||
| # content="[User aborted this request. Partial output before abort was preserved.]", | ||
| # ).model_dump() | ||
| # ) | ||
|
Comment on lines
+383
to
+389
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| token_usage = None | ||
| if runner_stats: | ||
| # token_usage = runner_stats.token_usage.total | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,9 @@ | |
| from astrbot.core import logger, sp | ||
| from astrbot.core.core_lifecycle import AstrBotCoreLifecycle | ||
| from astrbot.core.db import BaseDatabase | ||
| from astrbot.core.platform.message_type import MessageType | ||
| from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr | ||
| from astrbot.core.utils.active_event_registry import active_event_registry | ||
| from astrbot.core.utils.astrbot_path import get_astrbot_data_path | ||
|
|
||
| from .route import Response, Route, RouteContext | ||
|
|
@@ -41,6 +43,7 @@ def __init__( | |
| "/chat/new_session": ("GET", self.new_session), | ||
| "/chat/sessions": ("GET", self.get_sessions), | ||
| "/chat/get_session": ("GET", self.get_session), | ||
| "/chat/stop": ("POST", self.stop_session), | ||
| "/chat/delete_session": ("GET", self.delete_webchat_session), | ||
| "/chat/update_session_display_name": ( | ||
| "POST", | ||
|
|
@@ -466,13 +469,13 @@ async def stream(): | |
| if tc_id in tool_calls: | ||
| tool_calls[tc_id]["result"] = tcr.get("result") | ||
| tool_calls[tc_id]["finished_ts"] = tcr.get("ts") | ||
| accumulated_parts.append( | ||
| { | ||
| "type": "tool_call", | ||
| "tool_calls": [tool_calls[tc_id]], | ||
| } | ||
| ) | ||
| tool_calls.pop(tc_id, None) | ||
| accumulated_parts.append( | ||
| { | ||
| "type": "tool_call", | ||
| "tool_calls": [tool_calls[tc_id]], | ||
| } | ||
| ) | ||
| tool_calls.pop(tc_id, None) | ||
| elif chain_type == "reasoning": | ||
| accumulated_reasoning += result_text | ||
| elif streaming: | ||
|
|
@@ -603,6 +606,36 @@ async def stream(): | |
| response.timeout = None # fix SSE auto disconnect issue | ||
| return response | ||
|
|
||
| async def stop_session(self): | ||
| """Stop active agent runs for a session.""" | ||
| post_data = await request.json | ||
| if post_data is None: | ||
| return Response().error("Missing JSON body").__dict__ | ||
|
|
||
| session_id = post_data.get("session_id") | ||
| if not session_id: | ||
| return Response().error("Missing key: session_id").__dict__ | ||
|
|
||
| username = g.get("username", "guest") | ||
| session = await self.db.get_platform_session_by_id(session_id) | ||
| if not session: | ||
| return Response().error(f"Session {session_id} not found").__dict__ | ||
| if session.creator != username: | ||
| return Response().error("Permission denied").__dict__ | ||
|
|
||
| message_type = ( | ||
| MessageType.GROUP_MESSAGE.value | ||
| if session.is_group | ||
| else MessageType.FRIEND_MESSAGE.value | ||
| ) | ||
| umo = ( | ||
| f"{session.platform_id}:{message_type}:" | ||
| f"{session.platform_id}!{username}!{session_id}" | ||
| ) | ||
|
Comment on lines
+631
to
+634
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for constructing the |
||
| stopped_count = active_event_registry.request_agent_stop_all(umo) | ||
|
|
||
| return Response().ok(data={"stopped_count": stopped_count}).__dict__ | ||
|
|
||
| async def delete_webchat_session(self): | ||
| """Delete a Platform session and all its related data.""" | ||
| session_id = request.args.get("session_id") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
stopcommand lacks permission checks, allowing any user in a shared session (such as a group chat) to interrupt an active agent task initiated by another user. In contrast, other destructive or disruptive commands in the same file, such asresetanddel_conv, implement permission checks that default to requiring administrator privileges in group settings. This inconsistency allows a regular member to perform a denial-of-service-like action against other members' interactions with the bot.