From 85807356cb1c372dab88b124a2bd66b6e3bb801c Mon Sep 17 00:00:00 2001 From: _floateR <1014300549@qq.com> Date: Mon, 13 Apr 2026 16:50:52 +0800 Subject: [PATCH 1/6] fix(provider): resolve vLLM embedding compatibility and dimension inference --- .../sources/openai_embedding_source.py | 179 ++++++++++++++++-- astrbot/dashboard/routes/config.py | 6 + .../src/components/shared/AstrBotConfig.vue | 6 +- 3 files changed, 168 insertions(+), 23 deletions(-) diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index ae531996ae..c16064d637 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -1,8 +1,6 @@ import httpx from openai import AsyncOpenAI - from astrbot import logger - from ..entities import ProviderType from ..provider import EmbeddingProvider from ..register import register_provider_adapter @@ -18,12 +16,14 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.provider_config = provider_config self.provider_settings = provider_settings + proxy = provider_config.get("proxy", "") provider_id = provider_config.get("id", "unknown_id") http_client = None if proxy: logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}") http_client = httpx.AsyncClient(proxy=proxy) + api_base = ( provider_config.get("embedding_api_base", "https://api.openai.com/v1") .strip() @@ -33,7 +33,12 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"): # /v4 see #5699 api_base = api_base + "/v1" + + # [新增] 保存处理后的 api_base 并转换为小写,用于后续特征比对 + self.api_base_normalized = api_base.lower() + logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}") + self.client = AsyncOpenAI( api_key=provider_config.get("embedding_api_key"), base_url=api_base, @@ -41,48 +46,182 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: http_client=http_client, ) self.model = provider_config.get("embedding_model", "text-embedding-3-small") + + # [新增] 运行时状态标记:一旦触发 400 错误将此设为 True + self._is_vllm_detected = False + + def _is_vllm(self) -> bool: + """检测是否是vLLM(vLLM不支持dimensions参数)""" + # 先检查运行时检测标志 + if self._is_vllm_detected: + return True + # 方法1:检查provider_id是否包含vllm + provider_id = self.provider_config.get("id", "").lower() + if "vllm" in provider_id: + logger.info(f"[OpenAI Embedding] Detected vLLM by provider id: {provider_id}") + return True + # 方法2:检查api_base中的特征端口或主机名 + api_base = self.api_base_normalized.lower() + if "vllm" in api_base: + logger.info(f"[OpenAI Embedding] Detected vLLM by api_base keyword") + return True + # 方法3:检查常见的vLLM端口(8000, 8001等) + if ":8000" in api_base or ":8001" in api_base or ":8002" in api_base: + logger.info(f"[OpenAI Embedding] Detected vLLM by common port in api_base: {api_base}") + return True + return False + + def _mark_as_vllm(self) -> None: + """标记此实例为vLLM(通过运行时错误检测出来的)""" + self._is_vllm_detected = True + logger.info("[OpenAI Embedding] Marked as vLLM (runtime detection via error)") async def get_embedding(self, text: str) -> list[float]: """获取文本的嵌入""" kwargs = self._embedding_kwargs() - embedding = await self.client.embeddings.create( - input=text, - model=self.model, - **kwargs, - ) - return embedding.data[0].embedding + try: + embedding = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs, + ) + return embedding.data[0].embedding + except Exception as e: + # 如果包含"matryoshka"或"dimensions"相关的错误,说明vLLM不支持该参数 + # 尝试不带dimensions重试 + error_msg = str(e).lower() + if ("matryoshka" in error_msg or "dimensions" in error_msg) and kwargs.get("dimensions"): + logger.warning( + f"[OpenAI Embedding] Detected vLLM dimensions error, retrying without dimensions parameter: {e}" + ) + kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"} + try: + embedding = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs_retry, + ) + logger.info( + "[OpenAI Embedding] Successfully retrieved embedding without dimensions parameter, marking as vLLM" + ) + # 标记为vLLM以便后续调用也跳过dimensions + self._mark_as_vllm() + return embedding.data[0].embedding + except Exception as retry_error: + logger.error( + f"[OpenAI Embedding] Retry without dimensions also failed: {retry_error}" + ) + raise retry_error + else: + raise async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" kwargs = self._embedding_kwargs() - embeddings = await self.client.embeddings.create( - input=text, - model=self.model, - **kwargs, - ) - return [item.embedding for item in embeddings.data] + try: + embeddings = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs, + ) + return [item.embedding for item in embeddings.data] + except Exception as e: + # 如果包含"matryoshka"或"dimensions"相关的错误,说明vLLM不支持该参数 + # 尝试不带dimensions重试 + error_msg = str(e).lower() + if ("matryoshka" in error_msg or "dimensions" in error_msg) and kwargs.get("dimensions"): + logger.warning( + f"[OpenAI Embedding] Detected vLLM dimensions error in batch mode, retrying without dimensions: {e}" + ) + kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"} + try: + embeddings = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs_retry, + ) + logger.info( + "[OpenAI Embedding] Successfully retrieved batch embeddings without dimensions parameter" + ) + # 标记为vLLM以便后续调用也跳过dimensions + self._mark_as_vllm() + return [item.embedding for item in embeddings.data] + except Exception as retry_error: + logger.error( + f"[OpenAI Embedding] Batch retry without dimensions also failed: {retry_error}" + ) + raise retry_error + else: + raise def _embedding_kwargs(self) -> dict: """构建嵌入请求的可选参数""" kwargs = {} - if "embedding_dimensions" in self.provider_config: + provider_id = self.provider_config.get("id", "unknown") + embedding_dim_config = self.provider_config.get("embedding_dimensions", "") + # 检查是否是vLLM + is_vllm = self._is_vllm() + if is_vllm: + logger.info( + f"[OpenAI Embedding] {provider_id}: Detected vLLM, skipping dimensions parameter (config value: '{embedding_dim_config}')" + ) + return kwargs + # 非vLLM服务(OpenAI等)支持dimensions,读取配置 + if embedding_dim_config and embedding_dim_config != "": try: - kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"]) + dim_value = int(embedding_dim_config) + kwargs["dimensions"] = dim_value + logger.info( + f"[OpenAI Embedding] {provider_id}: Added dimensions parameter: {dim_value}" + ) except (ValueError, TypeError): logger.warning( - f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." + f"[OpenAI Embedding] {provider_id}: embedding_dimensions is not a valid integer: '{embedding_dim_config}', ignored." ) + else: + logger.info( + f"[OpenAI Embedding] {provider_id}: No embedding_dimensions configured, API will use default" + ) return kwargs def get_dim(self) -> int: """获取向量的维度""" - if "embedding_dimensions" in self.provider_config: + provider_id = self.provider_config.get("id", "unknown") + # 首先尝试从config读取 + embedding_dim_config = self.provider_config.get("embedding_dimensions", "") + if embedding_dim_config and embedding_dim_config != "": try: - return int(self.provider_config["embedding_dimensions"]) + dim = int(embedding_dim_config) + if dim > 0: + logger.info( + f"[OpenAI Embedding] {provider_id}: Dimension from config: {dim}" + ) + return dim except (ValueError, TypeError): logger.warning( - f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." + f"[OpenAI Embedding] {provider_id}: embedding_dimensions is not a valid integer: '{embedding_dim_config}', trying model inference" + ) + # config为空或无效时根据模型名推断维度 + # 这样Living Memory可以在自动检测后匹配正确的维度 + model = self.provider_config.get("embedding_model", "").lower() + model_dims = { + "bge-m3": 1024, + "bge-large-en-v1.5": 1024, + "bge-large-zh-v1.5": 1024, + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, + } + for model_key, dim in model_dims.items(): + if model_key in model: + logger.info( + f"[OpenAI Embedding] {provider_id}: Inferred dimension {dim} from model: {model}" ) + return dim + # 无法推断时返回0(Living Memory会检测实际维度) + logger.warning( + f"[OpenAI Embedding] {provider_id}: Could not determine dimension (model: {model}, config: '{embedding_dim_config}')" + ) return 0 async def terminate(self): diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index bcd7e075c7..8087b7bd83 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -906,6 +906,12 @@ async def get_embedding_dim(self): return Response().ok({"embedding_dimensions": dim}).__dict__ except Exception as e: logger.error(traceback.format_exc()) + err_msg = str(e).lower() + # [新增] 识别 vLLM 的特定报错关键字 + if "matryoshka" in err_msg or "dimensions" in err_msg: + logger.info("Detected vLLM specific error, bypassing...") + # 伪造一个成功的响应,告知前端进入"兼容模式" + return Response().ok({"embedding_dimensions": "vLLM-Adaptive"}).__dict__ return Response().error(f"获取嵌入维度失败: {e!s}").__dict__ async def get_provider_source_models(self): diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index 33273a36c9..0b53d78f5f 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -104,17 +104,17 @@ function saveEditedContent() { async function getEmbeddingDimensions(providerConfig) { if (loadingEmbeddingDim.value) return - loadingEmbeddingDim.value = true try { const response = await axios.post('/api/config/provider/get_embedding_dim', { provider_config: providerConfig }) - if (response.data.status != "error" && response.data.data?.embedding_dimensions) { console.log(response.data.data.embedding_dimensions) - providerConfig.embedding_dimensions = response.data.data.embedding_dimensions + [已禁用] 不再自动写入配置文件,仅显示提示 + // providerConfig.embedding_dimensions = response.data.data.embedding_dimensions useToast().success("获取成功: " + response.data.data.embedding_dimensions) + useToast().info(`检测到维度: ${response.data.data.embedding_dimensions}。如需保存,请手动填入后点保存。`) } else { useToast().error(response.data.message) } From f08962aa9d28753cbbb956975626fef83d4c6f4b Mon Sep 17 00:00:00 2001 From: _floateR <1014300549@qq.com> Date: Mon, 13 Apr 2026 21:19:35 +0800 Subject: [PATCH 2/6] refactor: use API Key 'vllm' as indicator and add WebUI hints --- .../sources/openai_embedding_source.py | 27 ++++++++++--------- .../src/components/shared/AstrBotConfig.vue | 2 +- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index c16064d637..dab622db08 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -51,24 +51,25 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self._is_vllm_detected = False def _is_vllm(self) -> bool: - """检测是否是vLLM(vLLM不支持dimensions参数)""" - # 先检查运行时检测标志 + """检测是否是 vLLM(vLLM 不支持 dimensions 参数)""" + # 1. 优先检查运行时已证实的标记 if self._is_vllm_detected: return True - # 方法1:检查provider_id是否包含vllm - provider_id = self.provider_config.get("id", "").lower() - if "vllm" in provider_id: - logger.info(f"[OpenAI Embedding] Detected vLLM by provider id: {provider_id}") + + # 2. [核心修改] 检查 API Key 是否为 "vllm" + api_key = self.provider_config.get("embedding_api_key", "") + if api_key and api_key.lower() == "vllm": + logger.info("[OpenAI Embedding] vLLM mode enabled by API Key 'vllm'.") return True - # 方法2:检查api_base中的特征端口或主机名 + + # 3. 辅助检查:ID 或 URL 中是否显式包含 "vllm" + provider_id = self.provider_config.get("id", "").lower() api_base = self.api_base_normalized.lower() - if "vllm" in api_base: - logger.info(f"[OpenAI Embedding] Detected vLLM by api_base keyword") - return True - # 方法3:检查常见的vLLM端口(8000, 8001等) - if ":8000" in api_base or ":8001" in api_base or ":8002" in api_base: - logger.info(f"[OpenAI Embedding] Detected vLLM by common port in api_base: {api_base}") + if "vllm" in provider_id or "vllm" in api_base: + logger.info(f"[OpenAI Embedding] Detected vLLM by id/api_base: {provider_id}") return True + + # 4. 移除对端口 (8000, 8001) 的静态判定,避免误伤其他兼容服务 return False def _mark_as_vllm(self) -> None: diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index 0b53d78f5f..9a332fb0e5 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -111,7 +111,7 @@ async function getEmbeddingDimensions(providerConfig) { }) if (response.data.status != "error" && response.data.data?.embedding_dimensions) { console.log(response.data.data.embedding_dimensions) - [已禁用] 不再自动写入配置文件,仅显示提示 + //[已禁用] 不再自动写入配置文件,仅显示提示 // providerConfig.embedding_dimensions = response.data.data.embedding_dimensions useToast().success("获取成功: " + response.data.data.embedding_dimensions) useToast().info(`检测到维度: ${response.data.data.embedding_dimensions}。如需保存,请手动填入后点保存。`) From b62927a356544d75d2fb288dcf69a0592a79b433 Mon Sep 17 00:00:00 2001 From: _floateR <1014300549@qq.com> Date: Tue, 14 Apr 2026 11:34:30 +0800 Subject: [PATCH 3/6] refactor: move UI hints to i18n metadata and update config defaults --- astrbot/core/config/default.py | 11 ++++++----- .../locales/zh-CN/features/config-metadata.json | 13 +++++++------ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 6787370460..ad6500058a 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2131,6 +2131,7 @@ class ChatProviderTemplate(TypedDict): "embedding_api_key": { "description": "API Key", "type": "string", + "hint": "使用 vLLM 作为提供商时,请在 API Key 中填写 'vllm' 以启用兼容模式(自动禁用 dimensions 参数)", }, "embedding_api_base": { "description": "API Base URL", @@ -2671,12 +2672,12 @@ class ChatProviderTemplate(TypedDict): "deerflow_assistant_id": { "description": "Assistant ID", "type": "string", - "hint": "DeerFlow 2.0 LangGraph assistant_id,默认为 lead_agent。", + "hint": "LangGraph assistant_id,默认为 lead_agent。", }, "deerflow_model_name": { "description": "模型名称覆盖", "type": "string", - "hint": "可选。覆盖 DeerFlow 默认模型(对应运行时 configurable 的 model_name)。", + "hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。", }, "deerflow_thinking_enabled": { "description": "启用思考模式", @@ -2685,17 +2686,17 @@ class ChatProviderTemplate(TypedDict): "deerflow_plan_mode": { "description": "启用计划模式", "type": "bool", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 is_plan_mode。", + "hint": "对应 DeerFlow 的 is_plan_mode。", }, "deerflow_subagent_enabled": { "description": "启用子智能体", "type": "bool", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 subagent_enabled。", + "hint": "对应 DeerFlow 的 subagent_enabled。", }, "deerflow_max_concurrent_subagents": { "description": "子智能体最大并发数", "type": "int", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。", + "hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。", }, "deerflow_recursion_limit": { "description": "递归深度上限", diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 935abb358e..c5bf0889b7 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -1276,7 +1276,8 @@ "hint": "嵌入模型名称。" }, "embedding_api_key": { - "description": "API Key" + "description": "API Key", + "hint": "使用 vLLM 作为提供商时,请在 API Key 中填写 'vllm' 以启用兼容模式(自动禁用 dimensions 参数)" }, "embedding_api_base": { "description": "API Base URL" @@ -1604,26 +1605,26 @@ }, "deerflow_assistant_id": { "description": "Assistant ID", - "hint": "DeerFlow 2.0 LangGraph assistant_id,默认为 lead_agent。" + "hint": "LangGraph assistant_id,默认为 lead_agent。" }, "deerflow_model_name": { "description": "模型名称覆盖", - "hint": "可选。覆盖 DeerFlow 默认模型(对应运行时 configurable 的 model_name)。" + "hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。" }, "deerflow_thinking_enabled": { "description": "启用思考模式" }, "deerflow_plan_mode": { "description": "启用计划模式", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 is_plan_mode。" + "hint": "对应 DeerFlow 的 is_plan_mode。" }, "deerflow_subagent_enabled": { "description": "启用子智能体", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 subagent_enabled。" + "hint": "对应 DeerFlow 的 subagent_enabled。" }, "deerflow_max_concurrent_subagents": { "description": "子智能体最大并发数", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。" + "hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。" }, "deerflow_recursion_limit": { "description": "递归深度上限", From c281462db1aa5daf9b9054fda52f33a24b69284f Mon Sep 17 00:00:00 2001 From: _floateR <1014300549@qq.com> Date: Tue, 19 May 2026 14:15:23 +0800 Subject: [PATCH 4/6] Integrate vLLM embedding provider into core --- astrbot/core/config/default.py | 266 +++++++++++++----- astrbot/core/provider/manager.py | 28 +- .../provider/sources/vllm_embedding_source.py | 265 +++++++++++++++++ 3 files changed, 490 insertions(+), 69 deletions(-) create mode 100644 astrbot/core/provider/sources/vllm_embedding_source.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index ad6500058a..7e5c8d9ccb 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1,11 +1,11 @@ """如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。""" import os -from typing import Any, TypedDict +from astrbot.core.computer.booters.cua_defaults import CUA_DEFAULT_CONFIG from astrbot.core.utils.astrbot_path import get_astrbot_data_path -VERSION = "4.23.0" +VERSION = "4.25.1" DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") PERSONAL_WECHAT_CONFIG_METADATA = { "weixin_oc_base_url": { @@ -111,6 +111,7 @@ "websearch_bocha_key": [], "websearch_brave_key": [], "websearch_baidu_app_builder_key": "", + "websearch_firecrawl_key": [], "web_search_link": False, "display_reasoning_text": False, "identifier": False, @@ -134,6 +135,7 @@ "streaming_response": False, "show_tool_use_status": False, "show_tool_call_result": False, + "buffer_intermediate_messages": False, "sanitize_context_by_modalities": False, "max_quoted_fallback_images": 20, "quoted_message_parser": { @@ -174,6 +176,12 @@ "shipyard_neo_access_token": "", "shipyard_neo_profile": "python-default", "shipyard_neo_ttl": 3600, + "cua_image": CUA_DEFAULT_CONFIG["image"], + "cua_os_type": CUA_DEFAULT_CONFIG["os_type"], + "cua_idle_timeout": CUA_DEFAULT_CONFIG["idle_timeout"], + "cua_telemetry_enabled": CUA_DEFAULT_CONFIG["telemetry_enabled"], + "cua_local": CUA_DEFAULT_CONFIG["local"], + "cua_api_key": CUA_DEFAULT_CONFIG["api_key"], }, "image_compress_enabled": True, "image_compress_options": { @@ -236,7 +244,10 @@ "dashboard": { "enable": True, "username": "astrbot", - "password": "77b90590a8945a7d36c963981a307dc9", + "password": "", + "pbkdf2_password": "", + "password_storage_upgraded": False, + "password_change_required": False, "jwt_secret": "", "host": "0.0.0.0", "port": 6185, @@ -283,27 +294,10 @@ "kb_final_top_k": 5, # 知识库检索最终返回结果数量 "kb_agentic_mode": False, "disable_builtin_commands": False, + "disable_metrics": False, } -class ChatProviderTemplate(TypedDict): - id: str - provider_source_id: str - model: str - modalities: list - custom_extra_body: dict[str, Any] - max_context_tokens: int - - -CHAT_PROVIDER_TEMPLATE = { - "id": "", - "provide_source_id": "", - "model": "", - "modalities": [], - "custom_extra_body": {}, - "max_context_tokens": 0, -} - """ AstrBot v3 时代的配置元数据,目前仅承担以下功能: @@ -324,7 +318,7 @@ class ChatProviderTemplate(TypedDict): "QQ 官方机器人(WebSocket)": { "id": "default", "type": "qq_official", - "enable": False, + "enable": True, "appid": "", "secret": "", "enable_group_c2c": True, @@ -333,7 +327,7 @@ class ChatProviderTemplate(TypedDict): "QQ 官方机器人(Webhook)": { "id": "default", "type": "qq_official_webhook", - "enable": False, + "enable": True, "appid": "", "secret": "", "is_sandbox": False, @@ -345,7 +339,7 @@ class ChatProviderTemplate(TypedDict): "OneBot v11": { "id": "default", "type": "aiocqhttp", - "enable": False, + "enable": True, "ws_reverse_host": "0.0.0.0", "ws_reverse_port": 6199, "ws_reverse_token": "", @@ -353,7 +347,7 @@ class ChatProviderTemplate(TypedDict): "微信公众平台": { "id": "weixin_official_account", "type": "weixin_official_account", - "enable": False, + "enable": True, "appid": "", "secret": "", "token": "", @@ -368,7 +362,7 @@ class ChatProviderTemplate(TypedDict): "企业微信(含微信客服)": { "id": "wecom", "type": "wecom", - "enable": False, + "enable": True, "corpid": "", "secret": "", "token": "", @@ -405,7 +399,7 @@ class ChatProviderTemplate(TypedDict): "个人微信": { "id": "weixin_personal", "type": "weixin_oc", - "enable": False, + "enable": True, "weixin_oc_base_url": "https://ilinkai.weixin.qq.com", "weixin_oc_bot_type": "3", "weixin_oc_qr_poll_interval": 1, @@ -415,8 +409,7 @@ class ChatProviderTemplate(TypedDict): "飞书(Lark)": { "id": "lark", "type": "lark", - "enable": False, - "lark_bot_name": "", + "enable": True, "app_id": "", "app_secret": "", "domain": "https://open.feishu.cn", @@ -428,7 +421,7 @@ class ChatProviderTemplate(TypedDict): "钉钉(DingTalk)": { "id": "dingtalk", "type": "dingtalk", - "enable": False, + "enable": True, "client_id": "", "client_secret": "", "card_template_id": "", @@ -436,7 +429,7 @@ class ChatProviderTemplate(TypedDict): "Telegram": { "id": "telegram", "type": "telegram", - "enable": False, + "enable": True, "telegram_token": "your_bot_token", "start_message": "Hello, I'm AstrBot!", "telegram_api_base_url": "https://api.telegram.org/bot", @@ -449,7 +442,7 @@ class ChatProviderTemplate(TypedDict): "Discord": { "id": "discord", "type": "discord", - "enable": False, + "enable": True, "discord_token": "", "discord_proxy": "", "discord_command_register": True, @@ -459,7 +452,7 @@ class ChatProviderTemplate(TypedDict): "Misskey": { "id": "misskey", "type": "misskey", - "enable": False, + "enable": True, "misskey_instance_url": "https://misskey.example", "misskey_token": "", "misskey_default_visibility": "public", @@ -477,7 +470,7 @@ class ChatProviderTemplate(TypedDict): "Slack": { "id": "slack", "type": "slack", - "enable": False, + "enable": True, "bot_token": "", "app_token": "", "signing_secret": "", @@ -491,7 +484,7 @@ class ChatProviderTemplate(TypedDict): "Line": { "id": "line", "type": "line", - "enable": False, + "enable": True, "channel_access_token": "", "channel_secret": "", "unified_webhook_mode": True, @@ -500,7 +493,7 @@ class ChatProviderTemplate(TypedDict): "Satori": { "id": "satori", "type": "satori", - "enable": False, + "enable": True, "satori_api_base_url": "http://localhost:5140/satori/v1", "satori_endpoint": "ws://localhost:5140/satori/v1/events", "satori_token": "", @@ -511,7 +504,7 @@ class ChatProviderTemplate(TypedDict): "KOOK": { "id": "kook", "type": "kook", - "enable": False, + "enable": True, "kook_bot_token": "", "kook_reconnect_delay": 1, "kook_max_reconnect_delay": 60, @@ -524,7 +517,7 @@ class ChatProviderTemplate(TypedDict): "Mattermost": { "id": "mattermost", "type": "mattermost", - "enable": False, + "enable": True, "mattermost_url": "https://chat.example.com", "mattermost_bot_token": "", "mattermost_reconnect_delay": 5.0, @@ -782,7 +775,7 @@ class ChatProviderTemplate(TypedDict): "appid": { "description": "appid", "type": "string", - "hint": "必填项。QQ 官方机器人平台的 appid。如何获取请参考文档。", + "hint": "必填项。当前消息平台的 AppID。如何获取请参考对应平台接入文档。", }, "secret": { "description": "secret", @@ -895,11 +888,6 @@ class ChatProviderTemplate(TypedDict): "wecom_ai_bot_connection_mode": "long_connection", }, }, - "lark_bot_name": { - "description": "飞书机器人的名字", - "type": "string", - "hint": "请务必填写正确,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。", - }, "discord_token": { "description": "Discord Bot Token", "type": "string", @@ -1206,7 +1194,7 @@ class ChatProviderTemplate(TypedDict): "provider_type": "chat_completion", "enable": True, "key": [], - "api_base": "https://api.kimi.com/coding/", + "api_base": "https://api.kimi.com/coding", "timeout": 120, "proxy": "", "custom_headers": {"User-Agent": "claude-code/0.1.0"}, @@ -1236,6 +1224,19 @@ class ChatProviderTemplate(TypedDict): "proxy": "", "custom_headers": {}, }, + "MiniMax Token Plan": { + "id": "minimax-token-plan", + "provider": "minimax-token-plan", + "type": "minimax_token_plan", + "provider_type": "chat_completion", + "enable": True, + "key": [], + "api_base": "https://api.minimaxi.com/anthropic", + "timeout": 120, + "proxy": "", + "custom_headers": {"User-Agent": "claude-code/0.1.0"}, + "anth_thinking_config": {"type": "", "budget": 0, "effort": ""}, + }, "xAI": { "id": "xai", "provider": "xai", @@ -1796,6 +1797,48 @@ class ChatProviderTemplate(TypedDict): "timeout": 20, "proxy": "", }, + "NVIDIA Embedding": { + "id": "nvidia_embedding", + "type": "nvidia_embedding", + "provider": "nvidia", + "provider_type": "embedding", + "hint": "provider_group.provider.nvidia_embedding.hint", + "enable": True, + "embedding_api_key": "", + "embedding_api_base": "https://integrate.api.nvidia.com/v1", + "embedding_model": "nvidia/llama-nemotron-embed-1b-v2", + "input_type": "passage", + "embedding_dimensions": 1024, + "timeout": 20, + "proxy": "", + }, + "Ollama Embedding": { + "id": "ollama_embedding", + "type": "ollama_embedding", + "provider": "ollama", + "provider_type": "embedding", + "hint": "provider_group.provider.ollama_embedding.hint", + "enable": True, + "embedding_api_base": "http://localhost:11434", + "embedding_model": "nomic-embed-text", + "embedding_dimensions": 768, + "timeout": 60, + "proxy": "", + }, + "vLLM Embedding": { + "id": "vllm_embedding", + "type": "vllm_embedding", + "provider": "vllm", + "provider_type": "embedding", + "hint": "面向 vLLM OpenAI-compatible Embedding 接口。请求时会自动跳过 dimensions,并尝试将模型名对齐到 served-model-name。", + "enable": True, + "embedding_api_key": "", + "embedding_api_base": "http://127.0.0.1:8000/v1", + "embedding_model": "BAAI/bge-m3", + "embedding_dimensions": 1024, + "timeout": 20, + "proxy": "", + }, "vLLM Rerank": { "id": "vllm_rerank", "type": "vllm_rerank", @@ -1949,13 +1992,13 @@ class ChatProviderTemplate(TypedDict): "options": ["text", "image", "audio", "tool_use"], "labels": ["文本", "图像", "音频", "工具使用"], "render_type": "checkbox", - "hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。", + "hint": "模型支持的模态及能力。", }, "custom_headers": { - "description": "自定义添加请求头", + "description": "自定义请求头", "type": "dict", "items": {}, - "hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。值必须为字符串。", + "hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。", }, "ollama_disable_thinking": { "description": "关闭思考模式", @@ -1966,7 +2009,7 @@ class ChatProviderTemplate(TypedDict): "description": "自定义请求体参数", "type": "dict", "items": {}, - "hint": "用于在请求时添加额外的参数,如 temperature、top_p、max_tokens 等。", + "hint": "用于在请求时添加额外的参数,如 temperature, top_p, max_tokens, reasoning_effort 等。", "template_schema": { "temperature": { "name": "Temperature", @@ -2131,7 +2174,6 @@ class ChatProviderTemplate(TypedDict): "embedding_api_key": { "description": "API Key", "type": "string", - "hint": "使用 vLLM 作为提供商时,请在 API Key 中填写 'vllm' 以启用兼容模式(自动禁用 dimensions 参数)", }, "embedding_api_base": { "description": "API Base URL", @@ -2610,7 +2652,7 @@ class ChatProviderTemplate(TypedDict): "max_context_tokens": { "description": "模型上下文窗口大小", "type": "int", - "hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。", + "hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有)", }, "dify_api_key": { "description": "API Key", @@ -2672,12 +2714,12 @@ class ChatProviderTemplate(TypedDict): "deerflow_assistant_id": { "description": "Assistant ID", "type": "string", - "hint": "LangGraph assistant_id,默认为 lead_agent。", + "hint": "DeerFlow 2.0 LangGraph assistant_id,默认为 lead_agent。", }, "deerflow_model_name": { "description": "模型名称覆盖", "type": "string", - "hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。", + "hint": "可选。覆盖 DeerFlow 默认模型(对应运行时 configurable 的 model_name)。", }, "deerflow_thinking_enabled": { "description": "启用思考模式", @@ -2686,17 +2728,17 @@ class ChatProviderTemplate(TypedDict): "deerflow_plan_mode": { "description": "启用计划模式", "type": "bool", - "hint": "对应 DeerFlow 的 is_plan_mode。", + "hint": "对应 DeerFlow 2.0 运行时 configurable 的 is_plan_mode。", }, "deerflow_subagent_enabled": { "description": "启用子智能体", "type": "bool", - "hint": "对应 DeerFlow 的 subagent_enabled。", + "hint": "对应 DeerFlow 2.0 运行时 configurable 的 subagent_enabled。", }, "deerflow_max_concurrent_subagents": { "description": "子智能体最大并发数", "type": "int", - "hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。", + "hint": "对应 DeerFlow 2.0 运行时 configurable 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。", }, "deerflow_recursion_limit": { "description": "递归深度上限", @@ -2765,6 +2807,9 @@ class ChatProviderTemplate(TypedDict): "show_tool_call_result": { "type": "bool", }, + "buffer_intermediate_messages": { + "type": "bool", + }, "unsupported_streaming_strategy": { "type": "string", }, @@ -2919,6 +2964,11 @@ class ChatProviderTemplate(TypedDict): "callback_api_base": { "type": "string", }, + "disable_metrics": { + "description": "禁用匿名使用统计", + "type": "bool", + "hint": "禁用后,AstrBot 将不再上传匿名使用统计数据。", + }, "log_level": { "type": "string", "options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], @@ -3186,6 +3236,7 @@ class ChatProviderTemplate(TypedDict): "baidu_ai_search", "bocha", "brave", + "firecrawl", ], "condition": { "provider_settings.web_search": True, @@ -3221,12 +3272,23 @@ class ChatProviderTemplate(TypedDict): "provider_settings.web_search": True, }, }, + "provider_settings.websearch_firecrawl_key": { + "description": "Firecrawl API Key", + "type": "list", + "items": {"type": "string"}, + "hint": "可添加多个 Key 进行轮询。", + "condition": { + "provider_settings.websearch_provider": "firecrawl", + "provider_settings.web_search": True, + }, + }, "provider_settings.websearch_baidu_app_builder_key": { "description": "百度千帆智能云 APP Builder API Key", "type": "string", "hint": "参考:https://console.bce.baidu.com/iam/#/iam/apikey/list", "condition": { "provider_settings.websearch_provider": "baidu_ai_search", + "provider_settings.web_search": True, }, }, "provider_settings.web_search_link": { @@ -3262,8 +3324,8 @@ class ChatProviderTemplate(TypedDict): "provider_settings.sandbox.booter": { "description": "沙箱环境驱动器", "type": "string", - "options": ["shipyard_neo", "shipyard"], - "labels": ["Shipyard Neo", "Shipyard"], + "options": ["shipyard_neo", "shipyard", "cua"], + "labels": ["Shipyard Neo", "Shipyard", "CUA"], "condition": { "provider_settings.computer_use_runtime": "sandbox", }, @@ -3289,7 +3351,7 @@ class ChatProviderTemplate(TypedDict): "provider_settings.sandbox.shipyard_neo_profile": { "description": "Shipyard Neo Profile", "type": "string", - "hint": "Shipyard Neo 沙箱 profile,如 python-default。", + "hint": "Shipyard Neo 沙箱 profile,如 python-default。留空时自动选择能力更完整的 profile。", "condition": { "provider_settings.computer_use_runtime": "sandbox", "provider_settings.sandbox.booter": "shipyard_neo", @@ -3304,6 +3366,64 @@ class ChatProviderTemplate(TypedDict): "provider_settings.sandbox.booter": "shipyard_neo", }, }, + "provider_settings.sandbox.cua_image": { + "description": "CUA Image", + "type": "string", + "hint": "CUA 沙箱镜像/系统类型,默认 linux。可填写 linux、macos、windows、android,具体取决于 CUA SDK 支持。", + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + }, + }, + "provider_settings.sandbox.cua_os_type": { + "description": "CUA OS Type", + "type": "string", + "options": ["linux", "macos", "windows", "android"], + "labels": ["Linux", "macOS", "Windows", "Android"], + "hint": "CUA 沙箱操作系统类型,默认 linux。", + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + }, + }, + "provider_settings.sandbox.cua_idle_timeout": { + "description": "CUA Idle Timeout", + "type": "int", + "hint": "Idle timeout for CUA sandbox sessions in seconds. When greater than 0, AstrBot proactively shuts down an idle CUA sandbox after that amount of inactivity; 0 disables it.", + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + }, + }, + "provider_settings.sandbox.cua_telemetry_enabled": { + "description": "CUA Telemetry", + "type": "bool", + "hint": "是否允许 CUA SDK 发送遥测数据。默认关闭。", + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + }, + }, + "provider_settings.sandbox.cua_local": { + "description": "CUA Local Sandbox", + "type": "bool", + "hint": "是否优先使用 CUA 本地沙箱。默认开启,避免云端沙箱要求 CUA_API_KEY。关闭后可使用 CUA 云端沙箱。", + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + }, + }, + "provider_settings.sandbox.cua_api_key": { + "description": "CUA API Key", + "type": "string", + "hint": "CUA 云端沙箱 API Key。仅在关闭本地沙箱时需要。也可以通过 CUA_API_KEY 环境变量提供。", + "obvious_hint": True, + "condition": { + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "cua", + "provider_settings.sandbox.cua_local": False, + }, + }, "provider_settings.sandbox.shipyard_endpoint": { "description": "Shipyard API Endpoint", "type": "string", @@ -3452,6 +3572,14 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.fallback_max_context_tokens": { + "description": "上下文窗口兜底值", + "type": "int", + "hint": "当 max_context_tokens 为 0 且模型不在内置元数据中时,使用此值作为上下文窗口大小。默认 128000。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, }, "condition": { "provider_settings.agent_runner_type": "local", @@ -3531,6 +3659,15 @@ class ChatProviderTemplate(TypedDict): "provider_settings.show_tool_use_status": True, }, }, + "provider_settings.buffer_intermediate_messages": { + "description": "合并 Agent 中间消息", + "type": "bool", + "hint": "开启后,非流式模式下多步工具调用过程中产生的中间文本将缓冲,待 Agent 完成后合并为一条回复发送。", + "condition": { + "provider_settings.agent_runner_type": "local", + "provider_settings.streaming_response": False, + }, + }, "provider_settings.sanitize_context_by_modalities": { "description": "按模型能力清理历史上下文", "type": "bool", @@ -3568,11 +3705,6 @@ class ChatProviderTemplate(TypedDict): "type": "string", "hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求", }, - "provider_settings.prompt_prefix": { - "description": "用户提示词", - "type": "string", - "hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。", - }, "provider_settings.image_compress_enabled": { "description": "启用图片压缩", "type": "bool", @@ -3596,6 +3728,12 @@ class ChatProviderTemplate(TypedDict): }, "slider": {"min": 1, "max": 100, "step": 1}, }, + "provider_settings.prompt_prefix": { + "description": "用户提示词", + "type": "string", + "hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。", + "collapsed": True, + }, "provider_tts_settings.dual_output": { "description": "开启 TTS 时同时输出语音和文字内容", "type": "bool", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index b4ff2742d0..45182db5fb 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -363,6 +363,10 @@ def dynamic_import_provider(self, type: str) -> None: ) case "longcat_chat_completion": from .sources.longcat_source import ProviderLongCat as ProviderLongCat + case "minimax_token_plan": + from .sources.minimax_token_plan_source import ( + ProviderMiniMaxTokenPlan as ProviderMiniMaxTokenPlan, + ) case "zhipu_chat_completion": from .sources.zhipu_source import ProviderZhipu as ProviderZhipu case "groq_chat_completion": @@ -465,6 +469,18 @@ def dynamic_import_provider(self, type: str) -> None: from .sources.gemini_embedding_source import ( GeminiEmbeddingProvider as GeminiEmbeddingProvider, ) + case "nvidia_embedding": + from .sources.nvidia_embedding_source import ( + NvidiaEmbeddingProvider as NvidiaEmbeddingProvider, + ) + case "ollama_embedding": + from .sources.ollama_embedding_source import ( + OllamaEmbeddingProvider as OllamaEmbeddingProvider, + ) + case "vllm_embedding": + from .sources.vllm_embedding_source import ( + VLLMEmbeddingProvider as VLLMEmbeddingProvider, + ) case "vllm_rerank": from .sources.vllm_rerank_source import ( VLLMRerankProvider as VLLMRerankProvider, @@ -566,7 +582,9 @@ async def load_provider(self, provider_config: dict) -> None: return logger.info( - f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...", + "Loading model %s(%s) ...", + provider_config["type"], + provider_config["id"], ) # 动态导入 @@ -587,7 +605,7 @@ async def load_provider(self, provider_config: dict) -> None: if provider_config["type"] not in provider_cls_map: logger.error( - f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。", + f"Provider adapter not found: {provider_config['type']}({provider_config['id']}). Skipped.", exc_info=True, ) return @@ -621,7 +639,7 @@ async def load_provider(self, provider_config: dict) -> None: ): self.curr_stt_provider_inst = inst logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。", + f"Selected {provider_config['type']}({provider_config['id']}) as default STT provider", ) if not self.curr_stt_provider_inst: self.curr_stt_provider_inst = inst @@ -644,7 +662,7 @@ async def load_provider(self, provider_config: dict) -> None: ): self.curr_tts_provider_inst = inst logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", + f"Selected {provider_config['type']}({provider_config['id']}) as default TTS provider", ) if not self.curr_tts_provider_inst: self.curr_tts_provider_inst = inst @@ -670,7 +688,7 @@ async def load_provider(self, provider_config: dict) -> None: ): self.curr_provider_inst = inst logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。", + f"Selected {provider_config['type']}({provider_config['id']}) as default chat model provider", ) if not self.curr_provider_inst: self.curr_provider_inst = inst diff --git a/astrbot/core/provider/sources/vllm_embedding_source.py b/astrbot/core/provider/sources/vllm_embedding_source.py new file mode 100644 index 0000000000..c7e506b942 --- /dev/null +++ b/astrbot/core/provider/sources/vllm_embedding_source.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +from ipaddress import ip_address +from typing import Any +from urllib.parse import urlparse + +import httpx +from openai import AsyncOpenAI + +from astrbot import logger + +from ..entities import ProviderType +from ..provider import EmbeddingProvider +from ..register import register_provider_adapter + + +_COMMON_MODEL_DIMENSIONS = { + "bge-m3": 1024, + "bge-large-en-v1.5": 1024, + "bge-large-zh-v1.5": 1024, + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, +} + + +@register_provider_adapter( + "vllm_embedding", + "vLLM Embedding 提供商适配器", + provider_type=ProviderType.EMBEDDING, + provider_display_name="vLLM Embedding", +) +class VLLMEmbeddingProvider(EmbeddingProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.provider_config = provider_config + self.provider_settings = provider_settings + self.timeout = int(provider_config.get("timeout", 20) or 20) + self.model = str(provider_config.get("embedding_model", "") or "").strip() + self.set_model(self.model) + self._force_direct_transport = self._should_force_direct_transport() + + self._detected_dimension: int | None = None + self._resolved_request_model: str | None = None + self._direct_client_ready = self._force_direct_transport + + self.client = AsyncOpenAI( + api_key=provider_config.get("embedding_api_key"), + base_url=self._effective_api_base(), + timeout=self.timeout, + http_client=self._build_http_client(), + ) + + async def get_embedding(self, text: str) -> list[float]: + await self._ensure_runtime_ready() + request_model = await self._resolve_request_model() + logger.info( + "[vLLM Embedding] %s 发起单条 embedding 请求,model=%s,text_len=%s,跳过 dimensions。", + self._provider_id(), + request_model, + len(text), + ) + embedding = await self.client.embeddings.create( + input=text, + model=request_model, + ) + vector = embedding.data[0].embedding + self._cache_detected_dimension(len(vector)) + return vector + + async def get_embeddings(self, text: list[str]) -> list[list[float]]: + await self._ensure_runtime_ready() + request_model = await self._resolve_request_model() + total_chars = sum(len(item) for item in text) + logger.info( + "[vLLM Embedding] %s 发起批量 embedding 请求,model=%s,batch=%s,total_chars=%s,跳过 dimensions。", + self._provider_id(), + request_model, + len(text), + total_chars, + ) + embeddings = await self.client.embeddings.create( + input=text, + model=request_model, + ) + vectors = [item.embedding for item in embeddings.data] + if vectors: + self._cache_detected_dimension(len(vectors[0])) + return vectors + + def get_dim(self) -> int: + configured_dim = self._configured_dimension() + if configured_dim: + return configured_dim + if self._detected_dimension: + return self._detected_dimension + inferred_dim = self._infer_dimension_from_model(self.model) + if inferred_dim: + return inferred_dim + return 0 + + async def terminate(self) -> None: + if self.client: + await self.client.close() + + def _build_http_client(self) -> httpx.AsyncClient | None: + proxy = str(self.provider_config.get("proxy", "") or "").strip() + if proxy: + logger.info("[vLLM Embedding] %s 使用显式代理: %s", self._provider_id(), proxy) + return httpx.AsyncClient(proxy=proxy, timeout=self.timeout) + if self._force_direct_transport: + return httpx.AsyncClient(timeout=self.timeout, trust_env=False) + return None + + async def _ensure_runtime_ready(self) -> None: + if self._direct_client_ready or not self._should_force_direct_transport(): + return + + old_client = self.client + self.client = AsyncOpenAI( + api_key=self.provider_config.get("embedding_api_key"), + base_url=self._effective_api_base(), + timeout=self.timeout, + http_client=httpx.AsyncClient(timeout=self.timeout, trust_env=False), + ) + self._direct_client_ready = True + + logger.info( + "[vLLM Embedding] %s 检测到本地/内网端点,已切换为 trust_env=False 的直连 client。", + self._provider_id(), + ) + + if old_client is not None and old_client is not self.client: + try: + await old_client.close() + except Exception: + logger.debug("[vLLM Embedding] %s 关闭旧 client 失败,已忽略。", self._provider_id()) + + async def _resolve_request_model(self) -> str: + if self._resolved_request_model: + return self._resolved_request_model + + configured_model = self.model + if not configured_model: + self._resolved_request_model = configured_model + return configured_model + + available_models = await self._list_vllm_models() + resolved_model = self._match_served_model(configured_model, available_models) + if resolved_model: + self._resolved_request_model = resolved_model + if resolved_model != configured_model: + logger.info( + "[vLLM Embedding] %s 已将模型名 %s 对齐到 served-model-name %s。", + self._provider_id(), + configured_model, + resolved_model, + ) + return resolved_model + + basename_model = configured_model.rsplit("/", 1)[-1].strip() + if basename_model and basename_model != configured_model: + self._resolved_request_model = basename_model + logger.warning( + "[vLLM Embedding] %s 未能从 /models 精确匹配 %s,回退为 %s。", + self._provider_id(), + configured_model, + basename_model, + ) + return basename_model + + self._resolved_request_model = configured_model + return configured_model + + async def _list_vllm_models(self) -> list[dict[str, str]]: + try: + models = await self.client.models.list() + except Exception as exc: + logger.warning( + "[vLLM Embedding] %s 拉取 /models 失败,将直接使用配置模型名: %s", + self._provider_id(), + exc, + ) + return [] + + results: list[dict[str, str]] = [] + for item in getattr(models, "data", []) or []: + model_id = str(getattr(item, "id", "") or "").strip() + model_root = str(getattr(item, "root", "") or "").strip() + if model_id: + results.append({"id": model_id, "root": model_root}) + return results + + def _match_served_model( + self, + configured_model: str, + available_models: list[dict[str, str]], + ) -> str | None: + normalized_configured = configured_model.lower() + basename_model = configured_model.rsplit("/", 1)[-1].strip().lower() + + for item in available_models: + model_id = str(item.get("id", "") or "").strip() + model_root = str(item.get("root", "") or "").strip() + if model_id.lower() == normalized_configured: + return model_id + if model_root and model_root.lower() == normalized_configured: + return model_id + if basename_model and model_id.lower() == basename_model: + return model_id + return None + + def _configured_dimension(self) -> int | None: + raw_dimension = self.provider_config.get("embedding_dimensions", "") + if raw_dimension in (None, ""): + return None + try: + dimension = int(raw_dimension) + except (TypeError, ValueError): + logger.warning( + "[vLLM Embedding] %s 的 embedding_dimensions 不是有效整数: %r", + self._provider_id(), + raw_dimension, + ) + return None + return dimension if dimension > 0 else None + + def _infer_dimension_from_model(self, model_name: Any) -> int | None: + normalized_model = str(model_name or "").strip().lower() + for model_key, dimension in _COMMON_MODEL_DIMENSIONS.items(): + if model_key in normalized_model: + return dimension + return None + + def _cache_detected_dimension(self, dimension: int) -> None: + if isinstance(dimension, int) and dimension > 0: + self._detected_dimension = dimension + + def _effective_api_base(self) -> str: + api_base = str( + self.provider_config.get("embedding_api_base", "http://127.0.0.1:8000/v1") + or "http://127.0.0.1:8000/v1" + ).strip() + api_base = api_base.removesuffix("/").removesuffix("/embeddings") + if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"): + api_base = api_base + "/v1" + return api_base + + def _should_force_direct_transport(self) -> bool: + if str(self.provider_config.get("proxy", "") or "").strip(): + return False + + host = (urlparse(self._effective_api_base()).hostname or "").strip().lower() + if not host: + return False + if host in {"localhost", "127.0.0.1", "::1", "host.docker.internal"}: + return True + try: + parsed_host = ip_address(host) + except ValueError: + return False + return parsed_host.is_loopback or parsed_host.is_private + + def _provider_id(self) -> str: + return str(self.provider_config.get("id", "unknown") or "unknown") \ No newline at end of file From 1e586e86219dcc0b846f7c43317986e3f9a3dca5 Mon Sep 17 00:00:00 2001 From: _floateR <1014300549@qq.com> Date: Tue, 19 May 2026 14:21:18 +0800 Subject: [PATCH 5/6] Normalize vLLM embedding provider defaults --- astrbot/core/config/default.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 7e5c8d9ccb..34f714c92b 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1831,11 +1831,11 @@ "provider": "vllm", "provider_type": "embedding", "hint": "面向 vLLM OpenAI-compatible Embedding 接口。请求时会自动跳过 dimensions,并尝试将模型名对齐到 served-model-name。", - "enable": True, + "enable": False, "embedding_api_key": "", - "embedding_api_base": "http://127.0.0.1:8000/v1", - "embedding_model": "BAAI/bge-m3", - "embedding_dimensions": 1024, + "embedding_api_base": "", + "embedding_model": "", + "embedding_dimensions": "", "timeout": 20, "proxy": "", }, From ac673bece92443839715e4c03a7243ea62c7bf92 Mon Sep 17 00:00:00 2001 From: _floateR <1014300549@qq.com> Date: Tue, 19 May 2026 16:36:49 +0800 Subject: [PATCH 6/6] fix(provider): address embedding review feedback --- astrbot/core/config/default.py | 2 +- .../core/provider/sources/embedding_utils.py | 46 ++++ .../sources/openai_embedding_source.py | 199 +++++++++--------- .../provider/sources/vllm_embedding_source.py | 38 +--- .../src/components/shared/AstrBotConfig.vue | 31 ++- 5 files changed, 178 insertions(+), 138 deletions(-) create mode 100644 astrbot/core/provider/sources/embedding_utils.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 34f714c92b..bb1b6b55e8 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1835,7 +1835,7 @@ "embedding_api_key": "", "embedding_api_base": "", "embedding_model": "", - "embedding_dimensions": "", + "embedding_dimensions": 0, "timeout": 20, "proxy": "", }, diff --git a/astrbot/core/provider/sources/embedding_utils.py b/astrbot/core/provider/sources/embedding_utils.py new file mode 100644 index 0000000000..9b8b0a5803 --- /dev/null +++ b/astrbot/core/provider/sources/embedding_utils.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any + +from astrbot import logger + + +COMMON_MODEL_DIMENSIONS = { + "bge-m3": 1024, + "bge-large-en-v1.5": 1024, + "bge-large-zh-v1.5": 1024, + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, +} + + +def parse_configured_embedding_dimension( + raw_dimension: Any, + *, + provider_label: str, + provider_id: str, +) -> int | None: + if raw_dimension in (None, ""): + return None + + try: + dimension = int(raw_dimension) + except (TypeError, ValueError): + logger.warning( + "[%s] %s 的 embedding_dimensions 不是有效整数: %r", + provider_label, + provider_id, + raw_dimension, + ) + return None + + return dimension if dimension > 0 else None + + +def infer_embedding_dimension_from_model(model_name: Any) -> int | None: + normalized_model = str(model_name or "").strip().lower() + for model_key, dimension in COMMON_MODEL_DIMENSIONS.items(): + if model_key in normalized_model: + return dimension + return None \ No newline at end of file diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index dab622db08..5205d69c21 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -1,9 +1,15 @@ import httpx from openai import AsyncOpenAI + from astrbot import logger + from ..entities import ProviderType from ..provider import EmbeddingProvider from ..register import register_provider_adapter +from .embedding_utils import ( + infer_embedding_dimension_from_model, + parse_configured_embedding_dimension, +) @register_provider_adapter( @@ -80,81 +86,86 @@ def _mark_as_vllm(self) -> None: async def get_embedding(self, text: str) -> list[float]: """获取文本的嵌入""" kwargs = self._embedding_kwargs() - try: - embedding = await self.client.embeddings.create( - input=text, - model=self.model, - **kwargs, - ) - return embedding.data[0].embedding - except Exception as e: - # 如果包含"matryoshka"或"dimensions"相关的错误,说明vLLM不支持该参数 - # 尝试不带dimensions重试 - error_msg = str(e).lower() - if ("matryoshka" in error_msg or "dimensions" in error_msg) and kwargs.get("dimensions"): - logger.warning( - f"[OpenAI Embedding] Detected vLLM dimensions error, retrying without dimensions parameter: {e}" - ) - kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"} - try: - embedding = await self.client.embeddings.create( - input=text, - model=self.model, - **kwargs_retry, - ) - logger.info( - "[OpenAI Embedding] Successfully retrieved embedding without dimensions parameter, marking as vLLM" - ) - # 标记为vLLM以便后续调用也跳过dimensions - self._mark_as_vllm() - return embedding.data[0].embedding - except Exception as retry_error: - logger.error( - f"[OpenAI Embedding] Retry without dimensions also failed: {retry_error}" - ) - raise retry_error - else: - raise + embedding = await self._request_with_vllm_retry(text, kwargs, batch=False) + return embedding.data[0].embedding async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" kwargs = self._embedding_kwargs() + embeddings = await self._request_with_vllm_retry(text, kwargs, batch=True) + return [item.embedding for item in embeddings.data] + + async def _request_with_vllm_retry( + self, + input_data: str | list[str], + kwargs: dict, + *, + batch: bool, + ): try: - embeddings = await self.client.embeddings.create( - input=text, + return await self.client.embeddings.create( + input=input_data, model=self.model, **kwargs, ) - return [item.embedding for item in embeddings.data] - except Exception as e: - # 如果包含"matryoshka"或"dimensions"相关的错误,说明vLLM不支持该参数 - # 尝试不带dimensions重试 - error_msg = str(e).lower() - if ("matryoshka" in error_msg or "dimensions" in error_msg) and kwargs.get("dimensions"): + except Exception as exc: + if not self._should_retry_without_dimensions(exc, kwargs): + raise + + if batch: logger.warning( - f"[OpenAI Embedding] Detected vLLM dimensions error in batch mode, retrying without dimensions: {e}" + f"[OpenAI Embedding] Detected vLLM dimensions error in batch mode, retrying without dimensions: {exc}" ) - kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"} - try: - embeddings = await self.client.embeddings.create( - input=text, - model=self.model, - **kwargs_retry, - ) - logger.info( - "[OpenAI Embedding] Successfully retrieved batch embeddings without dimensions parameter" - ) - # 标记为vLLM以便后续调用也跳过dimensions - self._mark_as_vllm() - return [item.embedding for item in embeddings.data] - except Exception as retry_error: + else: + logger.warning( + f"[OpenAI Embedding] Detected vLLM dimensions error, retrying without dimensions parameter: {exc}" + ) + + kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"} + try: + embeddings = await self.client.embeddings.create( + input=input_data, + model=self.model, + **kwargs_retry, + ) + except Exception as retry_error: + if batch: logger.error( f"[OpenAI Embedding] Batch retry without dimensions also failed: {retry_error}" ) - raise retry_error - else: + else: + logger.error( + f"[OpenAI Embedding] Retry without dimensions also failed: {retry_error}" + ) raise + if batch: + logger.info( + "[OpenAI Embedding] Successfully retrieved batch embeddings without dimensions parameter" + ) + else: + logger.info( + "[OpenAI Embedding] Successfully retrieved embedding without dimensions parameter, marking as vLLM" + ) + + self._mark_as_vllm() + return embeddings + + def _should_retry_without_dimensions(self, exc: Exception, kwargs: dict) -> bool: + if not kwargs.get("dimensions"): + return False + + error_msg = str(exc).lower() + return "matryoshka" in error_msg or "dimensions" in error_msg + + def _configured_dimension(self) -> int | None: + provider_id = self.provider_config.get("id", "unknown") + return parse_configured_embedding_dimension( + self.provider_config.get("embedding_dimensions", ""), + provider_label="OpenAI Embedding", + provider_id=provider_id, + ) + def _embedding_kwargs(self) -> dict: """构建嵌入请求的可选参数""" kwargs = {} @@ -168,18 +179,13 @@ def _embedding_kwargs(self) -> dict: ) return kwargs # 非vLLM服务(OpenAI等)支持dimensions,读取配置 - if embedding_dim_config and embedding_dim_config != "": - try: - dim_value = int(embedding_dim_config) - kwargs["dimensions"] = dim_value - logger.info( - f"[OpenAI Embedding] {provider_id}: Added dimensions parameter: {dim_value}" - ) - except (ValueError, TypeError): - logger.warning( - f"[OpenAI Embedding] {provider_id}: embedding_dimensions is not a valid integer: '{embedding_dim_config}', ignored." - ) - else: + configured_dim = self._configured_dimension() + if configured_dim is not None: + kwargs["dimensions"] = configured_dim + logger.info( + f"[OpenAI Embedding] {provider_id}: Added dimensions parameter: {configured_dim}" + ) + elif embedding_dim_config in (None, ""): logger.info( f"[OpenAI Embedding] {provider_id}: No embedding_dimensions configured, API will use default" ) @@ -188,40 +194,25 @@ def _embedding_kwargs(self) -> dict: def get_dim(self) -> int: """获取向量的维度""" provider_id = self.provider_config.get("id", "unknown") - # 首先尝试从config读取 embedding_dim_config = self.provider_config.get("embedding_dimensions", "") - if embedding_dim_config and embedding_dim_config != "": - try: - dim = int(embedding_dim_config) - if dim > 0: - logger.info( - f"[OpenAI Embedding] {provider_id}: Dimension from config: {dim}" - ) - return dim - except (ValueError, TypeError): - logger.warning( - f"[OpenAI Embedding] {provider_id}: embedding_dimensions is not a valid integer: '{embedding_dim_config}', trying model inference" - ) - # config为空或无效时根据模型名推断维度 - # 这样Living Memory可以在自动检测后匹配正确的维度 - model = self.provider_config.get("embedding_model", "").lower() - model_dims = { - "bge-m3": 1024, - "bge-large-en-v1.5": 1024, - "bge-large-zh-v1.5": 1024, - "text-embedding-3-small": 1536, - "text-embedding-3-large": 3072, - "text-embedding-ada-002": 1536, - } - for model_key, dim in model_dims.items(): - if model_key in model: - logger.info( - f"[OpenAI Embedding] {provider_id}: Inferred dimension {dim} from model: {model}" - ) - return dim - # 无法推断时返回0(Living Memory会检测实际维度) + + configured_dim = self._configured_dimension() + if configured_dim is not None: + logger.info( + f"[OpenAI Embedding] {provider_id}: Dimension from config: {configured_dim}" + ) + return configured_dim + + model = self.provider_config.get("embedding_model", "") + inferred_dim = infer_embedding_dimension_from_model(model) + if inferred_dim: + logger.info( + f"[OpenAI Embedding] {provider_id}: Inferred dimension {inferred_dim} from model: {str(model).lower()}" + ) + return inferred_dim + logger.warning( - f"[OpenAI Embedding] {provider_id}: Could not determine dimension (model: {model}, config: '{embedding_dim_config}')" + f"[OpenAI Embedding] {provider_id}: Could not determine dimension (model: {str(model).lower()}, config: '{embedding_dim_config}')" ) return 0 diff --git a/astrbot/core/provider/sources/vllm_embedding_source.py b/astrbot/core/provider/sources/vllm_embedding_source.py index c7e506b942..ec09e4b32b 100644 --- a/astrbot/core/provider/sources/vllm_embedding_source.py +++ b/astrbot/core/provider/sources/vllm_embedding_source.py @@ -12,16 +12,10 @@ from ..entities import ProviderType from ..provider import EmbeddingProvider from ..register import register_provider_adapter - - -_COMMON_MODEL_DIMENSIONS = { - "bge-m3": 1024, - "bge-large-en-v1.5": 1024, - "bge-large-zh-v1.5": 1024, - "text-embedding-3-small": 1536, - "text-embedding-3-large": 3072, - "text-embedding-ada-002": 1536, -} +from .embedding_utils import ( + infer_embedding_dimension_from_model, + parse_configured_embedding_dimension, +) @register_provider_adapter( @@ -211,26 +205,14 @@ def _match_served_model( return None def _configured_dimension(self) -> int | None: - raw_dimension = self.provider_config.get("embedding_dimensions", "") - if raw_dimension in (None, ""): - return None - try: - dimension = int(raw_dimension) - except (TypeError, ValueError): - logger.warning( - "[vLLM Embedding] %s 的 embedding_dimensions 不是有效整数: %r", - self._provider_id(), - raw_dimension, - ) - return None - return dimension if dimension > 0 else None + return parse_configured_embedding_dimension( + self.provider_config.get("embedding_dimensions", ""), + provider_label="vLLM Embedding", + provider_id=self._provider_id(), + ) def _infer_dimension_from_model(self, model_name: Any) -> int | None: - normalized_model = str(model_name or "").strip().lower() - for model_key, dimension in _COMMON_MODEL_DIMENSIONS.items(): - if model_key in normalized_model: - return dimension - return None + return infer_embedding_dimension_from_model(model_name) def _cache_detected_dimension(self, dimension: int) -> None: if isinstance(dimension, int) and dimension > 0: diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index 9a332fb0e5..244ce85fe0 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -102,6 +102,21 @@ function saveEditedContent() { dialog.value = false } +function getNumericEmbeddingDimension(value) { + if (typeof value === 'number' && Number.isInteger(value) && value >= 0) { + return value + } + + if (typeof value === 'string') { + const trimmedValue = value.trim() + if (/^\d+$/.test(trimmedValue)) { + return Number(trimmedValue) + } + } + + return null +} + async function getEmbeddingDimensions(providerConfig) { if (loadingEmbeddingDim.value) return loadingEmbeddingDim.value = true @@ -110,11 +125,17 @@ async function getEmbeddingDimensions(providerConfig) { provider_config: providerConfig }) if (response.data.status != "error" && response.data.data?.embedding_dimensions) { - console.log(response.data.data.embedding_dimensions) - //[已禁用] 不再自动写入配置文件,仅显示提示 - // providerConfig.embedding_dimensions = response.data.data.embedding_dimensions - useToast().success("获取成功: " + response.data.data.embedding_dimensions) - useToast().info(`检测到维度: ${response.data.data.embedding_dimensions}。如需保存,请手动填入后点保存。`) + const detectedDimension = response.data.data.embedding_dimensions + const numericDimension = getNumericEmbeddingDimension(detectedDimension) + + console.log(detectedDimension) + + if (numericDimension !== null) { + providerConfig.embedding_dimensions = numericDimension + useToast().success("获取成功: " + numericDimension) + } else { + useToast().info(`检测到维度: ${detectedDimension}。如需保存,请手动填入后点保存。`) + } } else { useToast().error(response.data.message) }