Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
import os
import sys
import time
import uuid
from collections.abc import Awaitable, Callable
from typing import Any, cast

import quart
from requests import Response
from wechatpy import WeChatClient, parse_message
from wechatpy import WeChatClient, create_reply, parse_message
from wechatpy.crypto import WeChatCrypto
from wechatpy.exceptions import InvalidSignatureException
from wechatpy.messages import BaseMessage, ImageMessage, TextMessage, VoiceMessage
Expand Down Expand Up @@ -38,7 +39,12 @@


class WeixinOfficialAccountServer:
def __init__(self, event_queue: asyncio.Queue, config: dict) -> None:
def __init__(
self,
event_queue: asyncio.Queue,
config: dict,
user_buffer: dict[Any, dict[str, Any]],
) -> None:
self.server = quart.Quart(__name__)
self.port = int(cast(int | str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
Expand All @@ -62,6 +68,10 @@ def __init__(self, event_queue: asyncio.Queue, config: dict) -> None:
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event()

self._wx_msg_time_out = 4.0 # 微信服务器要求 5 秒内回复
self.user_buffer: dict[str, dict[str, Any]] = user_buffer # from_user -> state
self.active_send_mode = False # 是否启用主动发送模式,启用后 callback 将直接返回回复内容,无需等待微信回调

async def verify(self):
"""内部服务器的 GET 验证入口"""
return await self.handle_verify(quart.request)
Expand Down Expand Up @@ -98,6 +108,22 @@ async def callback_command(self):
"""内部服务器的 POST 回调入口"""
return await self.handle_callback(quart.request)

def _maybe_encrypt(self, xml: str, nonce: str | None, timestamp: str | None) -> str:
if xml and "<Encrypt>" not in xml and nonce and timestamp:
return self.crypto.encrypt_message(xml, nonce, timestamp)
return xml or "success"

def _preview(self, msg: BaseMessage, limit: int = 24) -> str:
"""生成消息预览文本,供占位符使用"""
if isinstance(msg, TextMessage):
t = cast(str, msg.content).strip()
return (t[:limit] + "...") if len(t) > limit else (t or "空消息")
if isinstance(msg, ImageMessage):
return "图片"
if isinstance(msg, VoiceMessage):
return "语音"
return getattr(msg, "type", "未知消息")

async def handle_callback(self, request) -> str:
"""处理回调请求,可被统一 webhook 入口复用

Expand All @@ -123,14 +149,152 @@ async def handle_callback(self, request) -> str:
raise
logger.info(f"解析成功: {msg}")

if self.callback:
if not self.callback:
return "success"

# by pass passive reply logic and return active reply directly.
if self.active_send_mode:
result_xml = await self.callback(msg)
if not result_xml:
return "success"
if isinstance(result_xml, str):
return result_xml

return "success"
# passive reply
from_user = str(getattr(msg, "source", ""))
msg_id = str(cast(str | int, getattr(msg, "id", "")))
state = self.user_buffer.get(from_user)

def _reply_text(text: str) -> str:
reply_obj = create_reply(text, msg)
reply_xml = reply_obj if isinstance(reply_obj, str) else str(reply_obj)
return self._maybe_encrypt(reply_xml, nonce, timestamp)

# if in cached state, return cached result or placeholder
if state:
logger.debug(f"用户消息缓冲状态: user={from_user} state={state}")
cached = state.get("cached_xml")
# send one cached each time, if cached is empty after pop, remove the buffer
if cached and len(cached) > 0:
logger.info(f"wx buffer hit on trigger: user={from_user}")
cached_xml = cached.pop(0)
if len(cached) == 0:
self.user_buffer.pop(from_user, None)
return _reply_text(cached_xml)
else:
return _reply_text(
cached_xml
+ "\n【后续消息还在缓冲中,回复任意文字继续获取】"
)

task: asyncio.Task | None = cast(asyncio.Task | None, state.get("task"))
placeholder = (
f"【正在思考'{state.get('preview', '...')}'中,已思考"
f"{int(time.monotonic() - state.get('started_at', time.monotonic()))}s,回复任意文字尝试获取回复】"
)

# same msgid => WeChat retry: wait a little; new msgid => user trigger: just placeholder
if task and state.get("msg_id") == msg_id:
done, _ = await asyncio.wait(
{task},
timeout=self._wx_msg_time_out,
return_when=asyncio.FIRST_COMPLETED,
)
if done:
try:
cached = state.get("cached_xml")
# send one cached each time, if cached is empty after pop, remove the buffer
if cached and len(cached) > 0:
logger.info(
f"wx buffer hit on retry window: user={from_user}"
)
cached_xml = cached.pop(0)
if len(cached) == 0:
self.user_buffer.pop(from_user, None)
logger.debug(
f"wx finished message sending in passive window: user={from_user} msg_id={msg_id} "
)
return _reply_text(cached_xml)
else:
logger.debug(
f"wx finished message sending in passive window but not final: user={from_user} msg_id={msg_id} "
)
return _reply_text(
cached_xml
+ "\n【后续消息还在缓冲中,回复任意文字继续获取】"
)
logger.info(
f"wx finished in window but not final; return placeholder: user={from_user} msg_id={msg_id} "
)
return _reply_text(placeholder)
except Exception:
logger.critical(
"wx task failed in passive window", exc_info=True
)
self.user_buffer.pop(from_user, None)
return _reply_text("处理消息失败,请稍后再试。")

logger.info(
f"wx passive window timeout: user={from_user} msg_id={msg_id}"
)
return _reply_text(placeholder)

logger.debug(f"wx trigger while thinking: user={from_user}")
return _reply_text(placeholder)

# create new trigger when state is empty, and store state in buffer
logger.debug(f"wx new trigger: user={from_user} msg_id={msg_id}")
preview = self._preview(msg)
placeholder = (
f"【正在思考'{preview}'中,已思考0s,回复任意文字尝试获取回复】"
)
logger.info(
f"wx start task: user={from_user} msg_id={msg_id} preview={preview}"
)

self.user_buffer[from_user] = state = {
"msg_id": msg_id,
"preview": preview,
"task": None, # set later after task created
"cached_xml": [], # for passive reply
"started_at": time.monotonic(),
}
self.user_buffer[from_user]["task"] = task = asyncio.create_task(
self.callback(msg)
)

# immediate return if done
done, _ = await asyncio.wait(
{task},
timeout=self._wx_msg_time_out,
return_when=asyncio.FIRST_COMPLETED,
)
if done:
try:
cached = state.get("cached_xml", None)
# send one cached each time, if cached is empty after pop, remove the buffer
if cached and len(cached) > 0:
logger.info(f"wx buffer hit immediately: user={from_user}")
cached_xml = cached.pop(0)
if len(cached) == 0:
self.user_buffer.pop(from_user, None)
return _reply_text(cached_xml)
else:
return _reply_text(
cached_xml
+ "\n【后续消息还在缓冲中,回复任意文字继续获取】"
)
logger.info(
f"wx not finished in first window; return placeholder: user={from_user} msg_id={msg_id} "
)
return _reply_text(placeholder)
except Exception:
logger.critical("wx task failed in first window", exc_info=True)
self.user_buffer.pop(from_user, None)
return _reply_text("处理消息失败,请稍后再试。")

logger.info(f"wx first window timeout: user={from_user} msg_id={msg_id}")
return _reply_text(placeholder)

async def start_polling(self) -> None:
logger.info(
Expand Down Expand Up @@ -176,7 +340,10 @@ def __init__(
if not self.api_base_url.endswith("/"):
self.api_base_url += "/"

self.server = WeixinOfficialAccountServer(self._event_queue, self.config)
self.user_buffer: dict[str, dict[str, Any]] = {} # from_user -> state
self.server = WeixinOfficialAccountServer(
self._event_queue, self.config, self.user_buffer
)

self.client = WeChatClient(
self.config["appid"].strip(),
Expand All @@ -193,28 +360,33 @@ async def callback(msg: BaseMessage):
try:
if self.active_send_mode:
await self.convert_message(msg, None)
return None

msg_id = str(cast(str | int, msg.id))
future = self.wexin_event_workers.get(msg_id)
if future:
logger.debug(f"duplicate message id checked: {msg.id}")
else:
if str(msg.id) in self.wexin_event_workers:
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
logger.debug(f"duplicate message id checked: {msg.id}")
else:
future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[str(cast(str | int, msg.id))] = future
await self.convert_message(msg, future)
future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[msg_id] = future
await self.convert_message(msg, future)
# I love shield so much!
result = await asyncio.wait_for(
asyncio.shield(future),
60,
) # wait for 60s
logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
return result # xml. see weixin_offacc_event.py
180,
) # wait for 180s
logger.debug(f"Got future result: {result}")
return result
except asyncio.TimeoutError:
pass
logger.info(f"callback 处理消息超时: message_id={msg.id}")
return create_reply("处理消息超时,请稍后再试。", msg)
except Exception as e:
logger.error(f"转换消息时出现异常: {e}")
finally:
self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)

self.server.callback = callback
self.server.active_send_mode = self.active_send_mode

@override
async def send_by_session(
Expand Down Expand Up @@ -336,12 +508,19 @@ async def convert_message(
await self.handle_msg(abm)

async def handle_msg(self, message: AstrBotMessage) -> None:
buffer = self.user_buffer.get(message.sender.user_id, None)
if buffer is None:
logger.critical(
f"用户消息未找到缓冲状态,无法处理消息: user={message.sender.user_id} message_id={message.message_id}"
)
return
message_event = WeixinOfficialAccountPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client,
message_out=buffer,
)
self.commit_event(message_event)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import os
from typing import cast
from typing import Any, cast

from wechatpy import WeChatClient
from wechatpy.replies import ImageReply, TextReply, VoiceReply
from wechatpy.replies import ImageReply, VoiceReply

from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
Expand All @@ -20,9 +20,11 @@ def __init__(
platform_meta: PlatformMetadata,
session_id: str,
client: WeChatClient,
message_out: dict[Any, Any],
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
self.message_out = message_out

@staticmethod
async def send_with_client(
Expand All @@ -32,27 +34,27 @@ async def send_with_client(
) -> None:
pass

async def split_plain(self, plain: str) -> list[str]:
"""将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符
async def split_plain(self, plain: str, max_length: int = 1024) -> list[str]:
"""将长文本分割成多个小文本, 每个小文本长度不超过 max_length 字符

Args:
plain (str): 要分割的长文本
Returns:
list[str]: 分割后的文本列表

"""
if len(plain) <= 2048:
if len(plain) <= max_length:
return [plain]
result = []
start = 0
while start < len(plain):
# 剩下的字符串长度<2048时结束
if start + 2048 >= len(plain):
# 剩下的字符串长度<max_length时结束
if start + max_length >= len(plain):
result.append(plain[start:])
break

# 向前搜索分割标点符号
end = min(start + 2048, len(plain))
end = min(start + max_length, len(plain))
cut_position = end
for i in range(end, start, -1):
if i < len(plain) and plain[i - 1] in [
Expand Down Expand Up @@ -87,19 +89,15 @@ async def send(self, message: MessageChain) -> None:
if isinstance(comp, Plain):
# Split long text messages if needed
plain_chunks = await self.split_plain(comp.text)
for chunk in plain_chunks:
if active_send_mode:
if active_send_mode:
for chunk in plain_chunks:
self.client.message.send_text(message_obj.sender.user_id, chunk)
else:
reply = TextReply(
content=chunk,
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
await asyncio.sleep(0.5) # Avoid sending too fast
else:
# disable passive sending, just store the chunks in
logger.debug(
f"split plain into {len(plain_chunks)} chunks for passive reply. Message not sent."
)
self.message_out["cached_xml"] = plain_chunks
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()

Expand Down