Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions astrbot/builtin_stars/builtin_commands/commands/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,30 @@ async def reset(self, message: AstrMessageEvent) -> None:

message.set_result(MessageEventResult().message(ret))

async def stop(self, message: AstrMessageEvent) -> None:
"""停止当前会话正在运行的 Agent"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
umo = message.unified_msg_origin

if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
stopped_count = active_event_registry.stop_all(umo, exclude=message)
else:
stopped_count = active_event_registry.request_agent_stop_all(
umo,
exclude=message,
)

if stopped_count > 0:
message.set_result(
MessageEventResult().message(
f"已请求停止 {stopped_count} 个运行中的任务。"
)
)
return

message.set_result(MessageEventResult().message("当前会话没有运行中的任务。"))

async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话记录"""
if not self.context.get_using_provider(message.unified_msg_origin):
Expand Down
5 changes: 5 additions & 0 deletions astrbot/builtin_stars/builtin_commands/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ async def reset(self, message: AstrMessageEvent) -> None:
"""重置 LLM 会话"""
await self.conversation_c.reset(message)

@filter.command("stop")
async def stop(self, message: AstrMessageEvent) -> None:
"""停止当前会话中正在运行的 Agent"""
await self.conversation_c.stop(message)

@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("model")
async def model_ls(
Expand Down
58 changes: 58 additions & 0 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ async def reset(
self.tool_executor = tool_executor
self.agent_hooks = agent_hooks
self.run_context = run_context
self._stop_requested = False
self._aborted = False

# These two are used for tool schema mode handling
# We now have two modes:
Expand Down Expand Up @@ -328,6 +330,14 @@ async def step(self):
),
),
)
if self._stop_requested:
llm_resp_result = LLMResponse(
role="assistant",
completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]",
reasoning_content=llm_response.reasoning_content,
reasoning_signature=llm_response.reasoning_signature,
)
break
continue
llm_resp_result = llm_response

Expand All @@ -339,6 +349,48 @@ async def step(self):
break # got final response

if not llm_resp_result:
if self._stop_requested:
llm_resp_result = LLMResponse(role="assistant", completion_text="")
else:
return

if self._stop_requested:
logger.info("Agent execution was requested to stop by user.")
llm_resp = llm_resp_result
if llm_resp.role != "assistant":
llm_resp = LLMResponse(
role="assistant",
completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]",
)
self.final_llm_resp = llm_resp
self._aborted = True
self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()

parts = []
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
parts.append(
ThinkPart(
think=llm_resp.reasoning_content,
encrypted=llm_resp.reasoning_signature,
)
)
if llm_resp.completion_text:
parts.append(TextPart(text=llm_resp.completion_text))
if parts:
self.run_context.messages.append(
Message(role="assistant", content=parts)
)

try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)

yield AgentResponse(
type="aborted",
data=AgentResponseData(chain=MessageChain(type="aborted")),
)
return

# 处理 LLM 响应
Expand Down Expand Up @@ -848,5 +900,11 @@ def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)

def request_stop(self) -> None:
self._stop_requested = True

def was_aborted(self) -> bool:
return self._aborted

def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp
44 changes: 43 additions & 1 deletion astrbot/core/astr_agent_run_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]


def _should_stop_agent(astr_event) -> bool:
return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested"))


async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
Expand Down Expand Up @@ -48,10 +52,28 @@ async def run_agent(
)
)

stop_watcher = asyncio.create_task(
_watch_agent_stop_signal(agent_runner, astr_event),
)
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
if _should_stop_agent(astr_event):
agent_runner.request_stop()

if resp.type == "aborted":
if not stop_watcher.done():
stop_watcher.cancel()
try:
await stop_watcher
except asyncio.CancelledError:
pass
astr_event.set_extra("agent_user_aborted", True)
astr_event.set_extra("agent_stop_requested", False)
return
Comment on lines +63 to 72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to cancel the stop_watcher task is duplicated in three places (here, after the loop on lines 145-150, and in the except block on lines 164-169). This increases code complexity and the risk of bugs if one of the locations is missed during future changes.

Consider refactoring this using a try...finally block to centralize the cleanup logic. This ensures stop_watcher is always cancelled, regardless of how the try block is exited (e.g., via return, break, or an exception).

Example:

stop_watcher = asyncio.create_task(...)
try:
    # Main agent loop logic
    async for resp in agent_runner.step():
        # ...
        if resp.type == "aborted":
            # ... set extras
            return
        # ...
    # ...
except Exception as e:
    # ... error handling
    return
finally:
    if not stop_watcher.done():
        stop_watcher.cancel()
        try:
            await stop_watcher
        except asyncio.CancelledError:
            pass


if _should_stop_agent(astr_event):
continue

if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]

Expand Down Expand Up @@ -120,6 +142,12 @@ async def run_agent(
# display the reasoning content only when configured
continue
yield resp.data["chain"] # MessageChain
if not stop_watcher.done():
stop_watcher.cancel()
try:
await stop_watcher
except asyncio.CancelledError:
pass
if agent_runner.done():
# send agent stats to webchat
if astr_event.get_platform_name() == "webchat":
Expand All @@ -133,6 +161,12 @@ async def run_agent(
break

except Exception as e:
if "stop_watcher" in locals() and not stop_watcher.done():
stop_watcher.cancel()
try:
await stop_watcher
except asyncio.CancelledError:
pass
logger.error(traceback.format_exc())

err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
Expand All @@ -155,6 +189,14 @@ async def run_agent(
return


async def _watch_agent_stop_signal(agent_runner: AgentRunner, astr_event) -> None:
while not agent_runner.done():
if _should_stop_agent(astr_event):
agent_runner.request_stop()
return
await asyncio.sleep(0.5)


async def run_live_agent(
agent_runner: AgentRunner,
tts_provider: TTSProvider | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,16 @@ async def process(
yield

# 保存历史记录
if not event.is_stopped() and agent_runner.done():
if agent_runner.done() and (
not event.is_stopped() or agent_runner.was_aborted()
):
await self._save_to_history(
event,
req,
agent_runner.get_final_llm_resp(),
agent_runner.run_context.messages,
agent_runner.stats,
user_aborted=agent_runner.was_aborted(),
)

elif streaming_response and not stream_to_general:
Expand Down Expand Up @@ -308,13 +311,14 @@ async def process(
)

# 检查事件是否被停止,如果被停止则不保存历史记录
if not event.is_stopped():
if not event.is_stopped() or agent_runner.was_aborted():
await self._save_to_history(
event,
req,
final_resp,
agent_runner.run_context.messages,
agent_runner.stats,
user_aborted=agent_runner.was_aborted(),
)

asyncio.create_task(
Expand All @@ -340,16 +344,29 @@ async def _save_to_history(
llm_response: LLMResponse | None,
all_messages: list[Message],
runner_stats: AgentStats | None,
user_aborted: bool = False,
) -> None:
if (
not req
or not req.conversation
or not llm_response
or llm_response.role != "assistant"
):
if not req or not req.conversation:
return

if not llm_response.completion_text and not req.tool_calls_result:
if not llm_response and not user_aborted:
return

if llm_response and llm_response.role != "assistant":
if not user_aborted:
return
llm_response = LLMResponse(
role="assistant",
completion_text=llm_response.completion_text or "",
)
elif llm_response is None:
llm_response = LLMResponse(role="assistant", completion_text="")

if (
not llm_response.completion_text
and not req.tool_calls_result
and not user_aborted
):
logger.debug("LLM 响应为空,不保存记录。")
return

Expand All @@ -363,6 +380,14 @@ async def _save_to_history(
continue
message_to_save.append(message.model_dump())

# if user_aborted:
# message_to_save.append(
# Message(
# role="assistant",
# content="[User aborted this request. Partial output before abort was preserved.]",
# ).model_dump()
# )

token_usage = None
if runner_stats:
# token_usage = runner_stats.token_usage.total
Expand Down
10 changes: 5 additions & 5 deletions astrbot/core/platform/sources/webchat/webchat_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@

from .webchat_queue_mgr import webchat_queue_mgr

imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")


class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
os.makedirs(imgs_dir, exist_ok=True)
os.makedirs(attachments_dir, exist_ok=True)

@staticmethod
async def _send(
Expand Down Expand Up @@ -69,7 +69,7 @@ async def _send(
elif isinstance(comp, Image):
# save image to local
filename = f"{str(uuid.uuid4())}.jpg"
path = os.path.join(imgs_dir, filename)
path = os.path.join(attachments_dir, filename)
image_base64 = await comp.convert_to_base64()
with open(path, "wb") as f:
f.write(base64.b64decode(image_base64))
Expand All @@ -85,7 +85,7 @@ async def _send(
elif isinstance(comp, Record):
# save record to local
filename = f"{str(uuid.uuid4())}.wav"
path = os.path.join(imgs_dir, filename)
path = os.path.join(attachments_dir, filename)
record_base64 = await comp.convert_to_base64()
with open(path, "wb") as f:
f.write(base64.b64decode(record_base64))
Expand All @@ -104,7 +104,7 @@ async def _send(
original_name = comp.name or os.path.basename(file_path)
ext = os.path.splitext(original_name)[1] or ""
filename = f"{uuid.uuid4()!s}{ext}"
dest_path = os.path.join(imgs_dir, filename)
dest_path = os.path.join(attachments_dir, filename)
shutil.copy2(file_path, dest_path)
data = f"[FILE]{filename}"
await web_chat_back_queue.put(
Expand Down
17 changes: 17 additions & 0 deletions astrbot/core/utils/active_event_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,22 @@ def stop_all(
count += 1
return count

def request_agent_stop_all(
self,
umo: str,
exclude: AstrMessageEvent | None = None,
) -> int:
"""请求停止指定 UMO 的所有活跃事件中的 Agent 运行。

与 stop_all 不同,这里不会调用 event.stop_event(),
因此不会中断事件传播,后续流程(如历史记录保存)仍可继续。
"""
count = 0
for event in list(self._events.get(umo, [])):
if event is not exclude:
event.set_extra("agent_stop_requested", True)
count += 1
return count


active_event_registry = ActiveEventRegistry()
Loading