diff --git a/astrbot/builtin_stars/builtin_commands/commands/conversation.py b/astrbot/builtin_stars/builtin_commands/commands/conversation.py index f6d5db9148..63561f64ed 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/conversation.py +++ b/astrbot/builtin_stars/builtin_commands/commands/conversation.py @@ -102,6 +102,30 @@ async def reset(self, message: AstrMessageEvent) -> None: message.set_result(MessageEventResult().message(ret)) + async def stop(self, message: AstrMessageEvent) -> None: + """停止当前会话正在运行的 Agent""" + cfg = self.context.get_config(umo=message.unified_msg_origin) + agent_runner_type = cfg["provider_settings"]["agent_runner_type"] + umo = message.unified_msg_origin + + if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: + stopped_count = active_event_registry.stop_all(umo, exclude=message) + else: + stopped_count = active_event_registry.request_agent_stop_all( + umo, + exclude=message, + ) + + if stopped_count > 0: + message.set_result( + MessageEventResult().message( + f"已请求停止 {stopped_count} 个运行中的任务。" + ) + ) + return + + message.set_result(MessageEventResult().message("当前会话没有运行中的任务。")) + async def his(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话记录""" if not self.context.get_using_provider(message.unified_msg_origin): diff --git a/astrbot/builtin_stars/builtin_commands/main.py b/astrbot/builtin_stars/builtin_commands/main.py index 9b839ca881..fb4a834035 100644 --- a/astrbot/builtin_stars/builtin_commands/main.py +++ b/astrbot/builtin_stars/builtin_commands/main.py @@ -132,6 +132,11 @@ async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" await self.conversation_c.reset(message) + @filter.command("stop") + async def stop(self, message: AstrMessageEvent) -> None: + """停止当前会话中正在运行的 Agent""" + await self.conversation_c.stop(message) + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("model") async def model_ls( diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 9f80dae1c9..10cf2e96c6 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -137,6 +137,8 @@ async def reset( self.tool_executor = tool_executor self.agent_hooks = agent_hooks self.run_context = run_context + self._stop_requested = False + self._aborted = False # These two are used for tool schema mode handling # We now have two modes: @@ -328,6 +330,14 @@ async def step(self): ), ), ) + if self._stop_requested: + llm_resp_result = LLMResponse( + role="assistant", + completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]", + reasoning_content=llm_response.reasoning_content, + reasoning_signature=llm_response.reasoning_signature, + ) + break continue llm_resp_result = llm_response @@ -339,6 +349,48 @@ async def step(self): break # got final response if not llm_resp_result: + if self._stop_requested: + llm_resp_result = LLMResponse(role="assistant", completion_text="") + else: + return + + if self._stop_requested: + logger.info("Agent execution was requested to stop by user.") + llm_resp = llm_resp_result + if llm_resp.role != "assistant": + llm_resp = LLMResponse( + role="assistant", + completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]", + ) + self.final_llm_resp = llm_resp + self._aborted = True + self._transition_state(AgentState.DONE) + self.stats.end_time = time.time() + + parts = [] + if llm_resp.reasoning_content or llm_resp.reasoning_signature: + parts.append( + ThinkPart( + think=llm_resp.reasoning_content, + encrypted=llm_resp.reasoning_signature, + ) + ) + if llm_resp.completion_text: + parts.append(TextPart(text=llm_resp.completion_text)) + if parts: + self.run_context.messages.append( + Message(role="assistant", content=parts) + ) + + try: + await self.agent_hooks.on_agent_done(self.run_context, llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + + yield AgentResponse( + type="aborted", + data=AgentResponseData(chain=MessageChain(type="aborted")), + ) return # 处理 LLM 响应 @@ -848,5 +900,11 @@ def done(self) -> bool: """检查 Agent 是否已完成工作""" return self._state in (AgentState.DONE, AgentState.ERROR) + def request_stop(self) -> None: + self._stop_requested = True + + def was_aborted(self) -> bool: + return self._aborted + def get_final_llm_resp(self) -> LLMResponse | None: return self.final_llm_resp diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index 379e62d6a5..c3c001da1c 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -20,6 +20,10 @@ AgentRunner = ToolLoopAgentRunner[AstrAgentContext] +def _should_stop_agent(astr_event) -> bool: + return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested")) + + async def run_agent( agent_runner: AgentRunner, max_step: int = 30, @@ -48,10 +52,28 @@ async def run_agent( ) ) + stop_watcher = asyncio.create_task( + _watch_agent_stop_signal(agent_runner, astr_event), + ) try: async for resp in agent_runner.step(): - if astr_event.is_stopped(): + if _should_stop_agent(astr_event): + agent_runner.request_stop() + + if resp.type == "aborted": + if not stop_watcher.done(): + stop_watcher.cancel() + try: + await stop_watcher + except asyncio.CancelledError: + pass + astr_event.set_extra("agent_user_aborted", True) + astr_event.set_extra("agent_stop_requested", False) return + + if _should_stop_agent(astr_event): + continue + if resp.type == "tool_call_result": msg_chain = resp.data["chain"] @@ -120,6 +142,12 @@ async def run_agent( # display the reasoning content only when configured continue yield resp.data["chain"] # MessageChain + if not stop_watcher.done(): + stop_watcher.cancel() + try: + await stop_watcher + except asyncio.CancelledError: + pass if agent_runner.done(): # send agent stats to webchat if astr_event.get_platform_name() == "webchat": @@ -133,6 +161,12 @@ async def run_agent( break except Exception as e: + if "stop_watcher" in locals() and not stop_watcher.done(): + stop_watcher.cancel() + try: + await stop_watcher + except asyncio.CancelledError: + pass logger.error(traceback.format_exc()) err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n" @@ -155,6 +189,14 @@ async def run_agent( return +async def _watch_agent_stop_signal(agent_runner: AgentRunner, astr_event) -> None: + while not agent_runner.done(): + if _should_stop_agent(astr_event): + agent_runner.request_stop() + return + await asyncio.sleep(0.5) + + async def run_live_agent( agent_runner: AgentRunner, tts_provider: TTSProvider | None = None, diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index be517dba99..33908fa982 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -247,13 +247,16 @@ async def process( yield # 保存历史记录 - if not event.is_stopped() and agent_runner.done(): + if agent_runner.done() and ( + not event.is_stopped() or agent_runner.was_aborted() + ): await self._save_to_history( event, req, agent_runner.get_final_llm_resp(), agent_runner.run_context.messages, agent_runner.stats, + user_aborted=agent_runner.was_aborted(), ) elif streaming_response and not stream_to_general: @@ -308,13 +311,14 @@ async def process( ) # 检查事件是否被停止,如果被停止则不保存历史记录 - if not event.is_stopped(): + if not event.is_stopped() or agent_runner.was_aborted(): await self._save_to_history( event, req, final_resp, agent_runner.run_context.messages, agent_runner.stats, + user_aborted=agent_runner.was_aborted(), ) asyncio.create_task( @@ -340,16 +344,29 @@ async def _save_to_history( llm_response: LLMResponse | None, all_messages: list[Message], runner_stats: AgentStats | None, + user_aborted: bool = False, ) -> None: - if ( - not req - or not req.conversation - or not llm_response - or llm_response.role != "assistant" - ): + if not req or not req.conversation: return - if not llm_response.completion_text and not req.tool_calls_result: + if not llm_response and not user_aborted: + return + + if llm_response and llm_response.role != "assistant": + if not user_aborted: + return + llm_response = LLMResponse( + role="assistant", + completion_text=llm_response.completion_text or "", + ) + elif llm_response is None: + llm_response = LLMResponse(role="assistant", completion_text="") + + if ( + not llm_response.completion_text + and not req.tool_calls_result + and not user_aborted + ): logger.debug("LLM 响应为空,不保存记录。") return @@ -363,6 +380,14 @@ async def _save_to_history( continue message_to_save.append(message.model_dump()) + # if user_aborted: + # message_to_save.append( + # Message( + # role="assistant", + # content="[User aborted this request. Partial output before abort was preserved.]", + # ).model_dump() + # ) + token_usage = None if runner_stats: # token_usage = runner_stats.token_usage.total diff --git a/astrbot/core/utils/active_event_registry.py b/astrbot/core/utils/active_event_registry.py index 2548599330..d98cdee37f 100644 --- a/astrbot/core/utils/active_event_registry.py +++ b/astrbot/core/utils/active_event_registry.py @@ -46,5 +46,22 @@ def stop_all( count += 1 return count + def request_agent_stop_all( + self, + umo: str, + exclude: AstrMessageEvent | None = None, + ) -> int: + """请求停止指定 UMO 的所有活跃事件中的 Agent 运行。 + + 与 stop_all 不同,这里不会调用 event.stop_event(), + 因此不会中断事件传播,后续流程(如历史记录保存)仍可继续。 + """ + count = 0 + for event in list(self._events.get(umo, [])): + if event is not exclude: + event.set_extra("agent_stop_requested", True) + count += 1 + return count + active_event_registry = ActiveEventRegistry() diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 495854b1b5..ffecaf89ed 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -13,7 +13,9 @@ from astrbot.core import logger, sp from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase +from astrbot.core.platform.message_type import MessageType from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr +from astrbot.core.utils.active_event_registry import active_event_registry from astrbot.core.utils.astrbot_path import get_astrbot_data_path from .route import Response, Route, RouteContext @@ -41,6 +43,7 @@ def __init__( "/chat/new_session": ("GET", self.new_session), "/chat/sessions": ("GET", self.get_sessions), "/chat/get_session": ("GET", self.get_session), + "/chat/stop": ("POST", self.stop_session), "/chat/delete_session": ("GET", self.delete_webchat_session), "/chat/update_session_display_name": ( "POST", @@ -466,13 +469,13 @@ async def stream(): if tc_id in tool_calls: tool_calls[tc_id]["result"] = tcr.get("result") tool_calls[tc_id]["finished_ts"] = tcr.get("ts") - accumulated_parts.append( - { - "type": "tool_call", - "tool_calls": [tool_calls[tc_id]], - } - ) - tool_calls.pop(tc_id, None) + accumulated_parts.append( + { + "type": "tool_call", + "tool_calls": [tool_calls[tc_id]], + } + ) + tool_calls.pop(tc_id, None) elif chain_type == "reasoning": accumulated_reasoning += result_text elif streaming: @@ -603,6 +606,36 @@ async def stream(): response.timeout = None # fix SSE auto disconnect issue return response + async def stop_session(self): + """Stop active agent runs for a session.""" + post_data = await request.json + if post_data is None: + return Response().error("Missing JSON body").__dict__ + + session_id = post_data.get("session_id") + if not session_id: + return Response().error("Missing key: session_id").__dict__ + + username = g.get("username", "guest") + session = await self.db.get_platform_session_by_id(session_id) + if not session: + return Response().error(f"Session {session_id} not found").__dict__ + if session.creator != username: + return Response().error("Permission denied").__dict__ + + message_type = ( + MessageType.GROUP_MESSAGE.value + if session.is_group + else MessageType.FRIEND_MESSAGE.value + ) + umo = ( + f"{session.platform_id}:{message_type}:" + f"{session.platform_id}!{username}!{session_id}" + ) + stopped_count = active_event_registry.request_agent_stop_all(umo) + + return Response().ok(data={"stopped_count": stopped_count}).__dict__ + async def delete_webchat_session(self): """Delete a Platform session and all its related data.""" session_id = request.args.get("session_id") diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 71e46e690b..803c5d826a 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -77,12 +77,14 @@ :stagedAudioUrl="stagedAudioUrl" :stagedFiles="stagedNonImageFiles" :disabled="isStreaming" + :is-running="isStreaming || isConvRunning" :enableStreaming="enableStreaming" :isRecording="isRecording" :session-id="currSessionId || null" :current-session="getCurrentSession" :replyTo="replyTo" @send="handleSendMessage" + @stop="handleStopMessage" @toggleStreaming="toggleStreaming" @removeImage="removeImage" @removeAudio="removeAudio" @@ -106,12 +108,14 @@ :stagedAudioUrl="stagedAudioUrl" :stagedFiles="stagedNonImageFiles" :disabled="isStreaming" + :is-running="isStreaming || isConvRunning" :enableStreaming="enableStreaming" :isRecording="isRecording" :session-id="currSessionId || null" :current-session="getCurrentSession" :replyTo="replyTo" @send="handleSendMessage" + @stop="handleStopMessage" @toggleStreaming="toggleStreaming" @removeImage="removeImage" @removeAudio="removeAudio" @@ -134,12 +138,14 @@ :stagedAudioUrl="stagedAudioUrl" :stagedFiles="stagedNonImageFiles" :disabled="isStreaming" + :is-running="isStreaming || isConvRunning" :enableStreaming="enableStreaming" :isRecording="isRecording" :session-id="currSessionId || null" :current-session="getCurrentSession" :replyTo="replyTo" @send="handleSendMessage" + @stop="handleStopMessage" @toggleStreaming="toggleStreaming" @removeImage="removeImage" @removeAudio="removeAudio" @@ -298,6 +304,7 @@ const { currentSessionProject, getSessionMessages: getSessionMsg, sendMessage: sendMsg, + stopMessage: stopMsg, toggleStreaming } = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions); @@ -631,6 +638,10 @@ async function handleSendMessage() { } } +async function handleStopMessage() { + await stopMsg(); +} + // 路由变化监听 watch( () => route.path, diff --git a/dashboard/src/components/chat/ChatInput.vue b/dashboard/src/components/chat/ChatInput.vue index 35ec22cd37..63cb03e3d7 100644 --- a/dashboard/src/components/chat/ChatInput.vue +++ b/dashboard/src/components/chat/ChatInput.vue @@ -94,8 +94,29 @@ {{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }} - + + + + {{ tm('input.stopGenerating') }} + + + @@ -160,6 +181,7 @@ interface Props { disabled: boolean; enableStreaming: boolean; isRecording: boolean; + isRunning: boolean; sessionId?: string | null; currentSession?: Session | null; configId?: string | null; @@ -177,6 +199,7 @@ const props = withDefaults(defineProps(), { const emit = defineEmits<{ 'update:prompt': [value: string]; send: []; + stop: []; toggleStreaming: []; removeImage: [index: number]; removeAudio: []; diff --git a/dashboard/src/components/chat/StandaloneChat.vue b/dashboard/src/components/chat/StandaloneChat.vue index db762dabf7..69fac13f9b 100644 --- a/dashboard/src/components/chat/StandaloneChat.vue +++ b/dashboard/src/components/chat/StandaloneChat.vue @@ -23,12 +23,14 @@ :stagedImagesUrl="stagedImagesUrl" :stagedAudioUrl="stagedAudioUrl" :disabled="isStreaming" + :is-running="isStreaming || isConvRunning" :enableStreaming="enableStreaming" :isRecording="isRecording" :session-id="currSessionId || null" :current-session="getCurrentSession" :config-id="configId" @send="handleSendMessage" + @stop="handleStopMessage" @toggleStreaming="toggleStreaming" @removeImage="removeImage" @removeAudio="removeAudio" @@ -156,6 +158,7 @@ const { enableStreaming, getSessionMessages: getSessionMsg, sendMessage: sendMsg, + stopMessage: stopMsg, toggleStreaming } = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions); @@ -236,6 +239,10 @@ async function handleSendMessage() { } } +async function handleStopMessage() { + await stopMsg(); +} + onMounted(async () => { // 独立模式在挂载时创建新会话 try { diff --git a/dashboard/src/composables/useMessages.ts b/dashboard/src/composables/useMessages.ts index 7174f43718..f11f678161 100644 --- a/dashboard/src/composables/useMessages.ts +++ b/dashboard/src/composables/useMessages.ts @@ -82,6 +82,10 @@ export function useMessages( const activeSSECount = ref(0); const enableStreaming = ref(true); const attachmentCache = new Map(); // attachment_id -> blob URL + const currentRequestController = ref(null); + const currentReader = ref | null>(null); + const currentRunningSessionId = ref(''); + const userStopRequested = ref(false); // 当前会话的项目信息 const currentSessionProject = ref<{ project_id: string; title: string; emoji: string } | null>(null); @@ -289,6 +293,8 @@ export function useMessages( if (activeSSECount.value === 1) { isConvRunning.value = true; } + userStopRequested.value = false; + currentRunningSessionId.value = currSessionId.value; // 收集所有 attachment_id const files = stagedFiles.map(f => f.attachment_id); @@ -330,12 +336,15 @@ export function useMessages( messageToSend = prompt; } + const controller = new AbortController(); + currentRequestController.value = controller; const response = await fetch('/api/chat/send', { method: 'POST', headers: { 'Content-Type': 'application/json', 'Authorization': 'Bearer ' + localStorage.getItem('token') }, + signal: controller.signal, body: JSON.stringify({ message: messageToSend, session_id: currSessionId.value, @@ -350,6 +359,7 @@ export function useMessages( } const reader = response.body!.getReader(); + currentReader.value = reader; const decoder = new TextDecoder(); let in_streaming = false; let message_obj: MessageContent | null = null; @@ -560,7 +570,9 @@ export function useMessages( } } } catch (readError) { - console.error('SSE读取错误:', readError); + if (!userStopRequested.value) { + console.error('SSE读取错误:', readError); + } break; } } @@ -569,7 +581,9 @@ export function useMessages( onSessionsUpdate(); } catch (err) { - console.error('发送消息失败:', err); + if (!userStopRequested.value) { + console.error('发送消息失败:', err); + } // 移除加载占位符 const lastMsg = messages.value[messages.value.length - 1]; if (lastMsg?.content?.isLoading) { @@ -577,6 +591,10 @@ export function useMessages( } } finally { isStreaming.value = false; + currentReader.value = null; + currentRequestController.value = null; + currentRunningSessionId.value = ''; + userStopRequested.value = false; activeSSECount.value--; if (activeSSECount.value === 0) { isConvRunning.value = false; @@ -584,6 +602,33 @@ export function useMessages( } } + async function stopMessage() { + const sessionId = currentRunningSessionId.value || currSessionId.value; + if (!sessionId) { + return; + } + + userStopRequested.value = true; + try { + await axios.post('/api/chat/stop', { + session_id: sessionId + }); + } catch (err) { + console.error('停止会话失败:', err); + } + + try { + await currentReader.value?.cancel(); + } catch (err) { + // ignore reader cancel failures + } + currentReader.value = null; + currentRequestController.value?.abort(); + currentRequestController.value = null; + + isStreaming.value = false; + } + return { messages, isStreaming, @@ -592,6 +637,7 @@ export function useMessages( currentSessionProject, getSessionMessages, sendMessage, + stopMessage, toggleStreaming, getAttachment }; diff --git a/dashboard/src/i18n/locales/en-US/features/chat.json b/dashboard/src/i18n/locales/en-US/features/chat.json index 5a54aee45d..8aababccff 100644 --- a/dashboard/src/i18n/locales/en-US/features/chat.json +++ b/dashboard/src/i18n/locales/en-US/features/chat.json @@ -9,7 +9,8 @@ "voice": "Voice Input", "recordingPrompt": "Recording, please speak...", "chatPrompt": "Let's chat!", - "dropToUpload": "Drop files to upload" + "dropToUpload": "Drop files to upload", + "stopGenerating": "Stop generating" }, "message": { "user": "User", diff --git a/dashboard/src/i18n/locales/zh-CN/features/chat.json b/dashboard/src/i18n/locales/zh-CN/features/chat.json index b97d0838eb..b3a2537f43 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/chat.json +++ b/dashboard/src/i18n/locales/zh-CN/features/chat.json @@ -9,7 +9,8 @@ "voice": "语音输入", "recordingPrompt": "录音中,请说话...", "chatPrompt": "聊天吧!", - "dropToUpload": "松开鼠标上传文件" + "dropToUpload": "松开鼠标上传文件", + "stopGenerating": "停止生成" }, "message": { "user": "用户", diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 4a91877fd3..0b5190407d 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -105,6 +105,28 @@ async def text_chat(self, **kwargs) -> LLMResponse: ) +class MockAbortableStreamProvider(MockProvider): + async def text_chat_stream(self, **kwargs): + abort_signal = kwargs.get("abort_signal") + yield LLMResponse( + role="assistant", + completion_text="partial ", + is_chunk=True, + ) + if abort_signal and abort_signal.is_set(): + yield LLMResponse( + role="assistant", + completion_text="partial ", + is_chunk=False, + ) + return + yield LLMResponse( + role="assistant", + completion_text="partial final", + is_chunk=False, + ) + + class MockHooks(BaseAgentRunHooks): """模拟钩子函数""" @@ -394,6 +416,41 @@ async def test_fallback_provider_used_when_primary_returns_err( assert fallback_provider.call_count == 1 +@pytest.mark.asyncio +async def test_stop_signal_returns_aborted_and_persists_partial_message( + runner, provider_request, mock_tool_executor, mock_hooks +): + provider = MockAbortableStreamProvider() + + await runner.reset( + provider=provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=True, + ) + + step_iter = runner.step() + first_resp = await step_iter.__anext__() + assert first_resp.type == "streaming_delta" + + runner.request_stop() + + rest_responses = [] + async for response in step_iter: + rest_responses.append(response) + + assert any(resp.type == "aborted" for resp in rest_responses) + assert runner.was_aborted() is True + + final_resp = runner.get_final_llm_resp() + assert final_resp is not None + assert final_resp.role == "assistant" + assert final_resp.completion_text == "partial " + assert runner.run_context.messages[-1].role == "assistant" + + if __name__ == "__main__": # 运行测试 pytest.main([__file__, "-v"])