diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 6787370460..bb1b6b55e8 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1,11 +1,11 @@ """如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。""" import os -from typing import Any, TypedDict +from astrbot.core.computer.booters.cua_defaults import CUA_DEFAULT_CONFIG from astrbot.core.utils.astrbot_path import get_astrbot_data_path -VERSION = "4.23.0" +VERSION = "4.25.1" DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") PERSONAL_WECHAT_CONFIG_METADATA = { "weixin_oc_base_url": { @@ -111,6 +111,7 @@ "websearch_bocha_key": [], "websearch_brave_key": [], "websearch_baidu_app_builder_key": "", + "websearch_firecrawl_key": [], "web_search_link": False, "display_reasoning_text": False, "identifier": False, @@ -134,6 +135,7 @@ "streaming_response": False, "show_tool_use_status": False, "show_tool_call_result": False, + "buffer_intermediate_messages": False, "sanitize_context_by_modalities": False, "max_quoted_fallback_images": 20, "quoted_message_parser": { @@ -174,6 +176,12 @@ "shipyard_neo_access_token": "", "shipyard_neo_profile": "python-default", "shipyard_neo_ttl": 3600, + "cua_image": CUA_DEFAULT_CONFIG["image"], + "cua_os_type": CUA_DEFAULT_CONFIG["os_type"], + "cua_idle_timeout": CUA_DEFAULT_CONFIG["idle_timeout"], + "cua_telemetry_enabled": CUA_DEFAULT_CONFIG["telemetry_enabled"], + "cua_local": CUA_DEFAULT_CONFIG["local"], + "cua_api_key": CUA_DEFAULT_CONFIG["api_key"], }, "image_compress_enabled": True, "image_compress_options": { @@ -236,7 +244,10 @@ "dashboard": { "enable": True, "username": "astrbot", - "password": "77b90590a8945a7d36c963981a307dc9", + "password": "", + "pbkdf2_password": "", + "password_storage_upgraded": False, + "password_change_required": False, "jwt_secret": "", "host": "0.0.0.0", "port": 6185, @@ -283,27 +294,10 @@ "kb_final_top_k": 5, # 知识库检索最终返回结果数量 "kb_agentic_mode": False, "disable_builtin_commands": False, + "disable_metrics": False, } -class ChatProviderTemplate(TypedDict): - id: str - provider_source_id: str - model: str - modalities: list - custom_extra_body: dict[str, Any] - max_context_tokens: int - - -CHAT_PROVIDER_TEMPLATE = { - "id": "", - "provide_source_id": "", - "model": "", - "modalities": [], - "custom_extra_body": {}, - "max_context_tokens": 0, -} - """ AstrBot v3 时代的配置元数据,目前仅承担以下功能: @@ -324,7 +318,7 @@ class ChatProviderTemplate(TypedDict): "QQ 官方机器人(WebSocket)": { "id": "default", "type": "qq_official", - "enable": False, + "enable": True, "appid": "", "secret": "", "enable_group_c2c": True, @@ -333,7 +327,7 @@ class ChatProviderTemplate(TypedDict): "QQ 官方机器人(Webhook)": { "id": "default", "type": "qq_official_webhook", - "enable": False, + "enable": True, "appid": "", "secret": "", "is_sandbox": False, @@ -345,7 +339,7 @@ class ChatProviderTemplate(TypedDict): "OneBot v11": { "id": "default", "type": "aiocqhttp", - "enable": False, + "enable": True, "ws_reverse_host": "0.0.0.0", "ws_reverse_port": 6199, "ws_reverse_token": "", @@ -353,7 +347,7 @@ class ChatProviderTemplate(TypedDict): "微信公众平台": { "id": "weixin_official_account", "type": "weixin_official_account", - "enable": False, + "enable": True, "appid": "", "secret": "", "token": "", @@ -368,7 +362,7 @@ class ChatProviderTemplate(TypedDict): "企业微信(含微信客服)": { "id": "wecom", "type": "wecom", - "enable": False, + "enable": True, "corpid": "", "secret": "", "token": "", @@ -405,7 +399,7 @@ class ChatProviderTemplate(TypedDict): "个人微信": { "id": "weixin_personal", "type": "weixin_oc", - "enable": False, + "enable": True, "weixin_oc_base_url": "https://ilinkai.weixin.qq.com", "weixin_oc_bot_type": "3", "weixin_oc_qr_poll_interval": 1, @@ -415,8 +409,7 @@ class ChatProviderTemplate(TypedDict): "飞书(Lark)": { "id": "lark", "type": "lark", - "enable": False, - "lark_bot_name": "", + "enable": True, "app_id": "", "app_secret": "", "domain": "https://open.feishu.cn", @@ -428,7 +421,7 @@ class ChatProviderTemplate(TypedDict): "钉钉(DingTalk)": { "id": "dingtalk", "type": "dingtalk", - "enable": False, + "enable": True, "client_id": "", "client_secret": "", "card_template_id": "", @@ -436,7 +429,7 @@ class ChatProviderTemplate(TypedDict): "Telegram": { "id": "telegram", "type": "telegram", - "enable": False, + "enable": True, "telegram_token": "your_bot_token", "start_message": "Hello, I'm AstrBot!", "telegram_api_base_url": "https://api.telegram.org/bot", @@ -449,7 +442,7 @@ class ChatProviderTemplate(TypedDict): "Discord": { "id": "discord", "type": "discord", - "enable": False, + "enable": True, "discord_token": "", "discord_proxy": "", "discord_command_register": True, @@ -459,7 +452,7 @@ class ChatProviderTemplate(TypedDict): "Misskey": { "id": "misskey", "type": "misskey", - "enable": False, + "enable": True, "misskey_instance_url": "https://misskey.example", "misskey_token": "", "misskey_default_visibility": "public", @@ -477,7 +470,7 @@ class ChatProviderTemplate(TypedDict): "Slack": { "id": "slack", "type": "slack", - "enable": False, + "enable": True, "bot_token": "", "app_token": "", "signing_secret": "", @@ -491,7 +484,7 @@ class ChatProviderTemplate(TypedDict): "Line": { "id": "line", "type": "line", - "enable": False, + "enable": True, "channel_access_token": "", "channel_secret": "", "unified_webhook_mode": True, @@ -500,7 +493,7 @@ class ChatProviderTemplate(TypedDict): "Satori": { "id": "satori", "type": "satori", - "enable": False, + "enable": True, "satori_api_base_url": "http://localhost:5140/satori/v1", "satori_endpoint": "ws://localhost:5140/satori/v1/events", "satori_token": "", @@ -511,7 +504,7 @@ class ChatProviderTemplate(TypedDict): "KOOK": { "id": "kook", "type": "kook", - "enable": False, + "enable": True, "kook_bot_token": "", "kook_reconnect_delay": 1, "kook_max_reconnect_delay": 60, @@ -524,7 +517,7 @@ class ChatProviderTemplate(TypedDict): "Mattermost": { "id": "mattermost", "type": "mattermost", - "enable": False, + "enable": True, "mattermost_url": "https://chat.example.com", "mattermost_bot_token": "", "mattermost_reconnect_delay": 5.0, @@ -782,7 +775,7 @@ class ChatProviderTemplate(TypedDict): "appid": { "description": "appid", "type": "string", - "hint": "必填项。QQ 官方机器人平台的 appid。如何获取请参考文档。", + "hint": "必填项。当前消息平台的 AppID。如何获取请参考对应平台接入文档。", }, "secret": { "description": "secret", @@ -895,11 +888,6 @@ class ChatProviderTemplate(TypedDict): "wecom_ai_bot_connection_mode": "long_connection", }, }, - "lark_bot_name": { - "description": "飞书机器人的名字", - "type": "string", - "hint": "请务必填写正确,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。", - }, "discord_token": { "description": "Discord Bot Token", "type": "string", @@ -1206,7 +1194,7 @@ class ChatProviderTemplate(TypedDict): "provider_type": "chat_completion", "enable": True, "key": [], - "api_base": "https://api.kimi.com/coding/", + "api_base": "https://api.kimi.com/coding", "timeout": 120, "proxy": "", "custom_headers": {"User-Agent": "claude-code/0.1.0"}, @@ -1236,6 +1224,19 @@ class ChatProviderTemplate(TypedDict): "proxy": "", "custom_headers": {}, }, + "MiniMax Token Plan": { + "id": "minimax-token-plan", + "provider": "minimax-token-plan", + "type": "minimax_token_plan", + "provider_type": "chat_completion", + "enable": True, + "key": [], + "api_base": "https://api.minimaxi.com/anthropic", + "timeout": 120, + "proxy": "", + "custom_headers": {"User-Agent": "claude-code/0.1.0"}, + "anth_thinking_config": {"type": "", "budget": 0, "effort": ""}, + }, "xAI": { "id": "xai", "provider": "xai", @@ -1796,6 +1797,48 @@ class ChatProviderTemplate(TypedDict): "timeout": 20, "proxy": "", }, + "NVIDIA Embedding": { + "id": "nvidia_embedding", + "type": "nvidia_embedding", + "provider": "nvidia", + "provider_type": "embedding", + "hint": "provider_group.provider.nvidia_embedding.hint", + "enable": True, + "embedding_api_key": "", + "embedding_api_base": "https://integrate.api.nvidia.com/v1", + "embedding_model": "nvidia/llama-nemotron-embed-1b-v2", + "input_type": "passage", + "embedding_dimensions": 1024, + "timeout": 20, + "proxy": "", + }, + "Ollama Embedding": { + "id": "ollama_embedding", + "type": "ollama_embedding", + "provider": "ollama", + "provider_type": "embedding", + "hint": "provider_group.provider.ollama_embedding.hint", + "enable": True, + "embedding_api_base": "http://localhost:11434", + "embedding_model": "nomic-embed-text", + "embedding_dimensions": 768, + "timeout": 60, + "proxy": "", + }, + "vLLM Embedding": { + "id": "vllm_embedding", + "type": "vllm_embedding", + "provider": "vllm", + "provider_type": "embedding", + "hint": "面向 vLLM OpenAI-compatible Embedding 接口。请求时会自动跳过 dimensions,并尝试将模型名对齐到 served-model-name。", + "enable": False, + "embedding_api_key": "", + "embedding_api_base": "", + "embedding_model": "", + "embedding_dimensions": 0, + "timeout": 20, + "proxy": "", + }, "vLLM Rerank": { "id": "vllm_rerank", "type": "vllm_rerank", @@ -1949,13 +1992,13 @@ class ChatProviderTemplate(TypedDict): "options": ["text", "image", "audio", "tool_use"], "labels": ["文本", "图像", "音频", "工具使用"], "render_type": "checkbox", - "hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。", + "hint": "模型支持的模态及能力。", }, "custom_headers": { - "description": "自定义添加请求头", + "description": "自定义请求头", "type": "dict", "items": {}, - "hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。值必须为字符串。", + "hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。", }, "ollama_disable_thinking": { "description": "关闭思考模式", @@ -1966,7 +2009,7 @@ class ChatProviderTemplate(TypedDict): "description": "自定义请求体参数", "type": "dict", "items": {}, - "hint": "用于在请求时添加额外的参数,如 temperature、top_p、max_tokens 等。", + "hint": "用于在请求时添加额外的参数,如 temperature, top_p, max_tokens, reasoning_effort 等。", "template_schema": { "temperature": { "name": "Temperature", @@ -2609,7 +2652,7 @@ class ChatProviderTemplate(TypedDict): "max_context_tokens": { "description": "模型上下文窗口大小", "type": "int", - "hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。", + "hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有)", }, "dify_api_key": { "description": "API Key", @@ -2764,6 +2807,9 @@ class ChatProviderTemplate(TypedDict): "show_tool_call_result": { "type": "bool", }, + "buffer_intermediate_messages": { + "type": "bool", + }, "unsupported_streaming_strategy": { "type": "string", }, @@ -2918,6 +2964,11 @@ class ChatProviderTemplate(TypedDict): "callback_api_base": { "type": "string", }, + "disable_metrics": { + "description": "禁用匿名使用统计", + "type": "bool", + "hint": "禁用后,AstrBot 将不再上传匿名使用统计数据。", + }, "log_level": { "type": "string", "options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], @@ -3185,6 +3236,7 @@ class ChatProviderTemplate(TypedDict): "baidu_ai_search", "bocha", "brave", + "firecrawl", ], "condition": { "provider_settings.web_search": True, @@ -3220,12 +3272,23 @@ class ChatProviderTemplate(TypedDict): "provider_settings.web_search": True, }, }, + "provider_settings.websearch_firecrawl_key": { + "description": "Firecrawl API Key", + "type": "list", + "items": {"type": "string"}, + "hint": "可添加多个 Key 进行轮询。", + "condition": { + "provider_settings.websearch_provider": "firecrawl", + "provider_settings.web_search": True, + }, + }, "provider_settings.websearch_baidu_app_builder_key": { "description": "百度千帆智能云 APP Builder API Key", "type": "string", "hint": "参考:https://console.bce.baidu.com/iam/#/iam/apikey/list", "condition": { "provider_settings.websearch_provider": "baidu_ai_search", + "provider_settings.web_search": True, }, }, "provider_settings.web_search_link": { @@ -3261,8 +3324,8 @@ class ChatProviderTemplate(TypedDict): "provider_settings.sandbox.booter": { "description": "沙箱环境驱动器", "type": "string", - "options": ["shipyard_neo", "shipyard"], - "labels": ["Shipyard Neo", "Shipyard"], + "options": ["shipyard_neo", "shipyard", "cua"], + "labels": ["Shipyard Neo", "Shipyard", "CUA"], "condition": { "provider_settings.computer_use_runtime": "sandbox", }, @@ -3288,7 +3351,7 @@ class ChatProviderTemplate(TypedDict): "provider_settings.sandbox.shipyard_neo_profile": { "description": "Shipyard Neo Profile", "type": "string", - "hint": "Shipyard Neo 沙箱 profile,如 python-default。", + "hint": "Shipyard Neo 沙箱 profile,如 python-default。留空时自动选择能力更完整的 profile。", "condition": { "provider_settings.computer_use_runtime": "sandbox", "provider_settings.sandbox.booter": "shipyard_neo", @@ -3303,6 +3366,64 @@ class ChatProviderTemplate(TypedDict): "provider_settings.sandbox.booter": "shipyard_neo", }, }, + "provider_settings.sandbox.cua_image": { + "description": "CUA Image", + "type": "string", + "hint": "CUA 沙箱镜像/系统类型,默认 linux。可填写 linux、macos、windows、android,具体取决于 CUA SDK 支持。", + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + }, + }, + "provider_settings.sandbox.cua_os_type": { + "description": "CUA OS Type", + "type": "string", + "options": ["linux", "macos", "windows", "android"], + "labels": ["Linux", "macOS", "Windows", "Android"], + "hint": "CUA 沙箱操作系统类型,默认 linux。", + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + }, + }, + "provider_settings.sandbox.cua_idle_timeout": { + "description": "CUA Idle Timeout", + "type": "int", + "hint": "Idle timeout for CUA sandbox sessions in seconds. When greater than 0, AstrBot proactively shuts down an idle CUA sandbox after that amount of inactivity; 0 disables it.", + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + }, + }, + "provider_settings.sandbox.cua_telemetry_enabled": { + "description": "CUA Telemetry", + "type": "bool", + "hint": "是否允许 CUA SDK 发送遥测数据。默认关闭。", + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + }, + }, + "provider_settings.sandbox.cua_local": { + "description": "CUA Local Sandbox", + "type": "bool", + "hint": "是否优先使用 CUA 本地沙箱。默认开启,避免云端沙箱要求 CUA_API_KEY。关闭后可使用 CUA 云端沙箱。", + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + }, + }, + "provider_settings.sandbox.cua_api_key": { + "description": "CUA API Key", + "type": "string", + "hint": "CUA 云端沙箱 API Key。仅在关闭本地沙箱时需要。也可以通过 CUA_API_KEY 环境变量提供。", + "obvious_hint": True, + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + "provider_settings.sandbox.cua_local": False, + }, + }, "provider_settings.sandbox.shipyard_endpoint": { "description": "Shipyard API Endpoint", "type": "string", @@ -3451,6 +3572,14 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.fallback_max_context_tokens": { + "description": "上下文窗口兜底值", + "type": "int", + "hint": "当 max_context_tokens 为 0 且模型不在内置元数据中时,使用此值作为上下文窗口大小。默认 128000。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, }, "condition": { "provider_settings.agent_runner_type": "local", @@ -3530,6 +3659,15 @@ class ChatProviderTemplate(TypedDict): "provider_settings.show_tool_use_status": True, }, }, + "provider_settings.buffer_intermediate_messages": { + "description": "合并 Agent 中间消息", + "type": "bool", + "hint": "开启后,非流式模式下多步工具调用过程中产生的中间文本将缓冲,待 Agent 完成后合并为一条回复发送。", + "condition": { + "provider_settings.agent_runner_type": "local", + "provider_settings.streaming_response": False, + }, + }, "provider_settings.sanitize_context_by_modalities": { "description": "按模型能力清理历史上下文", "type": "bool", @@ -3567,11 +3705,6 @@ class ChatProviderTemplate(TypedDict): "type": "string", "hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求", }, - "provider_settings.prompt_prefix": { - "description": "用户提示词", - "type": "string", - "hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。", - }, "provider_settings.image_compress_enabled": { "description": "启用图片压缩", "type": "bool", @@ -3595,6 +3728,12 @@ class ChatProviderTemplate(TypedDict): }, "slider": {"min": 1, "max": 100, "step": 1}, }, + "provider_settings.prompt_prefix": { + "description": "用户提示词", + "type": "string", + "hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。", + "collapsed": True, + }, "provider_tts_settings.dual_output": { "description": "开启 TTS 时同时输出语音和文字内容", "type": "bool", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index b4ff2742d0..45182db5fb 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -363,6 +363,10 @@ def dynamic_import_provider(self, type: str) -> None: ) case "longcat_chat_completion": from .sources.longcat_source import ProviderLongCat as ProviderLongCat + case "minimax_token_plan": + from .sources.minimax_token_plan_source import ( + ProviderMiniMaxTokenPlan as ProviderMiniMaxTokenPlan, + ) case "zhipu_chat_completion": from .sources.zhipu_source import ProviderZhipu as ProviderZhipu case "groq_chat_completion": @@ -465,6 +469,18 @@ def dynamic_import_provider(self, type: str) -> None: from .sources.gemini_embedding_source import ( GeminiEmbeddingProvider as GeminiEmbeddingProvider, ) + case "nvidia_embedding": + from .sources.nvidia_embedding_source import ( + NvidiaEmbeddingProvider as NvidiaEmbeddingProvider, + ) + case "ollama_embedding": + from .sources.ollama_embedding_source import ( + OllamaEmbeddingProvider as OllamaEmbeddingProvider, + ) + case "vllm_embedding": + from .sources.vllm_embedding_source import ( + VLLMEmbeddingProvider as VLLMEmbeddingProvider, + ) case "vllm_rerank": from .sources.vllm_rerank_source import ( VLLMRerankProvider as VLLMRerankProvider, @@ -566,7 +582,9 @@ async def load_provider(self, provider_config: dict) -> None: return logger.info( - f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...", + "Loading model %s(%s) ...", + provider_config["type"], + provider_config["id"], ) # 动态导入 @@ -587,7 +605,7 @@ async def load_provider(self, provider_config: dict) -> None: if provider_config["type"] not in provider_cls_map: logger.error( - f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。", + f"Provider adapter not found: {provider_config['type']}({provider_config['id']}). Skipped.", exc_info=True, ) return @@ -621,7 +639,7 @@ async def load_provider(self, provider_config: dict) -> None: ): self.curr_stt_provider_inst = inst logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。", + f"Selected {provider_config['type']}({provider_config['id']}) as default STT provider", ) if not self.curr_stt_provider_inst: self.curr_stt_provider_inst = inst @@ -644,7 +662,7 @@ async def load_provider(self, provider_config: dict) -> None: ): self.curr_tts_provider_inst = inst logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", + f"Selected {provider_config['type']}({provider_config['id']}) as default TTS provider", ) if not self.curr_tts_provider_inst: self.curr_tts_provider_inst = inst @@ -670,7 +688,7 @@ async def load_provider(self, provider_config: dict) -> None: ): self.curr_provider_inst = inst logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。", + f"Selected {provider_config['type']}({provider_config['id']}) as default chat model provider", ) if not self.curr_provider_inst: self.curr_provider_inst = inst diff --git a/astrbot/core/provider/sources/embedding_utils.py b/astrbot/core/provider/sources/embedding_utils.py new file mode 100644 index 0000000000..9b8b0a5803 --- /dev/null +++ b/astrbot/core/provider/sources/embedding_utils.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any + +from astrbot import logger + + +COMMON_MODEL_DIMENSIONS = { + "bge-m3": 1024, + "bge-large-en-v1.5": 1024, + "bge-large-zh-v1.5": 1024, + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, +} + + +def parse_configured_embedding_dimension( + raw_dimension: Any, + *, + provider_label: str, + provider_id: str, +) -> int | None: + if raw_dimension in (None, ""): + return None + + try: + dimension = int(raw_dimension) + except (TypeError, ValueError): + logger.warning( + "[%s] %s 的 embedding_dimensions 不是有效整数: %r", + provider_label, + provider_id, + raw_dimension, + ) + return None + + return dimension if dimension > 0 else None + + +def infer_embedding_dimension_from_model(model_name: Any) -> int | None: + normalized_model = str(model_name or "").strip().lower() + for model_key, dimension in COMMON_MODEL_DIMENSIONS.items(): + if model_key in normalized_model: + return dimension + return None \ No newline at end of file diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index ae531996ae..5205d69c21 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -6,6 +6,10 @@ from ..entities import ProviderType from ..provider import EmbeddingProvider from ..register import register_provider_adapter +from .embedding_utils import ( + infer_embedding_dimension_from_model, + parse_configured_embedding_dimension, +) @register_provider_adapter( @@ -18,12 +22,14 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.provider_config = provider_config self.provider_settings = provider_settings + proxy = provider_config.get("proxy", "") provider_id = provider_config.get("id", "unknown_id") http_client = None if proxy: logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}") http_client = httpx.AsyncClient(proxy=proxy) + api_base = ( provider_config.get("embedding_api_base", "https://api.openai.com/v1") .strip() @@ -33,7 +39,12 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"): # /v4 see #5699 api_base = api_base + "/v1" + + # [新增] 保存处理后的 api_base 并转换为小写,用于后续特征比对 + self.api_base_normalized = api_base.lower() + logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}") + self.client = AsyncOpenAI( api_key=provider_config.get("embedding_api_key"), base_url=api_base, @@ -41,48 +52,168 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: http_client=http_client, ) self.model = provider_config.get("embedding_model", "text-embedding-3-small") + + # [新增] 运行时状态标记:一旦触发 400 错误将此设为 True + self._is_vllm_detected = False + + def _is_vllm(self) -> bool: + """检测是否是 vLLM(vLLM 不支持 dimensions 参数)""" + # 1. 优先检查运行时已证实的标记 + if self._is_vllm_detected: + return True + + # 2. [核心修改] 检查 API Key 是否为 "vllm" + api_key = self.provider_config.get("embedding_api_key", "") + if api_key and api_key.lower() == "vllm": + logger.info("[OpenAI Embedding] vLLM mode enabled by API Key 'vllm'.") + return True + + # 3. 辅助检查:ID 或 URL 中是否显式包含 "vllm" + provider_id = self.provider_config.get("id", "").lower() + api_base = self.api_base_normalized.lower() + if "vllm" in provider_id or "vllm" in api_base: + logger.info(f"[OpenAI Embedding] Detected vLLM by id/api_base: {provider_id}") + return True + + # 4. 移除对端口 (8000, 8001) 的静态判定,避免误伤其他兼容服务 + return False + + def _mark_as_vllm(self) -> None: + """标记此实例为vLLM(通过运行时错误检测出来的)""" + self._is_vllm_detected = True + logger.info("[OpenAI Embedding] Marked as vLLM (runtime detection via error)") async def get_embedding(self, text: str) -> list[float]: """获取文本的嵌入""" kwargs = self._embedding_kwargs() - embedding = await self.client.embeddings.create( - input=text, - model=self.model, - **kwargs, - ) + embedding = await self._request_with_vllm_retry(text, kwargs, batch=False) return embedding.data[0].embedding async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" kwargs = self._embedding_kwargs() - embeddings = await self.client.embeddings.create( - input=text, - model=self.model, - **kwargs, - ) + embeddings = await self._request_with_vllm_retry(text, kwargs, batch=True) return [item.embedding for item in embeddings.data] + async def _request_with_vllm_retry( + self, + input_data: str | list[str], + kwargs: dict, + *, + batch: bool, + ): + try: + return await self.client.embeddings.create( + input=input_data, + model=self.model, + **kwargs, + ) + except Exception as exc: + if not self._should_retry_without_dimensions(exc, kwargs): + raise + + if batch: + logger.warning( + f"[OpenAI Embedding] Detected vLLM dimensions error in batch mode, retrying without dimensions: {exc}" + ) + else: + logger.warning( + f"[OpenAI Embedding] Detected vLLM dimensions error, retrying without dimensions parameter: {exc}" + ) + + kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"} + try: + embeddings = await self.client.embeddings.create( + input=input_data, + model=self.model, + **kwargs_retry, + ) + except Exception as retry_error: + if batch: + logger.error( + f"[OpenAI Embedding] Batch retry without dimensions also failed: {retry_error}" + ) + else: + logger.error( + f"[OpenAI Embedding] Retry without dimensions also failed: {retry_error}" + ) + raise + + if batch: + logger.info( + "[OpenAI Embedding] Successfully retrieved batch embeddings without dimensions parameter" + ) + else: + logger.info( + "[OpenAI Embedding] Successfully retrieved embedding without dimensions parameter, marking as vLLM" + ) + + self._mark_as_vllm() + return embeddings + + def _should_retry_without_dimensions(self, exc: Exception, kwargs: dict) -> bool: + if not kwargs.get("dimensions"): + return False + + error_msg = str(exc).lower() + return "matryoshka" in error_msg or "dimensions" in error_msg + + def _configured_dimension(self) -> int | None: + provider_id = self.provider_config.get("id", "unknown") + return parse_configured_embedding_dimension( + self.provider_config.get("embedding_dimensions", ""), + provider_label="OpenAI Embedding", + provider_id=provider_id, + ) + def _embedding_kwargs(self) -> dict: """构建嵌入请求的可选参数""" kwargs = {} - if "embedding_dimensions" in self.provider_config: - try: - kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"]) - except (ValueError, TypeError): - logger.warning( - f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." - ) + provider_id = self.provider_config.get("id", "unknown") + embedding_dim_config = self.provider_config.get("embedding_dimensions", "") + # 检查是否是vLLM + is_vllm = self._is_vllm() + if is_vllm: + logger.info( + f"[OpenAI Embedding] {provider_id}: Detected vLLM, skipping dimensions parameter (config value: '{embedding_dim_config}')" + ) + return kwargs + # 非vLLM服务(OpenAI等)支持dimensions,读取配置 + configured_dim = self._configured_dimension() + if configured_dim is not None: + kwargs["dimensions"] = configured_dim + logger.info( + f"[OpenAI Embedding] {provider_id}: Added dimensions parameter: {configured_dim}" + ) + elif embedding_dim_config in (None, ""): + logger.info( + f"[OpenAI Embedding] {provider_id}: No embedding_dimensions configured, API will use default" + ) return kwargs def get_dim(self) -> int: """获取向量的维度""" - if "embedding_dimensions" in self.provider_config: - try: - return int(self.provider_config["embedding_dimensions"]) - except (ValueError, TypeError): - logger.warning( - f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." - ) + provider_id = self.provider_config.get("id", "unknown") + embedding_dim_config = self.provider_config.get("embedding_dimensions", "") + + configured_dim = self._configured_dimension() + if configured_dim is not None: + logger.info( + f"[OpenAI Embedding] {provider_id}: Dimension from config: {configured_dim}" + ) + return configured_dim + + model = self.provider_config.get("embedding_model", "") + inferred_dim = infer_embedding_dimension_from_model(model) + if inferred_dim: + logger.info( + f"[OpenAI Embedding] {provider_id}: Inferred dimension {inferred_dim} from model: {str(model).lower()}" + ) + return inferred_dim + + logger.warning( + f"[OpenAI Embedding] {provider_id}: Could not determine dimension (model: {str(model).lower()}, config: '{embedding_dim_config}')" + ) return 0 async def terminate(self): diff --git a/astrbot/core/provider/sources/vllm_embedding_source.py b/astrbot/core/provider/sources/vllm_embedding_source.py new file mode 100644 index 0000000000..ec09e4b32b --- /dev/null +++ b/astrbot/core/provider/sources/vllm_embedding_source.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +from ipaddress import ip_address +from typing import Any +from urllib.parse import urlparse + +import httpx +from openai import AsyncOpenAI + +from astrbot import logger + +from ..entities import ProviderType +from ..provider import EmbeddingProvider +from ..register import register_provider_adapter +from .embedding_utils import ( + infer_embedding_dimension_from_model, + parse_configured_embedding_dimension, +) + + +@register_provider_adapter( + "vllm_embedding", + "vLLM Embedding 提供商适配器", + provider_type=ProviderType.EMBEDDING, + provider_display_name="vLLM Embedding", +) +class VLLMEmbeddingProvider(EmbeddingProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.provider_config = provider_config + self.provider_settings = provider_settings + self.timeout = int(provider_config.get("timeout", 20) or 20) + self.model = str(provider_config.get("embedding_model", "") or "").strip() + self.set_model(self.model) + self._force_direct_transport = self._should_force_direct_transport() + + self._detected_dimension: int | None = None + self._resolved_request_model: str | None = None + self._direct_client_ready = self._force_direct_transport + + self.client = AsyncOpenAI( + api_key=provider_config.get("embedding_api_key"), + base_url=self._effective_api_base(), + timeout=self.timeout, + http_client=self._build_http_client(), + ) + + async def get_embedding(self, text: str) -> list[float]: + await self._ensure_runtime_ready() + request_model = await self._resolve_request_model() + logger.info( + "[vLLM Embedding] %s 发起单条 embedding 请求,model=%s,text_len=%s,跳过 dimensions。", + self._provider_id(), + request_model, + len(text), + ) + embedding = await self.client.embeddings.create( + input=text, + model=request_model, + ) + vector = embedding.data[0].embedding + self._cache_detected_dimension(len(vector)) + return vector + + async def get_embeddings(self, text: list[str]) -> list[list[float]]: + await self._ensure_runtime_ready() + request_model = await self._resolve_request_model() + total_chars = sum(len(item) for item in text) + logger.info( + "[vLLM Embedding] %s 发起批量 embedding 请求,model=%s,batch=%s,total_chars=%s,跳过 dimensions。", + self._provider_id(), + request_model, + len(text), + total_chars, + ) + embeddings = await self.client.embeddings.create( + input=text, + model=request_model, + ) + vectors = [item.embedding for item in embeddings.data] + if vectors: + self._cache_detected_dimension(len(vectors[0])) + return vectors + + def get_dim(self) -> int: + configured_dim = self._configured_dimension() + if configured_dim: + return configured_dim + if self._detected_dimension: + return self._detected_dimension + inferred_dim = self._infer_dimension_from_model(self.model) + if inferred_dim: + return inferred_dim + return 0 + + async def terminate(self) -> None: + if self.client: + await self.client.close() + + def _build_http_client(self) -> httpx.AsyncClient | None: + proxy = str(self.provider_config.get("proxy", "") or "").strip() + if proxy: + logger.info("[vLLM Embedding] %s 使用显式代理: %s", self._provider_id(), proxy) + return httpx.AsyncClient(proxy=proxy, timeout=self.timeout) + if self._force_direct_transport: + return httpx.AsyncClient(timeout=self.timeout, trust_env=False) + return None + + async def _ensure_runtime_ready(self) -> None: + if self._direct_client_ready or not self._should_force_direct_transport(): + return + + old_client = self.client + self.client = AsyncOpenAI( + api_key=self.provider_config.get("embedding_api_key"), + base_url=self._effective_api_base(), + timeout=self.timeout, + http_client=httpx.AsyncClient(timeout=self.timeout, trust_env=False), + ) + self._direct_client_ready = True + + logger.info( + "[vLLM Embedding] %s 检测到本地/内网端点,已切换为 trust_env=False 的直连 client。", + self._provider_id(), + ) + + if old_client is not None and old_client is not self.client: + try: + await old_client.close() + except Exception: + logger.debug("[vLLM Embedding] %s 关闭旧 client 失败,已忽略。", self._provider_id()) + + async def _resolve_request_model(self) -> str: + if self._resolved_request_model: + return self._resolved_request_model + + configured_model = self.model + if not configured_model: + self._resolved_request_model = configured_model + return configured_model + + available_models = await self._list_vllm_models() + resolved_model = self._match_served_model(configured_model, available_models) + if resolved_model: + self._resolved_request_model = resolved_model + if resolved_model != configured_model: + logger.info( + "[vLLM Embedding] %s 已将模型名 %s 对齐到 served-model-name %s。", + self._provider_id(), + configured_model, + resolved_model, + ) + return resolved_model + + basename_model = configured_model.rsplit("/", 1)[-1].strip() + if basename_model and basename_model != configured_model: + self._resolved_request_model = basename_model + logger.warning( + "[vLLM Embedding] %s 未能从 /models 精确匹配 %s,回退为 %s。", + self._provider_id(), + configured_model, + basename_model, + ) + return basename_model + + self._resolved_request_model = configured_model + return configured_model + + async def _list_vllm_models(self) -> list[dict[str, str]]: + try: + models = await self.client.models.list() + except Exception as exc: + logger.warning( + "[vLLM Embedding] %s 拉取 /models 失败,将直接使用配置模型名: %s", + self._provider_id(), + exc, + ) + return [] + + results: list[dict[str, str]] = [] + for item in getattr(models, "data", []) or []: + model_id = str(getattr(item, "id", "") or "").strip() + model_root = str(getattr(item, "root", "") or "").strip() + if model_id: + results.append({"id": model_id, "root": model_root}) + return results + + def _match_served_model( + self, + configured_model: str, + available_models: list[dict[str, str]], + ) -> str | None: + normalized_configured = configured_model.lower() + basename_model = configured_model.rsplit("/", 1)[-1].strip().lower() + + for item in available_models: + model_id = str(item.get("id", "") or "").strip() + model_root = str(item.get("root", "") or "").strip() + if model_id.lower() == normalized_configured: + return model_id + if model_root and model_root.lower() == normalized_configured: + return model_id + if basename_model and model_id.lower() == basename_model: + return model_id + return None + + def _configured_dimension(self) -> int | None: + return parse_configured_embedding_dimension( + self.provider_config.get("embedding_dimensions", ""), + provider_label="vLLM Embedding", + provider_id=self._provider_id(), + ) + + def _infer_dimension_from_model(self, model_name: Any) -> int | None: + return infer_embedding_dimension_from_model(model_name) + + def _cache_detected_dimension(self, dimension: int) -> None: + if isinstance(dimension, int) and dimension > 0: + self._detected_dimension = dimension + + def _effective_api_base(self) -> str: + api_base = str( + self.provider_config.get("embedding_api_base", "http://127.0.0.1:8000/v1") + or "http://127.0.0.1:8000/v1" + ).strip() + api_base = api_base.removesuffix("/").removesuffix("/embeddings") + if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"): + api_base = api_base + "/v1" + return api_base + + def _should_force_direct_transport(self) -> bool: + if str(self.provider_config.get("proxy", "") or "").strip(): + return False + + host = (urlparse(self._effective_api_base()).hostname or "").strip().lower() + if not host: + return False + if host in {"localhost", "127.0.0.1", "::1", "host.docker.internal"}: + return True + try: + parsed_host = ip_address(host) + except ValueError: + return False + return parsed_host.is_loopback or parsed_host.is_private + + def _provider_id(self) -> str: + return str(self.provider_config.get("id", "unknown") or "unknown") \ No newline at end of file diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index bcd7e075c7..8087b7bd83 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -906,6 +906,12 @@ async def get_embedding_dim(self): return Response().ok({"embedding_dimensions": dim}).__dict__ except Exception as e: logger.error(traceback.format_exc()) + err_msg = str(e).lower() + # [新增] 识别 vLLM 的特定报错关键字 + if "matryoshka" in err_msg or "dimensions" in err_msg: + logger.info("Detected vLLM specific error, bypassing...") + # 伪造一个成功的响应,告知前端进入"兼容模式" + return Response().ok({"embedding_dimensions": "vLLM-Adaptive"}).__dict__ return Response().error(f"获取嵌入维度失败: {e!s}").__dict__ async def get_provider_source_models(self): diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index 33273a36c9..244ce85fe0 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -102,19 +102,40 @@ function saveEditedContent() { dialog.value = false } +function getNumericEmbeddingDimension(value) { + if (typeof value === 'number' && Number.isInteger(value) && value >= 0) { + return value + } + + if (typeof value === 'string') { + const trimmedValue = value.trim() + if (/^\d+$/.test(trimmedValue)) { + return Number(trimmedValue) + } + } + + return null +} + async function getEmbeddingDimensions(providerConfig) { if (loadingEmbeddingDim.value) return - loadingEmbeddingDim.value = true try { const response = await axios.post('/api/config/provider/get_embedding_dim', { provider_config: providerConfig }) - if (response.data.status != "error" && response.data.data?.embedding_dimensions) { - console.log(response.data.data.embedding_dimensions) - providerConfig.embedding_dimensions = response.data.data.embedding_dimensions - useToast().success("获取成功: " + response.data.data.embedding_dimensions) + const detectedDimension = response.data.data.embedding_dimensions + const numericDimension = getNumericEmbeddingDimension(detectedDimension) + + console.log(detectedDimension) + + if (numericDimension !== null) { + providerConfig.embedding_dimensions = numericDimension + useToast().success("获取成功: " + numericDimension) + } else { + useToast().info(`检测到维度: ${detectedDimension}。如需保存,请手动填入后点保存。`) + } } else { useToast().error(response.data.message) } diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 935abb358e..c5bf0889b7 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -1276,7 +1276,8 @@ "hint": "嵌入模型名称。" }, "embedding_api_key": { - "description": "API Key" + "description": "API Key", + "hint": "使用 vLLM 作为提供商时,请在 API Key 中填写 'vllm' 以启用兼容模式(自动禁用 dimensions 参数)" }, "embedding_api_base": { "description": "API Base URL" @@ -1604,26 +1605,26 @@ }, "deerflow_assistant_id": { "description": "Assistant ID", - "hint": "DeerFlow 2.0 LangGraph assistant_id,默认为 lead_agent。" + "hint": "LangGraph assistant_id,默认为 lead_agent。" }, "deerflow_model_name": { "description": "模型名称覆盖", - "hint": "可选。覆盖 DeerFlow 默认模型(对应运行时 configurable 的 model_name)。" + "hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。" }, "deerflow_thinking_enabled": { "description": "启用思考模式" }, "deerflow_plan_mode": { "description": "启用计划模式", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 is_plan_mode。" + "hint": "对应 DeerFlow 的 is_plan_mode。" }, "deerflow_subagent_enabled": { "description": "启用子智能体", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 subagent_enabled。" + "hint": "对应 DeerFlow 的 subagent_enabled。" }, "deerflow_max_concurrent_subagents": { "description": "子智能体最大并发数", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。" + "hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。" }, "deerflow_recursion_limit": { "description": "递归深度上限",