@@ -58,11 +57,11 @@
📕 目录
- 💡 [RAGFlow 是什么?](#-RAGFlow-是什么)
-- 🎮 [Demo](#-demo)
+- 🎮 [快速开始](#-快速开始)
- 📌 [近期更新](#-近期更新)
- 🌟 [主要功能](#-主要功能)
- 🔎 [系统架构](#-系统架构)
-- 🎬 [快速开始](#-快速开始)
+- 🎬 [自主托管](#-自主托管)
- 🔧 [系统配置](#-系统配置)
- 🔨 [以源代码启动服务](#-以源代码启动服务)
- 📚 [技术文档](#-技术文档)
@@ -77,9 +76,9 @@
[RAGFlow](https://ragflow.io/) 是一款领先的开源检索增强生成([RAG](https://ragflow.io/basics/what-is-rag))引擎,通过融合前沿的 RAG 技术与 Agent 能力,为大型语言模型提供卓越的上下文层。它提供可适配任意规模企业的端到端 RAG 工作流,凭借融合式[上下文引擎](https://ragflow.io/basics/what-is-agent-context-engine)与预置的 Agent 模板,助力开发者以极致效率与精度将复杂数据转化为高可信、生产级的人工智能系统。
-## 🎮 Demo 试用
+## 🎮 快速开始
-请登录网址 [https://cloud.ragflow.io](https://cloud.ragflow.io) 试用 demo。
+请登录网址 [https://cloud.ragflow.io](https://cloud.ragflow.io) 体验云服务。
@@ -88,8 +87,9 @@
## 🔥 近期更新
+- 2026-04-24 支持 DeepSeek v4.
- 2026-03-24 发布 [RAGFlow 官方 Skill](https://clawhub.ai/yingfeng/ragflow-skill) — 提供官方 Skill 以通过 OpenClaw 访问 RAGFlow 数据集。
-- 2025-12-26 支持AI代理的"记忆"功能。
+- 2025-12-26 支持 AI 代理的"记忆"功能。
- 2025-11-19 支持 Gemini 3 Pro。
- 2025-11-12 支持从 Confluence、S3、Notion、Discord、Google Drive 进行数据同步。
- 2025-10-23 支持 MinerU 和 Docling 作为文档解析方法。
@@ -144,7 +144,7 @@
-## 🎬 快速开始
+## 🎬 自主托管
### 📝 前提条件
@@ -192,12 +192,12 @@
> 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。
> 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。
- > 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.25.0`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.25.0` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
+ > 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.25.2`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.25.2` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。
```bash
$ cd ragflow/docker
- # git checkout v0.25.0
+ # git checkout v0.25.2
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases)
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
@@ -410,8 +410,8 @@ docker build --platform linux/amd64 \
## 🏄 开源社区
-- [Discord](https://discord.gg/zd4qPW6t)
-- [Twitter](https://twitter.com/infiniflowai)
+- [Discord](https://discord.gg/NjYzJD3GM3)
+- [X](https://x.com/infiniflowai)
- [GitHub Discussions](https://github.com/orgs/infiniflow/discussions)
## 🙌 贡献指南
diff --git a/admin/client/README.md b/admin/client/README.md
index f71033d6482..cac7425aad8 100644
--- a/admin/client/README.md
+++ b/admin/client/README.md
@@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple
1. Ensure the Admin Service is running.
2. Install ragflow-cli.
```bash
- pip install ragflow-cli==0.25.0
+ pip install ragflow-cli==0.25.2
```
3. Launch the CLI client:
```bash
diff --git a/admin/client/pyproject.toml b/admin/client/pyproject.toml
index 48391a836d8..5f70bb1b188 100644
--- a/admin/client/pyproject.toml
+++ b/admin/client/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ragflow-cli"
-version = "0.25.0"
+version = "0.25.2"
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
license = { text = "Apache License, Version 2.0" }
diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py
index b9f04783ced..148af4b45fe 100644
--- a/admin/client/ragflow_client.py
+++ b/admin/client/ragflow_client.py
@@ -1215,12 +1215,12 @@ def chat_on_session(self, command):
# Prepare payload for completion API
# Note: stream parameter is not sent, server defaults to stream=True
payload = {
- "conversation_id": session_id,
+ "session_id": session_id,
"messages": [{"role": "user", "content": message}]
}
- response = self.http_client.request("POST", "/conversation/completion", json_body=payload,
- use_api_base=False, auth_kind="web", stream=True)
+ response = self.http_client.request("POST", "/chat/completions", json_body=payload,
+ use_api_base=True, auth_kind="web", stream=True)
if response.status_code != 200:
print(f"Fail to chat on session, status code: {response.status_code}")
@@ -1325,7 +1325,7 @@ def parse_dataset_docs(self, command_dict):
print(f"Documents {document_names} not found in {dataset_name}")
payload = {"doc_ids": document_ids, "run": 1}
- response = self.http_client.request("POST", "/document/run", json_body=payload, use_api_base=False,
+ response = self.http_client.request("POST", "/documents/ingest", json_body=payload, use_api_base=True,
auth_kind="web")
res_json = response.json()
if response.status_code == 200 and res_json["code"] == 0:
@@ -1351,7 +1351,7 @@ def parse_dataset(self, command_dict):
document_ids.append(doc["id"])
payload = {"doc_ids": document_ids, "run": 1}
- response = self.http_client.request("POST", "/document/run", json_body=payload, use_api_base=False,
+ response = self.http_client.request("POST", "/documents/ingest", json_body=payload, use_api_base=True,
auth_kind="web")
res_json = response.json()
if response.status_code == 200 and res_json["code"] == 0:
diff --git a/admin/client/uv.lock b/admin/client/uv.lock
index 83868d9a20f..0bf404a2308 100644
--- a/admin/client/uv.lock
+++ b/admin/client/uv.lock
@@ -188,7 +188,7 @@ wheels = [
[[package]]
name = "ragflow-cli"
-version = "0.25.0"
+version = "0.25.2"
source = { virtual = "." }
dependencies = [
{ name = "beartype" },
diff --git a/admin/server/auth.py b/admin/server/auth.py
index bd3c0c058ae..0aa96d0e37d 100644
--- a/admin/server/auth.py
+++ b/admin/server/auth.py
@@ -58,7 +58,7 @@ def load_user(web_request):
return None
# Decode JWT to get the UUID access_token
- jwt = Serializer(secret_key=settings.SECRET_KEY)
+ jwt = Serializer(secret_key=settings.get_secret_key())
access_token = str(jwt.loads(jwt_token))
if not access_token or not access_token.strip():
diff --git a/agent/canvas.py b/agent/canvas.py
index 65303ca9e9e..ab6d0ba9ff1 100644
--- a/agent/canvas.py
+++ b/agent/canvas.py
@@ -354,23 +354,21 @@ def reset(self, mem=False):
key = k[4:]
if key in self.variables:
variable = self.variables[key]
- if variable["type"] == "string":
- self.globals[k] = ""
- variable["value"] = ""
- elif variable["type"] == "number":
- self.globals[k] = 0
- variable["value"] = 0
- elif variable["type"] == "boolean":
- self.globals[k] = False
- variable["value"] = False
- elif variable["type"] == "object":
- self.globals[k] = {}
- variable["value"] = {}
- elif variable["type"].startswith("array"):
- self.globals[k] = []
- variable["value"] = []
+ value = variable.get("value")
+ if value is not None:
+ self.globals[k] = value
else:
- self.globals[k] = ""
+ var_type = variable.get("type", "")
+ if var_type == "number":
+ self.globals[k] = 0
+ elif var_type == "boolean":
+ self.globals[k] = False
+ elif var_type == "object":
+ self.globals[k] = {}
+ elif var_type.startswith("array"):
+ self.globals[k] = []
+ else: # "string" or unknown
+ self.globals[k] = ""
else:
self.globals[k] = ""
@@ -381,8 +379,10 @@ async def run(self, **kwargs):
self.message_id = get_uuid()
created_at = int(time.time())
self.add_user_input(kwargs.get("query"))
+ path_set = set(self.path)
for k, cpn in self.components.items():
- self.components[k]["obj"].reset(True)
+ if k in path_set:
+ self.components[k]["obj"].reset(True)
if kwargs.get("webhook_payload"):
for k, cpn in self.components.items():
diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py
index 56f23afe350..859064046d6 100644
--- a/agent/component/agent_with_tools.py
+++ b/agent/component/agent_with_tools.py
@@ -145,7 +145,8 @@ def get_meta(self) -> dict[str, Any]:
self._param.function_name = self._id.split("-->")[-1]
m = super().get_meta()
if hasattr(self._param, "user_prompt") and self._param.user_prompt:
- m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt
+ # Keep the JSON schema valid; user_prompt is a string field, not a schema node.
+ m["function"]["parameters"]["properties"]["user_prompt"]["default"] = self._param.user_prompt
return m
def get_input_form(self) -> dict[str, dict]:
@@ -276,10 +277,13 @@ async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt=
return
if delta.find("**ERROR**") >= 0:
if self.get_exception_default_value():
- self.set_output("content", self.get_exception_default_value())
- yield self.get_exception_default_value()
+ fallback = self.get_exception_default_value()
+ self.set_output("content", fallback)
+ yield fallback
else:
self.set_output("_ERROR", delta)
+ self.set_output("content", delta)
+ yield delta
return
if not need2cite or cited:
yield delta
diff --git a/agent/component/docs_generator.py b/agent/component/docs_generator.py
index d51b0ea591e..ce7a3abad59 100644
--- a/agent/component/docs_generator.py
+++ b/agent/component/docs_generator.py
@@ -1,3 +1,4 @@
+import base64
import logging
import json
import os
@@ -48,6 +49,7 @@ def __init__(self):
self.watermark_text = ""
self.add_page_numbers = True
self.add_timestamp = True
+ self.include_download_info_in_content = False
self.font_size = 12
self.outputs = {
"download": {"value": "", "type": "string"},
@@ -113,6 +115,7 @@ def _invoke(self, **kwargs):
raise Exception("Document file is empty")
file_size = len(file_bytes)
+ file_base64 = base64.b64encode(file_bytes).decode("utf-8")
doc_id = get_uuid()
settings.STORAGE_IMPL.put(self._canvas.get_tenant_id(), doc_id, file_bytes)
@@ -128,6 +131,8 @@ def _invoke(self, **kwargs):
"filename": filename,
"mime_type": mime_type,
"size": file_size,
+ "base64": file_base64,
+ "include_download_info_in_content": self._param.include_download_info_in_content,
}
self.set_output("download", json.dumps(download_info))
return download_info
diff --git a/agent/component/invoke.py b/agent/component/invoke.py
index 0dce464ebf0..4faaa7d0135 100644
--- a/agent/component/invoke.py
+++ b/agent/component/invoke.py
@@ -179,10 +179,7 @@ def _build_headers(self, kwargs: dict) -> dict:
if not isinstance(headers, dict):
raise ValueError("Invoke headers must be a JSON object.")
- return {
- key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value
- for key, value in headers.items()
- }
+ return {key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value for key, value in headers.items()}
def _build_proxies(self) -> dict | None:
if not re.sub(r"https?:?/?/?", "", self._param.proxy):
@@ -215,7 +212,7 @@ def _format_response(self, response) -> str:
# HtmlParser keeps the Invoke output text-focused when the endpoint returns HTML.
sections = HtmlParser()(None, response.content)
return "\n".join(sections)
-
+
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Invoke processing"):
diff --git a/agent/component/list_operations.py b/agent/component/list_operations.py
index 6016f758507..953e1455293 100644
--- a/agent/component/list_operations.py
+++ b/agent/component/list_operations.py
@@ -10,8 +10,9 @@ class ListOperationsParam(ComponentParamBase):
def __init__(self):
super().__init__()
self.query = ""
- self.operations = "topN"
- self.n=0
+ self.operations = "nth"
+ self.n = 0
+ self.strict = False
self.sort_method = "asc"
self.filter = {
"operator": "=",
@@ -34,7 +35,11 @@ def __init__(self):
def check(self):
self.check_empty(self.query, "query")
- self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"])
+ self.check_valid_value(
+ self.operations,
+ "Support operations",
+ ["nth", "head", "tail", "filter", "sort", "drop_duplicates"],
+ )
def get_input_form(self) -> dict[str, dict]:
return {}
@@ -51,8 +56,8 @@ def _invoke(self, **kwargs):
if not isinstance(self.inputs, list):
raise TypeError("The input of List Operations should be an array.")
self.set_input_value(inputs, self.inputs)
- if self._param.operations == "topN":
- self._topN()
+ if self._param.operations == "nth":
+ self._nth()
elif self._param.operations == "head":
self._head()
elif self._param.operations == "tail":
@@ -70,35 +75,74 @@ def _coerce_n(self):
return int(getattr(self._param, "n", 0))
except Exception:
return 0
-
+
+ def _is_strict(self):
+ strict = getattr(self._param, "strict", False)
+ if isinstance(strict, str):
+ return strict.strip().lower() in {"1", "true", "yes", "on"}
+ return bool(strict)
+
def _set_outputs(self, outputs):
self._param.outputs["result"]["value"] = outputs
self._param.outputs["first"]["value"] = outputs[0] if outputs else None
self._param.outputs["last"]["value"] = outputs[-1] if outputs else None
-
- def _topN(self):
+
+ def _raise_strict_range_error(self, operation, n):
+ raise ValueError(
+ f"{operation} requires n to be within the valid range in strict mode, got {n}."
+ )
+
+ def _nth(self):
n = self._coerce_n()
- if n < 1:
+ strict = self._is_strict()
+ if n == 0:
+ if strict:
+ self._raise_strict_range_error("nth", n)
outputs = []
+ elif n > 0:
+ if n <= len(self.inputs):
+ outputs = [self.inputs[n - 1]]
+ elif strict:
+ self._raise_strict_range_error("nth", n)
+ else:
+ outputs = []
else:
- n = min(n, len(self.inputs))
- outputs = self.inputs[:n]
+ if abs(n) <= len(self.inputs):
+ outputs = [self.inputs[n]]
+ elif strict:
+ self._raise_strict_range_error("nth", n)
+ else:
+ outputs = []
self._set_outputs(outputs)
def _head(self):
n = self._coerce_n()
- if 1 <= n <= len(self.inputs):
- outputs = [self.inputs[n - 1]]
+ strict = self._is_strict()
+ if strict:
+ if 1 <= n <= len(self.inputs):
+ outputs = self.inputs[:n]
+ else:
+ self._raise_strict_range_error("head", n)
else:
- outputs = []
+ if n < 1:
+ outputs = []
+ else:
+ outputs = self.inputs[:n]
self._set_outputs(outputs)
def _tail(self):
n = self._coerce_n()
- if 1 <= n <= len(self.inputs):
- outputs = [self.inputs[-n]]
+ strict = self._is_strict()
+ if strict:
+ if 1 <= n <= len(self.inputs):
+ outputs = self.inputs[-n:]
+ else:
+ self._raise_strict_range_error("tail", n)
else:
- outputs = []
+ if n < 1:
+ outputs = []
+ else:
+ outputs = self.inputs[-n:]
self._set_outputs(outputs)
def _filter(self):
@@ -107,7 +151,7 @@ def _filter(self):
def _norm(self,v):
s = "" if v is None else str(v)
return s
-
+
def _eval(self, v, operator, value):
if operator == "=":
return v == value
@@ -163,6 +207,6 @@ def _hashable(self,x):
if isinstance(x, set):
return tuple(sorted(self._hashable(v) for v in x))
return x
-
+
def thoughts(self) -> str:
return "ListOperation in progress"
diff --git a/agent/component/message.py b/agent/component/message.py
index 8db4eedbd14..a52741f6b36 100644
--- a/agent/component/message.py
+++ b/agent/component/message.py
@@ -75,6 +75,22 @@ def _is_download_info(value: Any) -> bool:
key in value for key in ("doc_id", "filename", "mime_type")
)
+ @staticmethod
+ def _download_info_includes_content(value: Any) -> bool:
+ return isinstance(value, dict) and bool(value.get("include_download_info_in_content"))
+
+ @staticmethod
+ def _normalize_download_info(value: Any) -> Any:
+ if isinstance(value, list):
+ return [Message._normalize_download_info(item) for item in value]
+
+ if not isinstance(value, dict):
+ return value
+
+ normalized = value.copy()
+ normalized.pop("include_download_info_in_content", None)
+ return normalized
+
def _extract_downloads(self, value: Any) -> list[dict[str, Any]]:
if isinstance(value, str):
try:
@@ -100,7 +116,19 @@ def _stringify_message_value(
extracted_downloads = self._extract_downloads(value)
if extracted_downloads:
if downloads is not None:
- downloads.extend(extracted_downloads)
+ downloads.extend(self._normalize_download_info(item) for item in extracted_downloads)
+ if any(self._download_info_includes_content(item) for item in extracted_downloads):
+ if isinstance(value, str):
+ try:
+ value = json.loads(value)
+ except Exception:
+ return value
+ try:
+ return json.dumps(self._normalize_download_info(value), ensure_ascii=False)
+ except Exception:
+ if fallback_to_str:
+ return str(value)
+ return ""
return ""
if value is None:
diff --git a/agent/component/variable_assigner.py b/agent/component/variable_assigner.py
index 08b28334312..dd6182c7ce0 100644
--- a/agent/component/variable_assigner.py
+++ b/agent/component/variable_assigner.py
@@ -141,20 +141,18 @@ def _extend(self,variable,parameter):
return variable + parameter
def _remove_first(self,variable):
- if len(variable)==0:
- return variable
if not isinstance(variable,list):
return "ERROR:VARIABLE_NOT_LIST"
- else:
- return variable[1:]
-
- def _remove_last(self,variable):
if len(variable)==0:
return variable
+ return variable[1:]
+
+ def _remove_last(self,variable):
if not isinstance(variable,list):
return "ERROR:VARIABLE_NOT_LIST"
- else:
- return variable[:-1]
+ if len(variable)==0:
+ return variable
+ return variable[:-1]
def is_number(self, value):
if isinstance(value, bool):
diff --git a/agent/sandbox/client.py b/agent/sandbox/client.py
index 4d49ae734c6..9ca51cc8e3a 100644
--- a/agent/sandbox/client.py
+++ b/agent/sandbox/client.py
@@ -23,11 +23,12 @@
import json
import logging
+import os
from typing import Dict, Any, Optional
from api.db.services.system_settings_service import SystemSettingsService
from agent.sandbox.providers import ProviderManager
-from agent.sandbox.providers.base import ExecutionResult
+from agent.sandbox.providers.base import ExecutionResult, SandboxProviderConfigError
logger = logging.getLogger(__name__)
@@ -59,8 +60,8 @@ def _load_provider_from_settings() -> None:
"""
Load sandbox provider from system settings and configure the provider manager.
- This function reads the system settings to determine which provider is active
- and initializes it with the appropriate configuration.
+ This function resolves the active provider type, then loads configuration
+ from system settings with environment overrides for that provider.
"""
global _provider_manager
@@ -68,41 +69,27 @@ def _load_provider_from_settings() -> None:
return
try:
- # Get active provider type
- provider_type_settings = SystemSettingsService.get_by_name("sandbox.provider_type")
- if not provider_type_settings:
- raise RuntimeError(
- "Sandbox provider type not configured. Please set 'sandbox.provider_type' in system settings."
- )
- provider_type = provider_type_settings[0].value
-
- # Get provider configuration
- provider_config_settings = SystemSettingsService.get_by_name(f"sandbox.{provider_type}")
-
- if not provider_config_settings:
- logger.warning(f"No configuration found for provider: {provider_type}")
- config = {}
- else:
- try:
- config = json.loads(provider_config_settings[0].value)
- except json.JSONDecodeError as e:
- logger.error(f"Failed to parse sandbox config for {provider_type}: {e}")
- config = {}
+ provider_type, provider_type_from_env = _resolve_provider_type()
+ config = _load_provider_config(provider_type)
# Import and instantiate the provider
from agent.sandbox.providers import (
SelfManagedProvider,
AliyunCodeInterpreterProvider,
E2BProvider,
+ LocalProvider,
)
provider_classes = {
"self_managed": SelfManagedProvider,
"aliyun_codeinterpreter": AliyunCodeInterpreterProvider,
"e2b": E2BProvider,
+ "local": LocalProvider,
}
if provider_type not in provider_classes:
+ if provider_type_from_env:
+ raise SandboxProviderConfigError(f"Unknown sandbox provider type: {provider_type}")
logger.error(f"Unknown provider type: {provider_type}")
return
@@ -111,19 +98,97 @@ def _load_provider_from_settings() -> None:
# Initialize the provider
if not provider.initialize(config):
- logger.error(f"Failed to initialize sandbox provider: {provider_type}. Config keys: {list(config.keys())}")
+ message = f"Failed to initialize sandbox provider: {provider_type}. Config keys: {list(config.keys())}"
+ if provider_type == "local" or provider_type_from_env:
+ raise SandboxProviderConfigError(message)
+ logger.error(message)
return
# Set the active provider
_provider_manager.set_provider(provider_type, provider)
logger.info(f"Sandbox provider '{provider_type}' initialized successfully")
+ except SandboxProviderConfigError:
+ raise
except Exception as e:
logger.error(f"Failed to load sandbox provider from settings: {e}")
import traceback
traceback.print_exc()
+def _load_provider_config_from_settings(provider_type: str) -> Dict[str, Any]:
+ provider_config_settings = SystemSettingsService.get_by_name(f"sandbox.{provider_type}")
+ if not provider_config_settings:
+ logger.warning(f"No configuration found for provider: {provider_type}")
+ return {}
+
+ try:
+ return json.loads(provider_config_settings[0].value)
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse sandbox config for {provider_type}: {e}")
+ return {}
+
+
+def _resolve_provider_type() -> tuple[str, bool]:
+ provider_type = os.environ.get("SANDBOX_PROVIDER_TYPE", "").strip()
+ if provider_type:
+ return provider_type, True
+
+ provider_type_settings = SystemSettingsService.get_by_name("sandbox.provider_type")
+ if not provider_type_settings:
+ raise RuntimeError(
+ "Sandbox provider type not configured. Please set 'sandbox.provider_type' in system settings."
+ )
+ return provider_type_settings[0].value, False
+
+
+def _load_provider_config(provider_type: str) -> Dict[str, Any]:
+ config = _load_provider_config_from_settings(provider_type)
+ env_config = _load_provider_config_from_env(provider_type)
+ if env_config:
+ config.update(env_config)
+ return config
+
+
+def _load_provider_config_from_env(provider_type: str) -> Dict[str, Any]:
+ if provider_type == "local":
+ return _load_local_provider_config_from_env()
+ if provider_type == "self_managed":
+ return _load_self_managed_provider_config_from_env()
+ return {}
+
+
+def _load_local_provider_config_from_env() -> Dict[str, Any]:
+ env_to_config = {
+ "SANDBOX_LOCAL_PYTHON_BIN": "python_bin",
+ "SANDBOX_LOCAL_NODE_BIN": "node_bin",
+ "SANDBOX_LOCAL_WORK_DIR": "work_dir",
+ "SANDBOX_LOCAL_TIMEOUT": "timeout",
+ "SANDBOX_LOCAL_MAX_MEMORY_MB": "max_memory_mb",
+ "SANDBOX_LOCAL_MAX_OUTPUT_BYTES": "max_output_bytes",
+ "SANDBOX_LOCAL_MAX_ARTIFACTS": "max_artifacts",
+ "SANDBOX_LOCAL_MAX_ARTIFACT_BYTES": "max_artifact_bytes",
+ }
+ config = {}
+ for env_name, config_name in env_to_config.items():
+ if env_name in os.environ:
+ config[config_name] = os.environ[env_name]
+ return config
+
+
+def _load_self_managed_provider_config_from_env() -> Dict[str, Any]:
+ host = os.environ.get("SANDBOX_HOST", "").strip()
+ port = os.environ.get("SANDBOX_EXECUTOR_MANAGER_PORT", "").strip()
+ pool_size = os.environ.get("SANDBOX_EXECUTOR_MANAGER_POOL_SIZE", "").strip()
+
+ config = {}
+ if host:
+ config["endpoint"] = f"http://{host}:{port or '9385'}"
+ if pool_size:
+ config["pool_size"] = pool_size
+ return config
+
+
def reload_provider() -> None:
"""
Reload the sandbox provider from system settings.
diff --git a/agent/sandbox/providers/__init__.py b/agent/sandbox/providers/__init__.py
index 7be1463b9ca..e7cfc2ddc9c 100644
--- a/agent/sandbox/providers/__init__.py
+++ b/agent/sandbox/providers/__init__.py
@@ -24,20 +24,24 @@
- aliyun_codeinterpreter.py: Aliyun Code Interpreter provider implementation
Official Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter
- e2b.py: E2B provider implementation
+- local.py: Local process provider implementation
"""
-from .base import SandboxProvider, SandboxInstance, ExecutionResult
+from .base import SandboxProvider, SandboxInstance, ExecutionResult, SandboxProviderConfigError
from .manager import ProviderManager
from .self_managed import SelfManagedProvider
from .aliyun_codeinterpreter import AliyunCodeInterpreterProvider
from .e2b import E2BProvider
+from .local import LocalProvider
__all__ = [
"SandboxProvider",
"SandboxInstance",
"ExecutionResult",
+ "SandboxProviderConfigError",
"ProviderManager",
"SelfManagedProvider",
"AliyunCodeInterpreterProvider",
"E2BProvider",
+ "LocalProvider",
]
diff --git a/agent/sandbox/providers/aliyun_codeinterpreter.py b/agent/sandbox/providers/aliyun_codeinterpreter.py
index 8ee99ed1ecc..bbec2a26820 100644
--- a/agent/sandbox/providers/aliyun_codeinterpreter.py
+++ b/agent/sandbox/providers/aliyun_codeinterpreter.py
@@ -30,7 +30,6 @@
import logging
import os
import time
-import base64
import json
from typing import Dict, Any, List, Optional
from datetime import datetime, timezone
@@ -39,10 +38,10 @@
from agentrun.utils.config import Config
from agentrun.utils.exception import ServerError
+from agent.sandbox.result_protocol import build_javascript_wrapper, build_python_wrapper, extract_structured_result
from .base import SandboxProvider, SandboxInstance, ExecutionResult
logger = logging.getLogger(__name__)
-RESULT_MARKER_PREFIX = "__RAGFLOW_RESULT__:"
class AliyunCodeInterpreterProvider(SandboxProvider):
@@ -234,9 +233,9 @@ def execute_code(self, instance_id: str, code: str, language: str, timeout: int
# Matches self_managed provider behavior: call main(**arguments)
args_json = json.dumps(arguments or {})
wrapped_code = (
- self._build_python_wrapper(code, args_json)
+ build_python_wrapper(code, args_json)
if normalized_lang == "python"
- else self._build_javascript_wrapper(code, args_json)
+ else build_javascript_wrapper(code, args_json)
)
logger.debug(f"Aliyun Code Interpreter: Wrapped code (first 200 chars): {wrapped_code[:200]}")
@@ -284,7 +283,7 @@ def execute_code(self, instance_id: str, code: str, language: str, timeout: int
stdout = "\n".join(stdout_parts)
stderr = "\n".join(stderr_parts)
- stdout, structured_result = self._extract_structured_result(stdout)
+ stdout, structured_result = extract_structured_result(stdout)
logger.info(f"Aliyun Code Interpreter: stdout length={len(stdout)}, stderr length={len(stderr)}, exit_code={exit_code}")
if stdout:
@@ -364,71 +363,6 @@ def health_check(self) -> bool:
# If we get any response (even an error), the service is reachable
return "connection" not in str(e).lower()
- @staticmethod
- def _build_python_wrapper(code: str, args_json: str) -> str:
- marker = RESULT_MARKER_PREFIX
- return f'''{code}
-
-if __name__ == "__main__":
- import base64
- import json
-
- result = main(**{args_json})
- payload = json.dumps({{"present": True, "value": result, "type": "json"}}, ensure_ascii=False, separators=(",", ":"))
- print("{marker}" + base64.b64encode(payload.encode("utf-8")).decode("ascii"))
-'''
-
- @staticmethod
- def _build_javascript_wrapper(code: str, args_json: str) -> str:
- marker = RESULT_MARKER_PREFIX
- return f'''{code}
-
-const __ragflowArgs = {args_json};
-
-(async () => {{
- try {{
- const output = await Promise.resolve(main(__ragflowArgs));
- if (typeof output === 'undefined') {{
- throw new Error('main() must return a value. Use null for an empty result.');
- }}
- const payload = JSON.stringify({{ present: true, value: output, type: 'json' }});
- if (typeof payload === 'undefined') {{
- throw new Error('main() returned a non-JSON-serializable value.');
- }}
- console.log('{marker}' + Buffer.from(payload, 'utf8').toString('base64'));
- }} catch (err) {{
- console.error(err instanceof Error ? err.stack || err.message : String(err));
- }}
-}})();
-'''
-
- @staticmethod
- def _extract_structured_result(stdout: str) -> tuple[str, Dict[str, Any]]:
- if not stdout:
- return "", {}
-
- cleaned_lines: list[str] = []
- structured_result: Dict[str, Any] = {}
-
- for line in str(stdout).splitlines():
- if line.startswith(RESULT_MARKER_PREFIX):
- payload_b64 = line[len(RESULT_MARKER_PREFIX) :].strip()
- if not payload_b64:
- continue
- try:
- payload = base64.b64decode(payload_b64).decode("utf-8")
- structured_result = json.loads(payload)
- except Exception as exc:
- logger.warning(f"Aliyun Code Interpreter: failed to decode structured result marker: {exc}")
- cleaned_lines.append(line)
- continue
- cleaned_lines.append(line)
-
- cleaned_stdout = "\n".join(cleaned_lines)
- if stdout.endswith("\n") and cleaned_stdout and not cleaned_stdout.endswith("\n"):
- cleaned_stdout += "\n"
- return cleaned_stdout, structured_result
-
def get_supported_languages(self) -> List[str]:
"""
Get list of supported programming languages.
diff --git a/agent/sandbox/providers/base.py b/agent/sandbox/providers/base.py
index c21b583e02b..8f9c04aaa42 100644
--- a/agent/sandbox/providers/base.py
+++ b/agent/sandbox/providers/base.py
@@ -26,6 +26,10 @@
from typing import Dict, Any, Optional, List
+class SandboxProviderConfigError(Exception):
+ """Raised when the selected provider is explicitly configured but unusable."""
+
+
@dataclass
class SandboxInstance:
"""Represents a sandbox execution instance"""
@@ -209,4 +213,4 @@ def validate_config(self, config: Dict[str, Any]) -> tuple[bool, Optional[str]]:
>>> return True, None
"""
# Default implementation: no custom validation
- return True, None
\ No newline at end of file
+ return True, None
diff --git a/agent/sandbox/providers/local.py b/agent/sandbox/providers/local.py
new file mode 100644
index 00000000000..b8057fa5b43
--- /dev/null
+++ b/agent/sandbox/providers/local.py
@@ -0,0 +1,296 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import base64
+import json
+import mimetypes
+import os
+import shutil
+import signal
+import subprocess
+import time
+import uuid
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from agent.sandbox.result_protocol import build_javascript_wrapper, build_python_wrapper, extract_structured_result
+from .base import ExecutionResult, SandboxInstance, SandboxProvider, SandboxProviderConfigError
+
+
+ALLOWED_ARTIFACT_EXTENSIONS = {
+ ".csv",
+ ".html",
+ ".jpeg",
+ ".jpg",
+ ".json",
+ ".pdf",
+ ".png",
+ ".svg",
+}
+
+
+def _env_enabled(name: str) -> bool:
+ return os.environ.get(name, "").strip().lower() in {"1", "true", "yes", "on"}
+
+
+class LocalProvider(SandboxProvider):
+ """
+ Execute code as a local child process.
+
+ This provider is intentionally gated by SANDBOX_LOCAL_ENABLED because it is
+ not a sandbox boundary. Use a low-privilege runtime account.
+ """
+
+ def __init__(self):
+ self.python_bin = "python3"
+ self.node_bin = "node"
+ self.work_dir = Path("/tmp/ragflow-codeexec")
+ self.timeout = 30
+ self.max_memory_mb = 512
+ self.max_output_bytes = 1024 * 1024
+ self.max_artifacts = 20
+ self.max_artifact_bytes = 10 * 1024 * 1024
+ self._initialized = False
+ self._instances: dict[str, Path] = {}
+
+ def initialize(self, config: Dict[str, Any]) -> bool:
+ if not _env_enabled("SANDBOX_LOCAL_ENABLED"):
+ raise SandboxProviderConfigError("Local code execution is disabled. Set SANDBOX_LOCAL_ENABLED=true to enable it.")
+
+ self.python_bin = str(self._resolve_config_value(config, "python_bin", "SANDBOX_LOCAL_PYTHON_BIN", "python3"))
+ self.node_bin = str(self._resolve_config_value(config, "node_bin", "SANDBOX_LOCAL_NODE_BIN", "node"))
+ self.work_dir = Path(self._resolve_config_value(config, "work_dir", "SANDBOX_LOCAL_WORK_DIR", "/tmp/ragflow-codeexec")).resolve()
+ self.timeout = int(self._resolve_config_value(config, "timeout", "SANDBOX_LOCAL_TIMEOUT", 30))
+ self.max_memory_mb = int(self._resolve_config_value(config, "max_memory_mb", "SANDBOX_LOCAL_MAX_MEMORY_MB", 512))
+ self.max_output_bytes = int(self._resolve_config_value(config, "max_output_bytes", "SANDBOX_LOCAL_MAX_OUTPUT_BYTES", 1024 * 1024))
+ self.max_artifacts = int(self._resolve_config_value(config, "max_artifacts", "SANDBOX_LOCAL_MAX_ARTIFACTS", 20))
+ self.max_artifact_bytes = int(self._resolve_config_value(config, "max_artifact_bytes", "SANDBOX_LOCAL_MAX_ARTIFACT_BYTES", 10 * 1024 * 1024))
+
+ self._validate_limits()
+ self.work_dir.mkdir(parents=True, exist_ok=True, mode=0o700)
+ self._initialized = True
+ return True
+
+ def create_instance(self, template: str = "python") -> SandboxInstance:
+ if not self._initialized:
+ raise RuntimeError("Provider not initialized. Call initialize() first.")
+
+ language = self._normalize_language(template)
+ instance_id = str(uuid.uuid4())
+ instance_dir = self.work_dir / instance_id
+ instance_dir.mkdir(mode=0o700)
+ (instance_dir / "artifacts").mkdir(mode=0o700)
+ self._instances[instance_id] = instance_dir
+
+ return SandboxInstance(
+ instance_id=instance_id,
+ provider="local",
+ status="running",
+ metadata={"language": language, "work_dir": str(instance_dir)},
+ )
+
+ def execute_code(
+ self,
+ instance_id: str,
+ code: str,
+ language: str,
+ timeout: int = 10,
+ arguments: Optional[Dict[str, Any]] = None,
+ ) -> ExecutionResult:
+ if not self._initialized:
+ raise RuntimeError("Provider not initialized. Call initialize() first.")
+
+ normalized_lang = self._normalize_language(language)
+ instance_dir = self._instances[instance_id]
+ args_json = json.dumps(arguments or {}, ensure_ascii=False)
+ command, script_path = self._prepare_script(instance_dir, normalized_lang, code, args_json)
+ requested_timeout = self.timeout if timeout is None else int(timeout)
+ if requested_timeout <= 0:
+ raise RuntimeError(f"Execution timeout must be greater than 0 seconds, got {requested_timeout}.")
+ exec_timeout = min(requested_timeout, self.timeout)
+
+ start_time = time.time()
+ process = subprocess.Popen(
+ command,
+ cwd=instance_dir,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ encoding="utf-8",
+ errors="replace",
+ env=self._build_child_env(instance_dir),
+ preexec_fn=self._limit_child_process if os.name == "posix" else None,
+ start_new_session=os.name == "posix",
+ )
+
+ try:
+ stdout, stderr = process.communicate(timeout=exec_timeout)
+ except subprocess.TimeoutExpired:
+ if os.name == "posix":
+ os.killpg(process.pid, signal.SIGKILL)
+ else:
+ process.kill()
+ process.communicate()
+ raise TimeoutError(f"Execution timed out after {exec_timeout} seconds")
+
+ execution_time = time.time() - start_time
+ self._validate_output_size(stdout, stderr)
+ stdout, structured_result = extract_structured_result(stdout)
+
+ return ExecutionResult(
+ stdout=stdout,
+ stderr=stderr,
+ exit_code=process.returncode,
+ execution_time=execution_time,
+ metadata={
+ "instance_id": instance_id,
+ "language": normalized_lang,
+ "script_path": str(script_path),
+ "status": "ok" if process.returncode == 0 else "error",
+ "timeout": exec_timeout,
+ "artifacts": self._collect_artifacts(instance_dir / "artifacts"),
+ "result_present": structured_result.get("present", False),
+ "result_value": structured_result.get("value"),
+ "result_type": structured_result.get("type"),
+ },
+ )
+
+ def destroy_instance(self, instance_id: str) -> bool:
+ if not self._initialized:
+ raise RuntimeError("Provider not initialized. Call initialize() first.")
+
+ instance_dir = self._instances.pop(instance_id)
+ shutil.rmtree(instance_dir)
+ return True
+
+ def health_check(self) -> bool:
+ return self._initialized and self.work_dir.exists() and os.access(self.work_dir, os.W_OK)
+
+ def get_supported_languages(self) -> List[str]:
+ return ["python", "javascript", "nodejs"]
+
+ @staticmethod
+ def get_config_schema() -> Dict[str, Dict]:
+ return {
+ "python_bin": {"type": "string", "required": False, "default": "python3"},
+ "node_bin": {"type": "string", "required": False, "default": "node"},
+ "work_dir": {"type": "string", "required": False, "default": "/tmp/ragflow-codeexec"},
+ "timeout": {"type": "integer", "required": False, "default": 30},
+ "max_memory_mb": {"type": "integer", "required": False, "default": 512},
+ "max_output_bytes": {"type": "integer", "required": False, "default": 1048576},
+ "max_artifacts": {"type": "integer", "required": False, "default": 20},
+ "max_artifact_bytes": {"type": "integer", "required": False, "default": 10485760},
+ }
+
+ def _validate_limits(self) -> None:
+ if self.timeout <= 0:
+ raise SandboxProviderConfigError("SANDBOX_LOCAL_TIMEOUT must be greater than 0.")
+ if self.max_memory_mb <= 0:
+ raise SandboxProviderConfigError("SANDBOX_LOCAL_MAX_MEMORY_MB must be greater than 0.")
+ if self.max_output_bytes <= 0:
+ raise SandboxProviderConfigError("SANDBOX_LOCAL_MAX_OUTPUT_BYTES must be greater than 0.")
+ if self.max_artifacts < 0:
+ raise SandboxProviderConfigError("SANDBOX_LOCAL_MAX_ARTIFACTS must be greater than or equal to 0.")
+ if self.max_artifact_bytes <= 0:
+ raise SandboxProviderConfigError("SANDBOX_LOCAL_MAX_ARTIFACT_BYTES must be greater than 0.")
+
+ def _prepare_script(self, instance_dir: Path, language: str, code: str, args_json: str) -> tuple[list[str], Path]:
+ if language == "python":
+ script_path = instance_dir / "main.py"
+ script_path.write_text(build_python_wrapper(code, args_json), encoding="utf-8")
+ return [self.python_bin, str(script_path)], script_path
+ if language in {"javascript", "nodejs"}:
+ script_path = instance_dir / "main.js"
+ script_path.write_text(build_javascript_wrapper(code, args_json), encoding="utf-8")
+ return [self.node_bin, str(script_path)], script_path
+ raise RuntimeError(f"Unsupported language for local provider: {language}")
+
+ @staticmethod
+ def _resolve_config_value(config: Dict[str, Any], key: str, env_name: str, default: Any) -> Any:
+ value = config.get(key)
+ if value is not None:
+ return value
+ return os.environ.get(env_name, default)
+
+ def _build_child_env(self, instance_dir: Path) -> dict[str, str]:
+ return {
+ "HOME": str(instance_dir),
+ "MPLBACKEND": "Agg",
+ "PATH": os.environ.get("PATH", ""),
+ "PYTHONUNBUFFERED": "1",
+ "TMPDIR": str(instance_dir),
+ }
+
+ def _limit_child_process(self) -> None:
+ import resource
+
+ self._set_resource_limit(resource.RLIMIT_CPU, self.timeout + 1)
+ self._set_resource_limit(resource.RLIMIT_AS, self.max_memory_mb * 1024 * 1024)
+ self._set_resource_limit(resource.RLIMIT_FSIZE, self.max_artifact_bytes)
+ self._set_resource_limit(resource.RLIMIT_NOFILE, 64)
+
+ @staticmethod
+ def _set_resource_limit(kind: int, value: int) -> None:
+ import resource
+
+ _, hard = resource.getrlimit(kind)
+ limit = value if hard == resource.RLIM_INFINITY else min(value, hard)
+ resource.setrlimit(kind, (limit, limit))
+
+ def _validate_output_size(self, stdout: str, stderr: str) -> None:
+ output_size = len((stdout or "").encode("utf-8")) + len((stderr or "").encode("utf-8"))
+ if output_size > self.max_output_bytes:
+ raise RuntimeError(f"Local execution output exceeded {self.max_output_bytes} bytes.")
+
+ def _collect_artifacts(self, artifacts_dir: Path) -> list[dict[str, Any]]:
+ artifacts: list[dict[str, Any]] = []
+ for path in sorted(artifacts_dir.rglob("*")):
+ if path.is_symlink():
+ raise RuntimeError(f"Artifact symlinks are not allowed: {path.name}")
+ if path.is_dir():
+ continue
+ if not path.is_file():
+ raise RuntimeError(f"Unsupported artifact entry: {path.name}")
+
+ if len(artifacts) >= self.max_artifacts:
+ raise RuntimeError(f"Local execution produced more than {self.max_artifacts} artifacts.")
+
+ size = path.stat().st_size
+ if size > self.max_artifact_bytes:
+ raise RuntimeError(f"Artifact exceeds {self.max_artifact_bytes} bytes: {path.name}")
+
+ ext = path.suffix.lower()
+ if ext not in ALLOWED_ARTIFACT_EXTENSIONS:
+ raise RuntimeError(f"Unsupported artifact type: {path.name}")
+
+ artifacts.append(
+ {
+ "name": path.relative_to(artifacts_dir).as_posix(),
+ "content_b64": base64.b64encode(path.read_bytes()).decode("ascii"),
+ "mime_type": mimetypes.guess_type(path.name)[0] or "application/octet-stream",
+ "size": size,
+ }
+ )
+ return artifacts
+
+ @staticmethod
+ def _normalize_language(language: str) -> str:
+ lang_lower = (language or "python").lower()
+ if lang_lower in {"python", "python3"}:
+ return "python"
+ if lang_lower in {"javascript", "nodejs"}:
+ return "nodejs"
+ return lang_lower
diff --git a/agent/sandbox/result_protocol.py b/agent/sandbox/result_protocol.py
new file mode 100644
index 00000000000..f71e5f49968
--- /dev/null
+++ b/agent/sandbox/result_protocol.py
@@ -0,0 +1,85 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import base64
+import json
+from typing import Any
+
+
+RESULT_MARKER_PREFIX = "__RAGFLOW_RESULT__:"
+
+
+def build_python_wrapper(code: str, args_json: str) -> str:
+ return f'''{code}
+
+if __name__ == "__main__":
+ import base64
+ import json
+
+ result = main(**{args_json})
+ payload = json.dumps({{"present": True, "value": result, "type": "json"}}, ensure_ascii=False, separators=(",", ":"))
+ print("{RESULT_MARKER_PREFIX}" + base64.b64encode(payload.encode("utf-8")).decode("ascii"))
+'''
+
+
+def build_javascript_wrapper(code: str, args_json: str) -> str:
+ return f'''{code}
+
+const __ragflowArgs = {args_json};
+
+(async () => {{
+ const __ragflowMain = typeof main !== 'undefined' ? main : module.exports && module.exports.main;
+ if (typeof __ragflowMain !== 'function') {{
+ throw new Error('main() must be defined or exported.');
+ }}
+ const output = await Promise.resolve(__ragflowMain(__ragflowArgs));
+ if (typeof output === 'undefined') {{
+ throw new Error('main() must return a value. Use null for an empty result.');
+ }}
+ const payload = JSON.stringify({{ present: true, value: output, type: 'json' }});
+ if (typeof payload === 'undefined') {{
+ throw new Error('main() returned a non-JSON-serializable value.');
+ }}
+ console.log('{RESULT_MARKER_PREFIX}' + Buffer.from(payload, 'utf8').toString('base64'));
+}})();
+'''
+
+
+def extract_structured_result(stdout: str) -> tuple[str, dict[str, Any]]:
+ if not stdout:
+ return "", {}
+
+ cleaned_lines: list[str] = []
+ structured_result: dict[str, Any] = {}
+
+ for line in str(stdout).splitlines():
+ if line.startswith(RESULT_MARKER_PREFIX):
+ payload_b64 = line[len(RESULT_MARKER_PREFIX) :].strip()
+ if not payload_b64:
+ cleaned_lines.append(line)
+ continue
+ try:
+ payload = base64.b64decode(payload_b64, validate=True).decode("utf-8")
+ structured_result = json.loads(payload)
+ except Exception:
+ cleaned_lines.append(line)
+ continue
+ cleaned_lines.append(line)
+
+ cleaned_stdout = "\n".join(cleaned_lines)
+ if stdout.endswith("\n") and cleaned_stdout and not cleaned_stdout.endswith("\n"):
+ cleaned_stdout += "\n"
+ return cleaned_stdout, structured_result
diff --git a/agent/templates/ingestion_pipeline_Book.json b/agent/templates/ingestion_pipeline_book.json
similarity index 100%
rename from agent/templates/ingestion_pipeline_Book.json
rename to agent/templates/ingestion_pipeline_book.json
diff --git a/agent/templates/ingestion_pipeline_General.json b/agent/templates/ingestion_pipeline_general.json
similarity index 100%
rename from agent/templates/ingestion_pipeline_General.json
rename to agent/templates/ingestion_pipeline_general.json
diff --git a/agent/templates/ingestion_pipeline_Laws.json b/agent/templates/ingestion_pipeline_laws.json
similarity index 100%
rename from agent/templates/ingestion_pipeline_Laws.json
rename to agent/templates/ingestion_pipeline_laws.json
diff --git a/agent/templates/ingestion_pipeline_Manual.json b/agent/templates/ingestion_pipeline_manual.json
similarity index 100%
rename from agent/templates/ingestion_pipeline_Manual.json
rename to agent/templates/ingestion_pipeline_manual.json
diff --git a/agent/templates/ingestion_pipeline_One.json b/agent/templates/ingestion_pipeline_one.json
similarity index 100%
rename from agent/templates/ingestion_pipeline_One.json
rename to agent/templates/ingestion_pipeline_one.json
diff --git a/agent/templates/ingestion_pipeline_Paper.json b/agent/templates/ingestion_pipeline_paper.json
similarity index 100%
rename from agent/templates/ingestion_pipeline_Paper.json
rename to agent/templates/ingestion_pipeline_paper.json
diff --git a/agent/templates/ingestion_pipeline_Resume.json b/agent/templates/ingestion_pipeline_resume.json
similarity index 98%
rename from agent/templates/ingestion_pipeline_Resume.json
rename to agent/templates/ingestion_pipeline_resume.json
index 7b8d9899577..cb35eb2043e 100644
--- a/agent/templates/ingestion_pipeline_Resume.json
+++ b/agent/templates/ingestion_pipeline_resume.json
@@ -242,13 +242,14 @@
"include_heading_content": false,
"levels": [
[
- "^\\s*(?i:(?:\\d+[\\.\\)]\\s*)?(?:EDUCATION|ACADEMIC\\s*BACKGROUND|ACADEMIC\\s*HISTORY|EDUCATIONAL\\s*BACKGROUND|RELEVANT\\s*COURSEWORK|COURSEWORK|EXPERIENCE|WORK\\s*EXPERIENCE|PROFESSIONAL\\s*EXPERIENCE|RELEVANT\\s*EXPERIENCE|EMPLOYMENT\\s*HISTORY|CAREER\\s*HISTORY|INTERNSHIP\\s*EXPERIENCE|PROJECTS|PROJECT\\s*EXPERIENCE|ACADEMIC\\s*PROJECTS|PROFESSIONAL\\s*PROJECTS|SKILLS|TECHNICAL\\s*SKILLS|CORE\\s*COMPETENCIES|COMPETENCIES|QUALIFICATIONS|SUMMARY\\s*OF\\s*QUALIFICATIONS|CERTIFICATIONS|LICENSES|CERTIFICATES|AWARDS|HONORS|HONOURS|ACHIEVEMENTS|PUBLICATIONS|RESEARCH|RESEARCH\\s*EXPERIENCE|LEADERSHIP|LEADERSHIP\\s*EXPERIENCE|ACTIVITIES|EXTRACURRICULAR\\s*ACTIVITIES|ACTIVITIES\\s*(?:&|AND)\\s*SKILLS|INVOLVEMENT|CAMPUS\\s*INVOLVEMENT|VOLUNTEER\\s*EXPERIENCE|VOLUNTEERING|COMMUNITY\\s*SERVICE|LANGUAGES|INTERESTS|HOBBIES|PROFILE|PROFESSIONAL\\s*PROFILE|SUMMARY|PROFESSIONAL\\s*SUMMARY|CAREER\\s*SUMMARY|OBJECTIVE|CAREER\\s*OBJECTIVE|PERSONAL\\s*INFORMATION|CONTACT\\s*INFORMATION|ADDITIONAL\\s*INFORMATION|TRAINING))\\s*[:\uff1a]?\\s*$"
+ "^\\s*(?i:(?:\\d+[\\.\\)]\\s*)?(?:EDUCATION|ACADEMIC\\s*BACKGROUND|ACADEMIC\\s*HISTORY|EDUCATIONAL\\s*BACKGROUND|RELEVANT\\s*COURSEWORK|COURSEWORK|EXPERIENCE|WORK\\s*EXPERIENCE|PROFESSIONAL\\s*EXPERIENCE|RELEVANT\\s*EXPERIENCE|EMPLOYMENT\\s*HISTORY|CAREER\\s*HISTORY|INTERNSHIP\\s*EXPERIENCE|PROJECTS|PROJECT\\s*EXPERIENCE|ACADEMIC\\s*PROJECTS|PROFESSIONAL\\s*PROJECTS|SKILLS|TECHNICAL\\s*SKILLS|CORE\\s*COMPETENCIES|COMPETENCIES|QUALIFICATIONS|SUMMARY\\s*OF\\s*QUALIFICATIONS|CERTIFICATIONS|LICENSES|CERTIFICATES|AWARDS|HONORS|HONOURS|ACHIEVEMENTS|PUBLICATIONS|RESEARCH|RESEARCH\\s*EXPERIENCE|LEADERSHIP|LEADERSHIP\\s*EXPERIENCE|ACTIVITIES|EXTRACURRICULAR\\s*ACTIVITIES|ACTIVITIES\\s*(?:&|AND)\\s*SKILLS|INVOLVEMENT|CAMPUS\\s*INVOLVEMENT|VOLUNTEER\\s*EXPERIENCE|VOLUNTEERING|COMMUNITY\\s*SERVICE|LANGUAGES|INTERESTS|HOBBIES|PROFILE|PROFESSIONAL\\s*PROFILE|SUMMARY|PROFESSIONAL\\s*SUMMARY|CAREER\\s*SUMMARY|OBJECTIVE|CAREER\\s*OBJECTIVE|PERSONAL\\s*INFORMATION|CONTACT\\s*INFORMATION|ADDITIONAL\\s*INFORMATION|TRAINING))\\s*[:\uff1a]?\\s*$"
],
[
"^\\s*(?:\\d+[\\.\u3001\\)]\\s*)?(?:\u6559\u80b2\u80cc\u666f|\u6559\u80b2\u7ecf\u5386|\u5b66\u5386\u80cc\u666f|\u5b66\u672f\u80cc\u666f|\u6280\u672f\u80cc\u666f|\u5de5\u4f5c\u7ecf\u5386|\u5de5\u4f5c\u7ecf\u9a8c|\u5b9e\u4e60\u7ecf\u5386|\u9879\u76ee\u7ecf\u5386|\u9879\u76ee\u7ecf\u9a8c|\u79d1\u7814\u7ecf\u5386|\u7814\u7a76\u7ecf\u5386|\u6821\u56ed\u7ecf\u5386|\u5b9e\u8df5\u7ecf\u5386|\u4e13\u4e1a\u7ecf\u5386|\u804c\u4e1a\u7ecf\u5386|\u6280\u80fd|\u4e13\u4e1a\u6280\u80fd|\u6280\u80fd\u7279\u957f|\u6838\u5fc3\u6280\u80fd|\u6280\u672f\u6808|\u4e2a\u4eba\u6280\u80fd|\u5de5\u4f5c\u6280\u80fd|\u804c\u4e1a\u6280\u80fd|\u6280\u80fd\u4e0e\u8bc4\u4ef7|\u6280\u80fd\u4e0e\u81ea\u6211\u8bc4\u4ef7|\u5de5\u4f5c\u6280\u80fd\u4e0e\u81ea\u6211\u8bc4\u4ef7|\u804c\u4e1a\u6280\u80fd\u4e0e\u81ea\u6211\u8bc4\u4ef7|\u8bc1\u4e66|\u8d44\u683c\u8bc1\u4e66|\u804c\u4e1a\u8d44\u683c|\u8d44\u8d28\u8bc1\u4e66|\u83b7\u5956\u60c5\u51b5|\u83b7\u5956\u7ecf\u5386|\u8363\u8a89|\u8363\u8a89\u5956\u9879|\u5956\u9879|\u79d1\u7814\u6210\u679c|\u8bba\u6587\u53d1\u8868|\u53d1\u8868\u8bba\u6587|\u9886\u5bfc\u7ecf\u5386|\u5b66\u751f\u5de5\u4f5c|\u6821\u56ed\u6d3b\u52a8|\u793e\u56e2\u7ecf\u5386|\u6d3b\u52a8\u7ecf\u5386|\u5fd7\u613f\u7ecf\u5386|\u5fd7\u613f\u670d\u52a1|\u793e\u4f1a\u5b9e\u8df5|\u8bed\u8a00\u80fd\u529b|\u8bed\u8a00|\u81ea\u6211\u8bc4\u4ef7|\u4e2a\u4eba\u8bc4\u4ef7|\u81ea\u6211\u603b\u7ed3|\u4e2a\u4eba\u603b\u7ed3|\u4e2a\u4eba\u4f18\u52bf|\u4e2a\u4eba\u7b80\u4ecb|\u4e2a\u4eba\u4fe1\u606f|\u57fa\u672c\u4fe1\u606f|\u8054\u7cfb\u65b9\u5f0f|\u6c42\u804c\u610f\u5411|\u5e94\u8058\u610f\u5411|\u804c\u4e1a\u76ee\u6807|\u6c42\u804c\u76ee\u6807|\u5174\u8da3\u7231\u597d|\u5174\u8da3\u7279\u957f|\u57f9\u8bad\u7ecf\u5386|\u5176\u4ed6\u4fe1\u606f|\u9644\u52a0\u4fe1\u606f)\\s*[:\uff1a]?\\s*$"
]
],
- "method": "hierarchy"
+ "method": "hierarchy",
+ "root_chunk_as_heading": true
}
},
"upstream": [
@@ -303,21 +304,24 @@
"data": {
"isHovered": false
},
- "id": "xy-edge__TitleChunker:FlatMiceFixstart-Extractor:ThreeDrinksActend",
- "source": "TitleChunker:FlatMiceFix",
+ "id": "xy-edge__Extractor:ThreeDrinksActstart-Tokenizer:KindHandsWinend",
+ "markerEnd": "logo",
+ "source": "Extractor:ThreeDrinksAct",
"sourceHandle": "start",
- "target": "Extractor:ThreeDrinksAct",
- "targetHandle": "end"
+ "target": "Tokenizer:KindHandsWin",
+ "targetHandle": "end",
+ "type": "buttonEdge",
+ "zIndex": 1001
},
{
"data": {
"isHovered": false
},
- "id": "xy-edge__Extractor:ThreeDrinksActstart-Tokenizer:KindHandsWinend",
+ "id": "xy-edge__TitleChunker:FlatMiceFixstart-Extractor:ThreeDrinksActend",
"markerEnd": "logo",
- "source": "Extractor:ThreeDrinksAct",
+ "source": "TitleChunker:FlatMiceFix",
"sourceHandle": "start",
- "target": "Tokenizer:KindHandsWin",
+ "target": "Extractor:ThreeDrinksAct",
"targetHandle": "end",
"type": "buttonEdge",
"zIndex": 1001
@@ -331,7 +335,7 @@
},
"id": "File",
"measured": {
- "height": 50,
+ "height": 49,
"width": 200
},
"position": {
@@ -460,7 +464,7 @@
"dragging": false,
"id": "Parser:HipSignsRhyme",
"measured": {
- "height": 198,
+ "height": 197,
"width": 200
},
"position": {
@@ -489,12 +493,12 @@
"dragging": false,
"id": "Tokenizer:KindHandsWin",
"measured": {
- "height": 114,
+ "height": 113,
"width": 200
},
"position": {
- "x": 876.4654525205967,
- "y": 189.1906747329592
+ "x": 883.0243372012395,
+ "y": 156.39625132974524
},
"selected": false,
"sourcePosition": "right",
@@ -514,6 +518,7 @@
}
},
"promote_first_heading_to_root": false,
+ "root_chunk_as_heading": true,
"rules": [
{
"levels": [
@@ -537,14 +542,14 @@
"dragging": false,
"id": "TitleChunker:FlatMiceFix",
"measured": {
- "height": 74,
+ "height": 73,
"width": 200
},
"position": {
"x": 572.7908769627791,
"y": 141.55515313482098
},
- "selected": false,
+ "selected": true,
"sourcePosition": "right",
"targetPosition": "left",
"type": "chunkerNode"
@@ -580,12 +585,12 @@
"dragging": false,
"id": "Extractor:ThreeDrinksAct",
"measured": {
- "height": 90,
+ "height": 89,
"width": 200
},
"position": {
- "x": 583.3659219536569,
- "y": 274.7600100230409
+ "x": 623.8123774842874,
+ "y": 236.49984938595793
},
"selected": false,
"sourcePosition": "right",
diff --git a/agent/tools/base.py b/agent/tools/base.py
index f5a42de4d10..194b47fceec 100644
--- a/agent/tools/base.py
+++ b/agent/tools/base.py
@@ -67,6 +67,19 @@ async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
else:
resp = await thread_pool_exec(tool_obj.invoke, **arguments)
+ if resp is None and hasattr(tool_obj, "output") and callable(tool_obj.output):
+ try:
+ fallback_output = tool_obj.output()
+ if isinstance(fallback_output, dict) and fallback_output.get("content") not in (None, ""):
+ resp = fallback_output["content"]
+ elif fallback_output not in (None, ""):
+ resp = fallback_output
+ else:
+ resp = fallback_output
+ logging.warning(f"[ToolCall] resp is None, fallback to output name={name} output_keys={list(fallback_output.keys()) if isinstance(fallback_output, dict) else type(fallback_output).__name__}")
+ except Exception as e:
+ logging.warning(f"[ToolCall] resp is None and output fallback failed name={name} err={e}")
+
elapsed = timer() - st
logging.info(f"[ToolCall] done name={name} elapsed={elapsed:.2f}s result={str(resp)[:200]}")
self.callback(name, arguments, resp, elapsed_time=elapsed)
diff --git a/agent/tools/code_exec.py b/agent/tools/code_exec.py
index 5d65a2e33ae..ece67d97fc9 100644
--- a/agent/tools/code_exec.py
+++ b/agent/tools/code_exec.py
@@ -357,6 +357,7 @@ def _execute_code(self, language: str, code: str, arguments: dict):
# Try using the new sandbox provider system first
try:
from agent.sandbox.client import execute_code as sandbox_execute_code
+ from agent.sandbox.providers.base import SandboxProviderConfigError
if self.check_if_canceled("CodeExec execution"):
return
@@ -376,8 +377,16 @@ def _execute_code(self, language: str, code: str, arguments: dict):
execution_metadata=result.metadata,
)
- except (ImportError, RuntimeError) as provider_error:
- # Provider system not available or not configured, fall back to HTTP
+ except SandboxProviderConfigError as provider_error:
+ self.set_output("_ERROR", str(provider_error))
+ return self.output()
+ except ImportError as provider_error:
+ # Provider modules are unavailable, fall back to legacy HTTP sandbox.
+ logging.info(f"[CodeExec]: Provider system not available, using HTTP fallback: {provider_error}")
+ except RuntimeError as provider_error:
+ if not self._should_fallback_to_http(provider_error):
+ self.set_output("_ERROR", f"Provider system execution failed: {provider_error}")
+ return self.output()
logging.info(f"[CodeExec]: Provider system not available, using HTTP fallback: {provider_error}")
# Fallback to direct HTTP request
@@ -487,6 +496,15 @@ def _resolve_execution_result_value(self, stdout: str, execution_metadata: Mappi
return metadata.get("result_value"), False
return self._deserialize_stdout(stdout), True
+ @staticmethod
+ def _should_fallback_to_http(provider_error: RuntimeError) -> bool:
+ message = str(provider_error).lower()
+ fallback_markers = (
+ "no sandbox provider configured",
+ "sandbox provider type not configured",
+ )
+ return any(marker in message for marker in fallback_markers)
+
@classmethod
def _ensure_bucket_lifecycle(cls):
if cls._lifecycle_configured:
@@ -533,7 +551,7 @@ def _upload_artifacts(self, artifacts: list) -> list[dict]:
settings.STORAGE_IMPL.put(SANDBOX_ARTIFACT_BUCKET, storage_name, binary)
- url = f"/v1/document/artifact/{storage_name}"
+ url = f"/api/v1/documents/artifact/{storage_name}"
uploaded.append(
{
"name": name,
diff --git a/agent/tools/crawler.py b/agent/tools/crawler.py
index e4d049e1bdd..6558c524f0a 100644
--- a/agent/tools/crawler.py
+++ b/agent/tools/crawler.py
@@ -19,7 +19,6 @@
from agent.tools.base import ToolParamBase, ToolBase
-
class CrawlerParam(ToolParamBase):
"""
Define the Crawler component parameters.
@@ -31,20 +30,26 @@ def __init__(self):
self.extract_type = "markdown"
def check(self):
- self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content'])
+ self.check_valid_value(self.extract_type, "Type of content from the crawler", ["html", "markdown", "content"])
class Crawler(ToolBase, ABC):
component_name = "Crawler"
def _run(self, history, **kwargs):
- from api.utils.web_utils import is_valid_url
+ from common.ssrf_guard import assert_url_is_safe, pin_dns_global
+
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
- if not is_valid_url(ans):
+ try:
+ _ssrf_hostname, _ssrf_ip = assert_url_is_safe(ans)
+ except ValueError:
return Crawler.be_output("URL not valid")
try:
- result = asyncio.run(self.get_web(ans))
+ # pin_dns_global is used (not thread-local) because crawl4ai resolves
+ # DNS in asyncio executor threads that don't share thread-local state.
+ with pin_dns_global(_ssrf_hostname, _ssrf_ip):
+ result = asyncio.run(self.get_web(ans))
return Crawler.be_output(result)
@@ -57,18 +62,15 @@ async def get_web(self, url):
proxy = self._param.proxy if self._param.proxy else None
async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler:
- result = await crawler.arun(
- url=url,
- bypass_cache=True
- )
+ result = await crawler.arun(url=url, bypass_cache=True)
if self.check_if_canceled("Crawler async operation"):
return
- if self._param.extract_type == 'html':
+ if self._param.extract_type == "html":
return result.cleaned_html
- elif self._param.extract_type == 'markdown':
+ elif self._param.extract_type == "markdown":
return result.markdown
- elif self._param.extract_type == 'content':
+ elif self._param.extract_type == "content":
return result.extracted_content
return result.markdown
diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py
index 912a5c34850..4496f497aef 100644
--- a/agent/tools/retrieval.py
+++ b/agent/tools/retrieval.py
@@ -135,7 +135,11 @@ async def _retrieve_kb(self, query_text: str):
doc_ids = []
if self._param.meta_data_filter != {}:
- metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
+ # Defer the (potentially expensive) metadata table load — manual
+ # filters served by ES push-down never need it. The loader is
+ # invoked at most once per request by ``apply_meta_data_filter``.
+ def _load_metas() -> dict:
+ return DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
def _resolve_manual_filter(flt: dict) -> dict:
pat = re.compile(self.variable_ref_patt)
@@ -174,11 +178,13 @@ def _resolve_manual_filter(flt: dict) -> dict:
doc_ids = await apply_meta_data_filter(
self._param.meta_data_filter,
- metas,
+ None,
query,
chat_mdl,
doc_ids,
_resolve_manual_filter if self._param.meta_data_filter.get("method") == "manual" else None,
+ kb_ids=kb_ids,
+ metas_loader=_load_metas,
)
if self._param.cross_languages:
diff --git a/agent/tools/searxng.py b/agent/tools/searxng.py
index fdc7bea525c..ef03375b306 100644
--- a/agent/tools/searxng.py
+++ b/agent/tools/searxng.py
@@ -20,6 +20,7 @@
import requests
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
from common.connection_utils import timeout
+from common.ssrf_guard import assert_url_is_safe, pin_dns
class SearXNGParam(ToolParamBase):
@@ -36,15 +37,15 @@ def __init__(self):
"type": "string",
"description": "The search keywords to execute with SearXNG. The keywords should be the most important words/terms(includes synonyms) from the original request.",
"default": "{sys.query}",
- "required": True
+ "required": True,
},
"searxng_url": {
"type": "string",
"description": "The base URL of your SearXNG instance (e.g., http://localhost:4000). This is required to connect to your SearXNG server.",
"required": False,
- "default": ""
- }
- }
+ "default": "",
+ },
+ },
}
super().__init__()
self.top_n = 10
@@ -61,17 +62,7 @@ def check(self):
self.check_positive_integer(self.top_n, "Top N")
def get_input_form(self) -> dict[str, dict]:
- return {
- "query": {
- "name": "Query",
- "type": "line"
- },
- "searxng_url": {
- "name": "SearXNG URL",
- "type": "line",
- "placeholder": "http://localhost:4000"
- }
- }
+ return {"query": {"name": "Query", "type": "line"}, "searxng_url": {"name": "SearXNG URL", "type": "line", "placeholder": "http://localhost:4000"}}
class SearXNG(ToolBase, ABC):
@@ -94,26 +85,22 @@ def _invoke(self, **kwargs):
self.set_output("formalized_content", "")
return ""
+ try:
+ _ssrf_hostname, _ssrf_ip = assert_url_is_safe(searxng_url)
+ except ValueError as e:
+ self.set_output("_ERROR", str(e))
+ return f"SearXNG error: SSRF guard blocked {searxng_url!r}: {e}"
+
last_e = ""
- for _ in range(self._param.max_retries+1):
+ for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("SearXNG processing"):
return
try:
- search_params = {
- 'q': query,
- 'format': 'json',
- 'categories': 'general',
- 'language': 'auto',
- 'safesearch': 1,
- 'pageno': 1
- }
-
- response = requests.get(
- f"{searxng_url}/search",
- params=search_params,
- timeout=10
- )
+ search_params = {"q": query, "format": "json", "categories": "general", "language": "auto", "safesearch": 1, "pageno": 1}
+
+ with pin_dns(_ssrf_hostname, _ssrf_ip):
+ response = requests.get(f"{searxng_url}/search", params=search_params, timeout=10)
response.raise_for_status()
if self.check_if_canceled("SearXNG processing"):
@@ -128,15 +115,12 @@ def _invoke(self, **kwargs):
if not isinstance(results, list):
raise ValueError("Invalid results format from SearXNG")
- results = results[:self._param.top_n]
+ results = results[: self._param.top_n]
if self.check_if_canceled("SearXNG processing"):
return
- self._retrieve_chunks(results,
- get_title=lambda r: r.get("title", ""),
- get_url=lambda r: r.get("url", ""),
- get_content=lambda r: r.get("content", ""))
+ self._retrieve_chunks(results, get_title=lambda r: r.get("title", ""), get_url=lambda r: r.get("url", ""), get_content=lambda r: r.get("content", ""))
self.set_output("json", results)
return self.output("formalized_content")
diff --git a/api/apps/__init__.py b/api/apps/__init__.py
index 9139954115c..e05bbb03d42 100644
--- a/api/apps/__init__.py
+++ b/api/apps/__init__.py
@@ -79,8 +79,8 @@ def _unauthorized_message(error):
app.config["MAX_CONTENT_LENGTH"] = int(
os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024)
)
-app.config['SECRET_KEY'] = settings.SECRET_KEY
-app.secret_key = settings.SECRET_KEY
+app.config['SECRET_KEY'] = settings.get_secret_key()
+app.secret_key = settings.get_secret_key()
commands.register_commands(app)
from functools import wraps
@@ -93,7 +93,7 @@ def _unauthorized_message(error):
def _load_user():
- jwt = Serializer(secret_key=settings.SECRET_KEY)
+ jwt = Serializer(secret_key=settings.get_secret_key())
authorization = request.headers.get("Authorization")
g.user = None
if not authorization:
@@ -301,6 +301,10 @@ def register_page(page_path):
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
]
+# Register backward compatibility routes for deprecated APIs
+from api.apps.backward_compat import register_backward_compat_routes
+register_backward_compat_routes(app)
+
@app.errorhandler(404)
async def not_found(error):
diff --git a/api/apps/auth/README.md b/api/apps/auth/README.md
index 372e75cfbd8..8edab999f82 100644
--- a/api/apps/auth/README.md
+++ b/api/apps/auth/README.md
@@ -20,7 +20,7 @@ oauth_config = {
"authorization_url": "https://your-oauth-provider.com/oauth/authorize",
"token_url": "https://your-oauth-provider.com/oauth/token",
"userinfo_url": "https://your-oauth-provider.com/oauth/userinfo",
- "redirect_uri": "https://your-app.com/v1/user/oauth/callback/
"
+ "redirect_uri": "https://your-app.com/api/v1/auth/oauth//callback"
}
# OIDC configuration
@@ -29,7 +29,7 @@ oidc_config = {
"issuer": "https://your-oauth-provider.com/oidc",
"client_id": "your_client_id",
"client_secret": "your_client_secret",
- "redirect_uri": "https://your-app.com/v1/user/oauth/callback/"
+ "redirect_uri": "https://your-app.com/api/v1/auth/oauth//callback"
}
# Github OAuth configuration
diff --git a/api/apps/backward_compat.py b/api/apps/backward_compat.py
new file mode 100644
index 00000000000..a2c950158e6
--- /dev/null
+++ b/api/apps/backward_compat.py
@@ -0,0 +1,522 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+Backward compatibility layer for deprecated API endpoints.
+
+This module adds support for old API routes that were deprecated during the
+RESTful API migration. Each deprecated route forwards to the corresponding
+new API implementation.
+
+Deprecated APIs and their replacements:
+- POST /api/v1/agents/{agent_id}/completions -> POST /api/v1/agents/chat/completion
+- POST /api/v1/chats/{chat_id}/completions -> POST /api/v1/chat/completions
+- POST /api/v1/chats_openai/{chat_id}/chat/completions -> POST /api/v1/openai/{chat_id}/chat/completions
+- PUT /api/v1/chats/{chat_id}/sessions/{session_id} -> PATCH /api/v1/chats/{chat_id}/sessions/{session_id}
+- DELETE /api/v1/chats -> DELETE /api/v1/chats/{chat_id} (with body)
+- POST /api/v1/file/convert -> POST /api/v1/files/link-to-datasets
+- GET /api/v1/file/* -> GET /api/v1/files*
+- POST /api/v1/file/* -> POST /api/v1/files*
+- GET /api/v1/document/get/{doc_id} -> GET /api/v1/documents/{doc_id}/preview
+- GET /api/v1/document/download/{doc_id} -> GET /api/v1/documents/{doc_id}/download
+- GET /v1/document/download/{attachment_id} -> GET /api/v1/documents/{attachment_id}/download
+- GET /v1/system/healthz -> GET /api/v1/system/healthz
+- POST /api/v1/sessions/related_questions -> POST /api/v1/chat/recommandation
+- PUT (chunk update) -> PATCH (chunk update)
+"""
+import logging
+
+from quart import Blueprint, jsonify, request
+
+from api.apps import login_required
+from api.apps.restful_apis import chat_api, file_api, file2document_api, chunk_api, openai_api, document_api
+from api.apps.restful_apis.system_api import run_health_checks
+from api.apps.restful_apis import agent_api
+from api.apps.services import file_api_service
+from api.utils.api_utils import get_data_error_result, get_json_result, add_tenant_id_to_kwargs
+
+manager = Blueprint("backward_compat", __name__)
+legacy_v1_manager = Blueprint("backward_compat_legacy_v1", __name__)
+
+
+# =============================================================================
+# System APIs
+# =============================================================================
+
+@legacy_v1_manager.route("/system/healthz", methods=["GET"])
+async def deprecated_system_healthz():
+ """
+ Deprecated: Use GET /api/v1/system/healthz instead.
+
+ Old path: GET /v1/system/healthz
+ New path: GET /api/v1/system/healthz
+ """
+ logging.warning(
+ "API endpoint /v1/system/healthz is deprecated. "
+ "Please use /api/v1/system/healthz instead."
+ )
+ result, all_ok = run_health_checks()
+ return jsonify(result), (200 if all_ok else 500)
+
+# =============================================================================
+# Chat Completion APIs
+# =============================================================================
+
+@manager.route("/chats//completions", methods=["POST"])
+@login_required
+async def deprecated_chat_completions(chat_id):
+ """
+ Deprecated: Use POST /api/v1/chat/completions instead.
+
+ Old path: POST /api/v1/chats/{chat_id}/completions
+ New path: POST /api/v1/chat/completions
+ """
+ logging.warning(
+ "API endpoint /api/v1/chats/%s/completions is deprecated. "
+ "Please use /api/v1/chat/completions instead.",
+ chat_id,
+ )
+ # Forward to the new API implementation
+ return await chat_api.session_completion(chat_id)
+
+
+@manager.route("/chats_openai//chat/completions", methods=["POST"])
+@login_required
+async def deprecated_openai_chat_completions(chat_id):
+ """
+ Deprecated: Use POST /api/v1/openai/{chat_id}/chat/completions instead.
+
+ Old path: POST /api/v1/chats_openai/{chat_id}/chat/completions
+ New path: POST /api/v1/openai/{chat_id}/chat/completions
+ """
+ logging.warning(
+ "API endpoint /api/v1/chats_openai/%s/chat/completions is deprecated. "
+ "Please use /api/v1/openai/%s/chat/completions instead.",
+ chat_id, chat_id,
+ )
+ # Forward to the new API implementation
+ return await openai_api.openai_chat_completions(chat_id)
+
+
+# =============================================================================
+# Chat Session APIs
+# =============================================================================
+
+@manager.route("/chats//sessions/", methods=["PUT"])
+@login_required
+async def deprecated_update_session(chat_id, session_id):
+ """
+ Deprecated: Use PATCH /api/v1/chats/{chat_id}/sessions/{session_id} instead.
+
+ Old path: PUT /api/v1/chats/{chat_id}/sessions/{session_id}
+ New path: PATCH /api/v1/chats/{chat_id}/sessions/{session_id}
+ """
+ logging.warning(
+ "API endpoint PUT /api/v1/chats/%s/sessions/%s is deprecated. "
+ "Please use PATCH /api/v1/chats/%s/sessions/%s instead.",
+ chat_id, session_id, chat_id, session_id,
+ )
+ # Forward to the new API implementation
+ return await chat_api.update_session(chat_id, session_id)
+
+
+# =============================================================================
+# File APIs (Old /api/v1/file/* -> New /api/v1/files*)
+# =============================================================================
+
+@manager.route("/file/get/", methods=["GET"])
+@login_required
+async def deprecated_file_get(file_id):
+ """
+ Deprecated: Use GET /api/v1/files/{file_id} instead.
+
+ Old path: GET /api/v1/file/get/{file_id}
+ New path: GET /api/v1/files/{file_id}
+ """
+ logging.warning(
+ "API endpoint /api/v1/file/get/%s is deprecated. "
+ "Please use /api/v1/files/%s instead.",
+ file_id, file_id,
+ )
+ # Forward to the new API implementation (download)
+ return await file_api.download(file_id=file_id)
+
+
+@manager.route("/file/list", methods=["GET"])
+@login_required
+async def deprecated_file_list():
+ """
+ Deprecated: Use GET /api/v1/files instead.
+
+ Old path: GET /api/v1/file/list?...
+ New path: GET /api/v1/files?...
+ """
+ logging.warning(
+ "API endpoint /api/v1/file/list is deprecated. "
+ "Please use /api/v1/files instead."
+ )
+ # Forward to the new API implementation
+ return await file_api.list_files()
+
+
+@manager.route("/file/all_parent_folder", methods=["GET"])
+@login_required
+async def deprecated_file_all_parent_folder():
+ """
+ Deprecated: Use GET /api/v1/files/{file_id}/ancestors instead.
+
+ Old path: GET /api/v1/file/all_parent_folder?file_id=...
+ New path: GET /api/v1/files/{file_id}/ancestors
+ """
+ file_id = request.args.get("file_id")
+ if not file_id:
+ return get_data_error_result(message="`file_id` query parameter is required")
+ logging.warning(
+ "API endpoint /api/v1/file/all_parent_folder is deprecated. "
+ "Please use /api/v1/files/%s/ancestors instead.",
+ file_id,
+ )
+ # Forward to the new API implementation
+ return await file_api.ancestors(file_id=file_id)
+
+
+@manager.route("/file/parent_folder", methods=["GET"])
+@login_required
+async def deprecated_file_parent_folder():
+ """
+ Deprecated: Use GET /api/v1/files/{file_id}/parent instead.
+
+ Old path: GET /api/v1/file/parent_folder?file_id=...
+ New path: GET /api/v1/files/{file_id}/parent
+ """
+ file_id = request.args.get("file_id")
+ if not file_id:
+ return get_data_error_result(message="`file_id` query parameter is required")
+ logging.warning(
+ "API endpoint /api/v1/file/parent_folder is deprecated. "
+ "Please use /api/v1/files/%s/parent instead.",
+ file_id,
+ )
+ # Forward to the new API implementation
+ return await file_api.parent_folder(file_id=file_id)
+
+
+@manager.route("/file/root_folder", methods=["GET"])
+@login_required
+async def deprecated_file_root_folder():
+ """
+ Deprecated: Root folder is now accessible via GET /api/v1/files with parent_id=...
+
+ Old path: GET /api/v1/file/root_folder
+ New path: GET /api/v1/files?parent_id=
+ """
+ logging.warning(
+ "API endpoint /api/v1/file/root_folder is deprecated. "
+ "Please use /api/v1/files with appropriate parent_id instead."
+ )
+ # Forward to the new API implementation with empty parent_id to get root
+ return await file_api.list_files()
+
+
+@manager.route("/file/create", methods=["POST"])
+@login_required
+@add_tenant_id_to_kwargs
+async def deprecated_file_create(tenant_id=None):
+ """
+ Deprecated: Use POST /api/v1/files instead.
+
+ Old path: POST /api/v1/file/create
+ New path: POST /api/v1/files
+ """
+ logging.warning(
+ "API endpoint /api/v1/file/create is deprecated. "
+ "Please use POST /api/v1/files instead."
+ )
+ # Forward to the new API implementation
+ return await file_api.create_or_upload(tenant_id=tenant_id)
+
+
+@manager.route("/file/upload", methods=["POST"])
+@login_required
+@add_tenant_id_to_kwargs
+async def deprecated_file_upload(tenant_id=None):
+ """
+ Deprecated: Use POST /api/v1/files (with multipart/form-data) instead.
+
+ Old path: POST /api/v1/file/upload
+ New path: POST /api/v1/files
+ """
+ logging.warning(
+ "API endpoint /api/v1/file/upload is deprecated. "
+ "Please use POST /api/v1/files with multipart/form-data instead."
+ )
+ # Forward to the new API implementation
+ return await file_api.create_or_upload(tenant_id=tenant_id)
+
+
+@manager.route("/file/convert", methods=["POST"])
+@login_required
+async def deprecated_file_convert():
+ """
+ Deprecated: Use POST /api/v1/files/link-to-datasets instead.
+
+ Old path: POST /api/v1/file/convert
+ New path: POST /api/v1/files/link-to-datasets
+ """
+ logging.warning(
+ "API endpoint /api/v1/file/convert is deprecated. "
+ "Please use POST /api/v1/files/link-to-datasets instead."
+ )
+ return await file2document_api.convert()
+
+
+@manager.route("/file/mv", methods=["POST"])
+@login_required
+@add_tenant_id_to_kwargs
+async def deprecated_file_mv(tenant_id=None):
+ """
+ Deprecated: Use POST /api/v1/files/move instead.
+
+ Old path: POST /api/v1/file/mv
+ New path: POST /api/v1/files/move
+ """
+ logging.warning(
+ "API endpoint /api/v1/file/mv is deprecated. "
+ "Please use POST /api/v1/files/move instead."
+ )
+ # Forward to the new API implementation
+ return await file_api.move(tenant_id=tenant_id)
+
+
+@manager.route("/file/rename", methods=["POST"])
+@login_required
+@add_tenant_id_to_kwargs
+async def deprecated_file_rename(tenant_id=None):
+ """
+ Deprecated: Use POST /api/v1/files/move with new_name instead.
+
+ Old path: POST /api/v1/file/rename
+ New path: POST /api/v1/files/move
+ """
+ logging.warning(
+ "API endpoint /api/v1/file/rename is deprecated. "
+ "Please use POST /api/v1/files/move with `new_name` instead."
+ )
+ # Transform the old API format to new format
+ req = await request.get_json()
+ # Old API used `file_id` and `name`, new API uses `src_file_ids` and `new_name`
+ src_file_ids = [req.get("file_id")]
+ new_name = req.get("name")
+ # Call the underlying service directly with transformed data
+ try:
+ success, result = await file_api_service.move_files(
+ tenant_id, src_file_ids, None, new_name
+ )
+ if success:
+ return get_json_result(data=result)
+ else:
+ return get_data_error_result(message=result)
+ except Exception as e:
+ logging.exception(e)
+ return get_data_error_result(message="Internal server error")
+
+
+@manager.route("/file/rm", methods=["POST"])
+@login_required
+@add_tenant_id_to_kwargs
+async def deprecated_file_rm(tenant_id=None):
+ """
+ Deprecated: Use DELETE /api/v1/files instead.
+
+ Old path: POST /api/v1/file/rm
+ New path: DELETE /api/v1/files
+ """
+ logging.warning(
+ "API endpoint /api/v1/file/rm is deprecated. "
+ "Please use DELETE /api/v1/files instead."
+ )
+ # Transform POST with body to DELETE behavior
+ # The new API expects a JSON body with `ids`
+ return await file_api.delete(tenant_id=tenant_id)
+
+
+# =============================================================================
+# Related Questions API
+# =============================================================================
+
+@manager.route("/sessions/related_questions", methods=["POST"])
+@login_required
+async def deprecated_related_questions():
+ """
+ Deprecated: Use POST /api/v1/chat/recommendation instead.
+
+ Old path: POST /api/v1/sessions/related_questions
+ New path: POST /api/v1/chat/recommendation
+ """
+ logging.warning(
+ "API endpoint /api/v1/sessions/related_questions is deprecated. "
+ "Please use /api/v1/chat/recommendation instead."
+ )
+ # Forward to the new API implementation
+ return await chat_api.recommendation()
+
+
+# =============================================================================
+# Chunk Update API (PUT -> PATCH)
+# =============================================================================
+
+@manager.route("/datasets//documents//chunks/", methods=["PUT"])
+@login_required
+async def deprecated_update_chunk(dataset_id, document_id, chunk_id):
+ """
+ Deprecated: Use PATCH /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id} instead.
+
+ Old path: PUT /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id}
+ New path: PATCH /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id}
+ """
+ logging.warning(
+ "API endpoint PUT /api/v1/datasets/%s/documents/%s/chunks/%s is deprecated. "
+ "Please use PATCH instead.",
+ dataset_id, document_id, chunk_id,
+ )
+ # Forward to the new API implementation
+ return await chunk_api.update_chunk(dataset_id=dataset_id, document_id=document_id, chunk_id=chunk_id)
+
+
+# =============================================================================
+# File Upload Info API
+# =============================================================================
+
+@manager.route("/file/upload_info", methods=["POST"])
+@login_required
+async def deprecated_file_upload_info():
+ """
+ Deprecated: Use POST /api/v1/documents/upload instead.
+
+ Old path: POST /api/v1/file/upload_info
+ New path: POST /api/v1/documents/upload
+ """
+ from api.apps import current_user
+
+ logging.warning(
+ "API endpoint /api/v1/file/upload_info is deprecated. "
+ "Please use POST /api/v1/documents/upload instead."
+ )
+ # Forward to the new API implementation
+ # Need to pass tenant_id explicitly since we're calling the function directly
+ tenant_id = current_user.id
+ return await document_api.upload_info(tenant_id=tenant_id)
+
+
+# =============================================================================
+# Document APIs
+# =============================================================================
+
+@manager.route("/datasets//documents/", methods=["PUT"])
+@login_required
+async def deprecated_update_document(dataset_id, document_id):
+ """
+ Deprecated: Use PATCH /api/v1/datasets/{dataset_id}/documents/{document_id} instead.
+
+ Old path: PUT /api/v1/datasets/{dataset_id}/documents/{document_id}
+ New path: PATCH /api/v1/datasets/{dataset_id}/documents/{document_id}
+ """
+ logging.warning(
+ "API endpoint PUT /api/v1/datasets/%s/documents/%s is deprecated. "
+ "Please use PATCH instead.",
+ dataset_id, document_id,
+ )
+ # Forward to the new API implementation
+ return await document_api.update_document(dataset_id=dataset_id, document_id=document_id)
+
+
+@manager.route("/document/get/", methods=["GET"])
+@login_required
+async def deprecated_document_get(doc_id):
+ """
+ Deprecated: Use GET /api/v1/documents/{doc_id}/preview instead.
+
+ Old path: GET /api/v1/document/get/{doc_id}
+ New path: GET /api/v1/documents/{doc_id}/preview
+ """
+ logging.warning(
+ "API endpoint /api/v1/document/get/%s is deprecated. "
+ "Please use /api/v1/documents/%s/preview instead.",
+ doc_id, doc_id,
+ )
+ return await document_api.get(doc_id)
+
+
+@manager.route("/document/download/", methods=["GET"])
+@login_required
+async def deprecated_document_download(doc_id):
+ """
+ Deprecated: Use GET /api/v1/documents/{doc_id}/download instead.
+
+ Old path: GET /api/v1/document/download/{doc_id}
+ New path: GET /api/v1/documents/{doc_id}/download
+ """
+ logging.warning(
+ "API endpoint /api/v1/document/download/%s is deprecated. "
+ "Please use /api/v1/documents/%s/download instead.",
+ doc_id, doc_id,
+ )
+ return await document_api.download_attachment(doc_id=doc_id)
+
+
+@legacy_v1_manager.route("/document/download/", methods=["GET"])
+@login_required
+async def document_download_v1(attachment_id):
+ """
+ Compatibility alias for document download under /v1.
+
+ Old path: GET /v1/document/download/{attachment_id}
+ New path: GET /api/v1/documents/{attachment_id}/download
+ """
+ logging.warning(
+ "API endpoint /v1/document/download/%s is deprecated. "
+ "Please use /api/v1/documents/%s/download instead.",
+ attachment_id, attachment_id,
+ )
+ return await document_api.download_attachment(attachment_id=attachment_id)
+
+# =============================================================================
+# Agent Chat API
+# =============================================================================
+
+@manager.route("/agents//completions", methods=["POST"])
+@login_required
+@add_tenant_id_to_kwargs
+async def deprecated_agent_completions(agent_id, tenant_id=None):
+ """
+ Deprecated: Use POST /api/v1/agents/chat/completions instead.
+
+ Old path: POST /api/v1/agents/{agent_id}/completions
+ New path: POST /api/v1/agents/chat/completions
+ """
+ logging.warning(
+ "API endpoint /api/v1/agents/%s/completions is deprecated. "
+ "Please use /api/v1/agents/chat/completions instead.",
+ agent_id,
+ )
+ return await agent_api.agent_chat_completion(tenant_id=tenant_id, agent_id=agent_id)
+
+def register_backward_compat_routes(app_instance):
+ """
+ Register all backward compatibility routes with the app.
+ """
+ app_instance.register_blueprint(manager, url_prefix="/api/v1")
+ app_instance.register_blueprint(legacy_v1_manager, url_prefix="/v1")
+ logging.info("Backward compatibility routes registered successfully.")
diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py
deleted file mode 100644
index 8c896e36add..00000000000
--- a/api/apps/canvas_app.py
+++ /dev/null
@@ -1,755 +0,0 @@
-#
-# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import copy
-import inspect
-import json
-import logging
-from functools import partial
-from quart import request, Response, make_response
-from agent.component import LLM
-from api.db import CanvasCategory
-from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
-from api.db.services.document_service import DocumentService
-from api.db.services.file_service import FileService
-from api.db.services.knowledgebase_service import KnowledgebaseService
-from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
-from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, TaskService
-from api.db.services.user_service import TenantService
-from api.db.services.user_canvas_version import UserCanvasVersionService
-from common.constants import RetCode
-from common.misc_utils import get_uuid, thread_pool_exec
-from api.utils.api_utils import (
- get_json_result,
- server_error_response,
- validate_request,
- get_data_error_result,
- get_request_json,
-)
-from agent.canvas import Canvas
-from agent.dsl_migration import normalize_chunker_dsl
-from peewee import MySQLDatabase, PostgresqlDatabase
-from api.db.db_models import APIToken, Task
-
-from rag.flow.pipeline import Pipeline
-from rag.nlp import search
-from rag.utils.redis_conn import REDIS_CONN
-from common import settings
-from api.apps import login_required, current_user
-from api.apps.services.canvas_replica_service import CanvasReplicaService
-from api.db.services.canvas_service import completion as agent_completion
-
-
-@manager.route('/templates', methods=['GET']) # noqa: F821
-@login_required
-def templates():
- return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
-
-
-@manager.route('/rm', methods=['POST']) # noqa: F821
-@validate_request("canvas_ids")
-@login_required
-async def rm():
- req = await get_request_json()
- for i in req["canvas_ids"]:
- if not UserCanvasService.accessible(i, current_user.id):
- return get_json_result(
- data=False, message='Only owner of canvas authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
- UserCanvasService.delete_by_id(i)
- return get_json_result(data=True)
-
-
-@manager.route('/set', methods=['POST']) # noqa: F821
-@validate_request("dsl", "title")
-@login_required
-async def save():
- req = await get_request_json()
- req['release'] = bool(req.get("release", ""))
- try:
- req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
- except ValueError as e:
- return get_data_error_result(message=str(e))
- cate = req.get("canvas_category", CanvasCategory.Agent)
- if "id" not in req:
- req["user_id"] = current_user.id
- if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=cate):
- return get_data_error_result(message=f"{req['title'].strip()} already exists.")
- req["id"] = get_uuid()
- if not UserCanvasService.save(**req):
- return get_data_error_result(message="Fail to save canvas.")
- else:
- if not UserCanvasService.accessible(req["id"], current_user.id):
- return get_json_result(
- data=False, message='Only owner of canvas authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
- UserCanvasService.update_by_id(req["id"], req)
- # save version
- UserCanvasVersionService.save_or_replace_latest(
- user_canvas_id=req["id"],
- dsl=req["dsl"],
- title=UserCanvasVersionService.build_version_title(getattr(current_user, "nickname", current_user.id), req.get("title")),
- release=req.get("release"),
- )
- replica_ok = CanvasReplicaService.replace_for_set(
- canvas_id=req["id"],
- tenant_id=str(current_user.id),
- runtime_user_id=str(current_user.id),
- dsl=req["dsl"],
- canvas_category=req.get("canvas_category", cate),
- title=req.get("title", ""),
- )
- if not replica_ok:
- return get_data_error_result(message="canvas saved, but replica sync failed.")
- return get_json_result(data=req)
-
-
-@manager.route('/get/', methods=['GET']) # noqa: F821
-@login_required
-def get(canvas_id):
- if not UserCanvasService.accessible(canvas_id, current_user.id):
- return get_data_error_result(message="canvas not found.")
- e, c = UserCanvasService.get_by_canvas_id(canvas_id)
- if not e:
- return get_data_error_result(message="canvas not found.")
- try:
- # DELETE
- CanvasReplicaService.bootstrap(
- canvas_id=canvas_id,
- tenant_id=str(current_user.id),
- runtime_user_id=str(current_user.id),
- dsl=c.get("dsl"),
- canvas_category=c.get("canvas_category", CanvasCategory.Agent),
- title=c.get("title", ""),
- )
- except ValueError as e:
- return get_data_error_result(message=str(e))
-
- # Get the last publication time (latest released version's update_time)
- last_publish_time = None
- versions = UserCanvasVersionService.list_by_canvas_id(canvas_id)
- if versions:
- released_versions = [v for v in versions if v.release]
- if released_versions:
- # Sort by update_time descending and get the latest
- released_versions.sort(key=lambda x: x.update_time, reverse=True)
- last_publish_time = released_versions[0].update_time
-
- # Add last_publish_time to response data
- if isinstance(c, dict):
- c["dsl"] = normalize_chunker_dsl(c.get("dsl", {}))
- c["last_publish_time"] = last_publish_time
- else:
- # If c is a model object, convert to dict first
- c = c.to_dict()
- c["dsl"] = normalize_chunker_dsl(c.get("dsl", {}))
- c["last_publish_time"] = last_publish_time
-
- # For pipeline type, get associated datasets
- if c.get("canvas_category") == CanvasCategory.DataFlow:
- datasets = list(KnowledgebaseService.query(pipeline_id=canvas_id))
- c["datasets"] = [{"id": d.id, "name": d.name, "avatar": d.avatar} for d in datasets]
-
- return get_json_result(data=c)
-
-
-@manager.route('/getsse/', methods=['GET']) # type: ignore # noqa: F821
-def getsse(canvas_id):
- token = request.headers.get('Authorization').split()
- if len(token) != 2:
- return get_data_error_result(message='Authorization is not valid!')
- token = token[1]
- objs = APIToken.query(beta=token)
- if not objs:
- return get_data_error_result(message='Authentication error: API key is invalid!"')
- tenant_id = objs[0].tenant_id
- if not UserCanvasService.query(user_id=tenant_id, id=canvas_id):
- return get_json_result(
- data=False,
- message='Only owner of canvas authorized for this operation.',
- code=RetCode.OPERATING_ERROR
- )
- e, c = UserCanvasService.get_by_id(canvas_id)
- if not e or c.user_id != tenant_id:
- return get_data_error_result(message="canvas not found.")
- return get_json_result(data=c.to_dict())
-
-
-@manager.route('/completion', methods=['POST']) # noqa: F821
-@validate_request("id")
-@login_required
-async def run():
- req = await get_request_json()
- query = req.get("query", "")
- files = req.get("files", [])
- inputs = req.get("inputs", {})
- tenant_id = str(current_user.id)
- runtime_user_id = req.get("user_id") or tenant_id
- user_id = str(runtime_user_id)
- if not await thread_pool_exec(UserCanvasService.accessible, req["id"], tenant_id):
- return get_json_result(
- data=False, message='Only owner of canvas authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
-
- replica_payload = CanvasReplicaService.load_for_run(
- canvas_id=req["id"],
- tenant_id=tenant_id,
- runtime_user_id=user_id,
- )
-
- if not replica_payload:
- return get_data_error_result(message="canvas replica not found, please call /get/ first.")
-
- replica_dsl = replica_payload.get("dsl", {})
- canvas_title = replica_payload.get("title", "")
- canvas_category = replica_payload.get("canvas_category", CanvasCategory.Agent)
- dsl_str = json.dumps(replica_dsl, ensure_ascii=False)
-
- _, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
- if cvs.canvas_category == CanvasCategory.DataFlow:
- task_id = get_uuid()
- Pipeline(dsl_str, tenant_id=tenant_id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
- ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
- if not ok:
- return get_data_error_result(message=error_message)
- return get_json_result(data={"message_id": task_id})
-
- try:
- canvas = Canvas(dsl_str, tenant_id, canvas_id=req["id"])
- except Exception as e:
- return server_error_response(e)
-
- async def sse():
- nonlocal canvas, user_id
- try:
- async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
- yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
-
- commit_ok = CanvasReplicaService.commit_after_run(
- canvas_id=req["id"],
- tenant_id=tenant_id,
- runtime_user_id=user_id,
- dsl=json.loads(str(canvas)),
- canvas_category=canvas_category,
- title=canvas_title,
- )
- if not commit_ok:
- logging.error(
- "Canvas runtime replica commit failed: canvas_id=%s tenant_id=%s runtime_user_id=%s",
- req["id"],
- tenant_id,
- user_id,
- )
-
- except Exception as e:
- logging.exception(e)
- canvas.cancel_task()
- yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
-
- resp = Response(sse(), mimetype="text/event-stream")
- resp.headers.add_header("Cache-control", "no-cache")
- resp.headers.add_header("Connection", "keep-alive")
- resp.headers.add_header("X-Accel-Buffering", "no")
- resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
- #resp.call_on_close(lambda: canvas.cancel_task())
- return resp
-
-
-@manager.route("//completion", methods=["POST"]) # noqa: F821
-@login_required
-async def exp_agent_completion(canvas_id):
- tenant_id = current_user.id
- req = await get_request_json()
- return_trace = bool(req.get("return_trace", False))
- async def generate():
- trace_items = []
- async for answer in agent_completion(tenant_id=tenant_id, agent_id=canvas_id, **req):
- if isinstance(answer, str):
- try:
- ans = json.loads(answer[5:]) # remove "data:"
- except Exception:
- continue
-
- event = ans.get("event")
- if event == "node_finished":
- if return_trace:
- data = ans.get("data", {})
- trace_items.append(
- {
- "component_id": data.get("component_id"),
- "trace": [copy.deepcopy(data)],
- }
- )
- ans.setdefault("data", {})["trace"] = trace_items
- answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
- yield answer
-
- if event not in ["message", "message_end"]:
- continue
-
- yield answer
-
- yield "data:[DONE]\n\n"
-
- resp = Response(generate(), mimetype="text/event-stream")
- resp.headers.add_header("Cache-control", "no-cache")
- resp.headers.add_header("Connection", "keep-alive")
- resp.headers.add_header("X-Accel-Buffering", "no")
- resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
- return resp
-
-
-@manager.route('/rerun', methods=['POST']) # noqa: F821
-@validate_request("id", "dsl", "component_id")
-@login_required
-async def rerun():
- req = await get_request_json()
- doc = PipelineOperationLogService.get_documents_info(req["id"])
- if not doc:
- return get_data_error_result(message="Document not found.")
- doc = doc[0]
- if 0 < doc["progress"] < 1:
- return get_data_error_result(message=f"`{doc['name']}` is processing...")
-
- if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]):
- settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
- doc["progress_msg"] = ""
- doc["chunk_num"] = 0
- doc["token_num"] = 0
- DocumentService.clear_chunk_num_when_rerun(doc["id"])
- DocumentService.update_by_id(id, doc)
- TaskService.filter_delete([Task.doc_id == id])
-
- dsl = req["dsl"]
- dsl["path"] = [req["component_id"]]
- PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
- queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
- return get_json_result(data=True)
-
-
-@manager.route('/cancel/', methods=['PUT']) # noqa: F821
-@login_required
-def cancel(task_id):
- try:
- REDIS_CONN.set(f"{task_id}-cancel", "x")
- except Exception as e:
- logging.exception(e)
- return get_json_result(data=True)
-
-
-@manager.route('/reset', methods=['POST']) # noqa: F821
-@validate_request("id")
-@login_required
-async def reset():
- req = await get_request_json()
- if not UserCanvasService.accessible(req["id"], current_user.id):
- return get_json_result(
- data=False, message='Only owner of canvas authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
- try:
- e, user_canvas = UserCanvasService.get_by_id(req["id"])
- if not e:
- return get_data_error_result(message="canvas not found.")
-
- canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
- canvas.reset()
- req["dsl"] = json.loads(str(canvas))
- UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
- return get_json_result(data=req["dsl"])
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/upload/", methods=["POST"]) # noqa: F821
-async def upload(canvas_id):
- e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
- if not e:
- return get_data_error_result(message="canvas not found.")
-
- user_id = cvs["user_id"]
- files = await request.files
- file_objs = files.getlist("file") if files and files.get("file") else []
- try:
- if len(file_objs) == 1:
- return get_json_result(data=FileService.upload_info(user_id, file_objs[0], request.args.get("url")))
- results = [FileService.upload_info(user_id, f) for f in file_objs]
- return get_json_result(data=results)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/input_form', methods=['GET']) # noqa: F821
-@login_required
-def input_form():
- cvs_id = request.args.get("id")
- cpn_id = request.args.get("component_id")
- try:
- e, user_canvas = UserCanvasService.get_by_id(cvs_id)
- if not e:
- return get_data_error_result(message="canvas not found.")
- if not UserCanvasService.query(user_id=current_user.id, id=cvs_id):
- return get_json_result(
- data=False, message='Only owner of canvas authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
-
- canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
- return get_json_result(data=canvas.get_component_input_form(cpn_id))
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/debug', methods=['POST']) # noqa: F821
-@validate_request("id", "component_id", "params")
-@login_required
-async def debug():
- req = await get_request_json()
- if not UserCanvasService.accessible(req["id"], current_user.id):
- return get_json_result(
- data=False, message='Only owner of canvas authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
- try:
- e, user_canvas = UserCanvasService.get_by_id(req["id"])
- canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
- canvas.reset()
- canvas.message_id = get_uuid()
- component = canvas.get_component(req["component_id"])["obj"]
- component.reset()
-
- if isinstance(component, LLM):
- component.set_debug_inputs(req["params"])
- component.invoke(**{k: o["value"] for k,o in req["params"].items()})
- outputs = component.output()
- for k in outputs.keys():
- if isinstance(outputs[k], partial):
- txt = ""
- iter_obj = outputs[k]()
- if inspect.isasyncgen(iter_obj):
- async for c in iter_obj:
- txt += c
- else:
- for c in iter_obj:
- txt += c
- outputs[k] = txt
- return get_json_result(data=outputs)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
-@validate_request("db_type", "database", "username", "host", "port", "password")
-@login_required
-async def test_db_connect():
- req = await get_request_json()
- try:
- if req["db_type"] in ["mysql", "mariadb"]:
- db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
- password=req["password"])
- elif req["db_type"] == "oceanbase":
- db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
- password=req["password"], charset="utf8mb4")
- elif req["db_type"] == 'postgres':
- db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
- password=req["password"])
- elif req["db_type"] == 'mssql':
- import pyodbc
- connection_string = (
- f"DRIVER={{ODBC Driver 17 for SQL Server}};"
- f"SERVER={req['host']},{req['port']};"
- f"DATABASE={req['database']};"
- f"UID={req['username']};"
- f"PWD={req['password']};"
- )
- db = pyodbc.connect(connection_string)
- cursor = db.cursor()
- cursor.execute("SELECT 1")
- cursor.close()
- elif req["db_type"] == 'IBM DB2':
- import ibm_db
- conn_str = (
- f"DATABASE={req['database']};"
- f"HOSTNAME={req['host']};"
- f"PORT={req['port']};"
- f"PROTOCOL=TCPIP;"
- f"UID={req['username']};"
- f"PWD={req['password']};"
- )
- redacted_conn_str = (
- f"DATABASE={req['database']};"
- f"HOSTNAME={req['host']};"
- f"PORT={req['port']};"
- f"PROTOCOL=TCPIP;"
- f"UID={req['username']};"
- f"PWD=****;"
- )
- logging.info(redacted_conn_str)
- conn = ibm_db.connect(conn_str, "", "")
- stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
- ibm_db.fetch_assoc(stmt)
- ibm_db.close(conn)
- return get_json_result(data="Database Connection Successful!")
- elif req["db_type"] == 'trino':
- def _parse_catalog_schema(db_name: str):
- if not db_name:
- return None, None
- if "." in db_name:
- catalog_name, schema_name = db_name.split(".", 1)
- elif "/" in db_name:
- catalog_name, schema_name = db_name.split("/", 1)
- else:
- catalog_name, schema_name = db_name, "default"
- return catalog_name, schema_name
- try:
- import trino
- import os
- except Exception as e:
- return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
-
- catalog, schema = _parse_catalog_schema(req["database"])
- if not catalog:
- return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.")
-
- http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
-
- auth = None
- if http_scheme == "https" and req.get("password"):
- auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
-
- conn = trino.dbapi.connect(
- host=req["host"],
- port=int(req["port"] or 8080),
- user=req["username"] or "ragflow",
- catalog=catalog,
- schema=schema or "default",
- http_scheme=http_scheme,
- auth=auth
- )
- cur = conn.cursor()
- cur.execute("SELECT 1")
- cur.fetchall()
- cur.close()
- conn.close()
- return get_json_result(data="Database Connection Successful!")
- else:
- return server_error_response("Unsupported database type.")
- if req["db_type"] != 'mssql':
- db.connect()
- db.close()
-
- return get_json_result(data="Database Connection Successful!")
- except Exception as e:
- return server_error_response(e)
-
-
-#api get list version dsl of canvas
-@manager.route('/getlistversion/', methods=['GET']) # noqa: F821
-@login_required
-def getlistversion(canvas_id):
- try:
- versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
- return get_json_result(data=versions)
- except Exception as e:
- return get_data_error_result(message=f"Error getting history files: {e}")
-
-
-#api get version dsl of canvas
-@manager.route('/getversion/', methods=['GET']) # noqa: F821
-@login_required
-def getversion( version_id):
- try:
- e, version = UserCanvasVersionService.get_by_id(version_id)
- if version:
- return get_json_result(data=version.to_dict())
- except Exception as e:
- return get_json_result(data=f"Error getting history file: {e}")
-
-
-@manager.route('/list', methods=['GET']) # noqa: F821
-@login_required
-def list_canvas():
- keywords = request.args.get("keywords", "")
- page_number = int(request.args.get("page", 0))
- items_per_page = int(request.args.get("page_size", 0))
- orderby = request.args.get("orderby", "create_time")
- canvas_category = request.args.get("canvas_category")
- if request.args.get("desc", "true").lower() == "false":
- desc = False
- else:
- desc = True
- owner_ids = [id for id in request.args.get("owner_ids", "").strip().split(",") if id]
- if not owner_ids:
- tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
- tenants = [m["tenant_id"] for m in tenants]
- tenants.append(current_user.id)
- canvas, total = UserCanvasService.get_by_tenant_ids(
- tenants, current_user.id, page_number,
- items_per_page, orderby, desc, keywords, canvas_category)
- else:
- tenants = owner_ids
- canvas, total = UserCanvasService.get_by_tenant_ids(
- tenants, current_user.id, 0,
- 0, orderby, desc, keywords, canvas_category)
- return get_json_result(data={"canvas": canvas, "total": total})
-
-
-@manager.route('/setting', methods=['POST']) # noqa: F821
-@validate_request("id", "title", "permission")
-@login_required
-async def setting():
- req = await get_request_json()
- req["user_id"] = current_user.id
-
- if not UserCanvasService.accessible(req["id"], current_user.id):
- return get_json_result(
- data=False, message='Only owner of canvas authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
-
- e,flow = UserCanvasService.get_by_id(req["id"])
- if not e:
- return get_data_error_result(message="canvas not found.")
- flow = flow.to_dict()
- flow["title"] = req["title"]
-
- for key in ["description", "permission", "avatar"]:
- if value := req.get(key):
- flow[key] = value
-
- num= UserCanvasService.update_by_id(req["id"], flow)
- return get_json_result(data=num)
-
-
-@manager.route('/trace', methods=['GET']) # noqa: F821
-def trace():
- cvs_id = request.args.get("canvas_id")
- msg_id = request.args.get("message_id")
- try:
- binary = REDIS_CONN.get(f"{cvs_id}-{msg_id}-logs")
- if not binary:
- return get_json_result(data={})
-
- return get_json_result(data=json.loads(binary.encode("utf-8")))
- except Exception as e:
- logging.exception(e)
-
-
-@manager.route('//sessions', methods=['GET']) # noqa: F821
-@login_required
-def sessions(canvas_id):
- tenant_id = current_user.id
- if not UserCanvasService.accessible(canvas_id, tenant_id):
- return get_json_result(
- data=False, message='Only owner of canvas authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
-
- user_id = request.args.get("user_id")
- page_number = int(request.args.get("page", 1))
- items_per_page = int(request.args.get("page_size", 30))
- keywords = request.args.get("keywords")
- from_date = request.args.get("from_date")
- to_date = request.args.get("to_date")
- orderby = request.args.get("orderby", "update_time")
- exp_user_id = request.args.get("exp_user_id")
- if request.args.get("desc") == "False" or request.args.get("desc") == "false":
- desc = False
- else:
- desc = True
-
- if exp_user_id:
- sess = API4ConversationService.get_names(canvas_id, exp_user_id)
- return get_json_result(data={"total": len(sess), "sessions": sess})
-
- # dsl defaults to True in all cases except for False and false
- include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
- total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc,
- None, user_id, include_dsl, keywords, from_date, to_date, exp_user_id=exp_user_id)
- try:
- return get_json_result(data={"total": total, "sessions": sess})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('//sessions', methods=['PUT']) # noqa: F821
-@login_required
-async def set_session(canvas_id):
- req = await get_request_json()
- tenant_id = current_user.id
- e, cvs = UserCanvasService.get_by_id(canvas_id)
- assert e, "Agent not found."
- if not isinstance(cvs.dsl, str):
- cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
- session_id=get_uuid()
- canvas = Canvas(cvs.dsl, tenant_id, canvas_id, canvas_id=cvs.id)
- canvas.reset()
- # Get the version title for this canvas (using latest, not necessarily released)
- version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=False)
- conv = {
- "id": session_id,
- "name": req.get("name", ""),
- "dialog_id": cvs.id,
- "user_id": tenant_id,
- "exp_user_id": tenant_id,
- "message": [],
- "source": "agent",
- "dsl": cvs.dsl,
- "reference": [],
- "version_title": version_title
- }
- API4ConversationService.save(**conv)
- return get_json_result(data=conv)
-
-
-@manager.route('//sessions/', methods=['GET']) # noqa: F821
-@login_required
-def get_session(canvas_id, session_id):
- tenant_id = current_user.id
- if not UserCanvasService.accessible(canvas_id, tenant_id):
- return get_json_result(
- data=False, message='Only owner of canvas authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
- _, conv = API4ConversationService.get_by_id(session_id)
- return get_json_result(data=conv.to_dict())
-
-
-@manager.route('//sessions/', methods=['DELETE']) # noqa: F821
-@login_required
-def del_session(canvas_id, session_id):
- tenant_id = current_user.id
- if not UserCanvasService.accessible(canvas_id, tenant_id):
- return get_json_result(
- data=False, message='Only owner of canvas authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
- return get_json_result(data=API4ConversationService.delete_by_id(session_id))
-
-
-@manager.route('/prompts', methods=['GET']) # noqa: F821
-@login_required
-def prompts():
- from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
-
- return get_json_result(data={
- "task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
- "plan_generation": NEXT_STEP,
- "reflection": REFLECT,
- #"context_summary": SUMMARY4MEMORY,
- #"context_ranking": RANK_MEMORY,
- "citation_guidelines": CITATION_PROMPT_TEMPLATE
- })
-
-
-@manager.route('/download', methods=['GET']) # noqa: F821
-async def download():
- id = request.args.get("id")
- created_by = request.args.get("created_by")
- blob = FileService.get_blob(created_by, id)
- return await make_response(blob)
diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py
deleted file mode 100644
index e6ceb66e695..00000000000
--- a/api/apps/chunk_app.py
+++ /dev/null
@@ -1,580 +0,0 @@
-#
-# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import base64
-import datetime
-import json
-import logging
-import re
-import xxhash
-from quart import request
-
-from api.db.services.document_service import DocumentService
-from api.db.services.doc_metadata_service import DocMetadataService
-from api.utils.image_utils import store_chunk_image
-from api.db.services.knowledgebase_service import KnowledgebaseService
-from api.db.services.llm_service import LLMBundle
-from common.metadata_utils import apply_meta_data_filter
-from api.db.services.search_service import SearchService
-from api.db.services.user_service import UserTenantService
-from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_tenant_default_model_by_type, get_model_config_by_type_and_name
-from api.utils.api_utils import (
- get_data_error_result,
- get_json_result,
- server_error_response,
- validate_request,
- get_request_json,
-)
-from common.misc_utils import thread_pool_exec
-from common.tag_feature_utils import validate_tag_features
-from rag.app.qa import beAdoc, rmPrefix
-from rag.app.tag import label_question
-from rag.nlp import rag_tokenizer, search
-from rag.prompts.generator import cross_languages, keyword_extraction
-from common.string_utils import is_content_empty, remove_redundant_spaces
-from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
-from common import settings
-from api.apps import login_required, current_user
-
-@manager.route('/list', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("doc_id")
-async def list_chunk():
- req = await get_request_json()
- doc_id = req["doc_id"]
- page = int(req.get("page", 1))
- size = int(req.get("size", 30))
- question = req.get("keywords", "")
- try:
- tenant_id = DocumentService.get_tenant_id(req["doc_id"])
- if not tenant_id:
- return get_data_error_result(message="Tenant not found!")
- e, doc = DocumentService.get_by_id(doc_id)
- if not e:
- return get_data_error_result(message="Document not found!")
- kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
- query = {
- "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True
- }
- if "available_int" in req:
- query["available_int"] = int(req["available_int"])
- sres = await settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
- res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
- for id in sres.ids:
- d = {
- "chunk_id": id,
- "content_with_weight": remove_redundant_spaces(sres.highlight[id]) if question and id in sres.highlight else sres.field[
- id].get(
- "content_with_weight", ""),
- "doc_id": sres.field[id]["doc_id"],
- "docnm_kwd": sres.field[id]["docnm_kwd"],
- "important_kwd": sres.field[id].get("important_kwd", []),
- "question_kwd": sres.field[id].get("question_kwd", []),
- "image_id": sres.field[id].get("img_id", ""),
- "available_int": int(sres.field[id].get("available_int", 1)),
- "positions": sres.field[id].get("position_int", []),
- "doc_type_kwd": sres.field[id].get("doc_type_kwd")
- }
- assert isinstance(d["positions"], list)
- assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
- res["chunks"].append(d)
- return get_json_result(data=res)
- except Exception as e:
- if str(e).find("not_found") > 0:
- return get_json_result(data=False, message='No chunk found!',
- code=RetCode.DATA_ERROR)
- return server_error_response(e)
-
-
-@manager.route('/get', methods=['GET']) # noqa: F821
-@login_required
-def get():
- chunk_id = request.args["chunk_id"]
- try:
- chunk = None
- tenants = UserTenantService.query(user_id=current_user.id)
- if not tenants:
- return get_data_error_result(message="Tenant not found!")
- for tenant in tenants:
- kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id)
- chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
- if chunk:
- break
- if chunk is None:
- return server_error_response(Exception("Chunk not found"))
-
- k = []
- for n in chunk.keys():
- if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
- k.append(n)
- for n in k:
- del chunk[n]
-
- return get_json_result(data=chunk)
- except Exception as e:
- if str(e).find("NotFoundError") >= 0:
- return get_json_result(data=False, message='Chunk not found!',
- code=RetCode.DATA_ERROR)
- return server_error_response(e)
-
-
-@manager.route('/set', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("doc_id", "chunk_id", "content_with_weight")
-async def set():
- req = await get_request_json()
- content_with_weight = req["content_with_weight"]
- if not isinstance(content_with_weight, (str, bytes)):
- raise TypeError("expected string or bytes-like object")
- if isinstance(content_with_weight, bytes):
- content_with_weight = content_with_weight.decode("utf-8", errors="ignore")
- if is_content_empty(content_with_weight):
- return get_data_error_result(message="`content_with_weight` is required")
- d = {
- "id": req["chunk_id"],
- "content_with_weight": content_with_weight}
- d["content_ltks"] = rag_tokenizer.tokenize(content_with_weight)
- d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
- if "important_kwd" in req:
- if not isinstance(req["important_kwd"], list):
- return get_data_error_result(message="`important_kwd` should be a list")
- d["important_kwd"] = req["important_kwd"]
- d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
- if "question_kwd" in req:
- if not isinstance(req["question_kwd"], list):
- return get_data_error_result(message="`question_kwd` should be a list")
- d["question_kwd"] = req["question_kwd"]
- d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
- if "tag_kwd" in req:
- if not isinstance(req["tag_kwd"], list):
- return get_data_error_result(message="`tag_kwd` should be a list")
- if not all(isinstance(t, str) for t in req["tag_kwd"]):
- return get_data_error_result(message="`tag_kwd` must be a list of strings")
- d["tag_kwd"] = req["tag_kwd"]
- if "tag_feas" in req:
- try:
- d["tag_feas"] = validate_tag_features(req["tag_feas"])
- except ValueError as exc:
- return get_data_error_result(message=f"`tag_feas` {exc}")
- if "available_int" in req:
- d["available_int"] = req["available_int"]
-
- try:
- def _set_sync():
- tenant_id = DocumentService.get_tenant_id(req["doc_id"])
- if not tenant_id:
- return get_data_error_result(message="Tenant not found!")
-
- e, doc = DocumentService.get_by_id(req["doc_id"])
- if not e:
- return get_data_error_result(message="Document not found!")
-
- tenant_embd_id = DocumentService.get_tenant_embd_id(req["doc_id"])
- if tenant_embd_id:
- embd_model_config = get_model_config_by_id(tenant_embd_id)
- else:
- embd_id = DocumentService.get_embd_id(req["doc_id"])
- if embd_id:
- embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id)
- else:
- embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING)
- embd_mdl = LLMBundle(tenant_id, embd_model_config)
-
- _d = d
- if doc.parser_id == ParserType.QA:
- arr = [
- t for t in re.split(
- r"[\n\t]",
- req["content_with_weight"]) if len(t) > 1]
- q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:]))
- _d = beAdoc(d, q, a, not any(
- [rag_tokenizer.is_chinese(t) for t in q + a]))
-
- v, c = embd_mdl.encode([doc.name, content_with_weight if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
- v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
- _d["q_%d_vec" % len(v)] = v.tolist()
- settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
-
- # update image
- image_base64 = req.get("image_base64", None)
- img_id = req.get("img_id", "")
- if image_base64 and img_id and "-" in img_id:
- bkt, name = img_id.split("-", 1)
- image_binary = base64.b64decode(image_base64)
- settings.STORAGE_IMPL.put(bkt, name, image_binary)
- return get_json_result(data=True)
-
- return await thread_pool_exec(_set_sync)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/switch', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("chunk_ids", "available_int", "doc_id")
-async def switch():
- req = await get_request_json()
- try:
- def _switch_sync():
- e, doc = DocumentService.get_by_id(req["doc_id"])
- if not e:
- return get_data_error_result(message="Document not found!")
- for cid in req["chunk_ids"]:
- if not settings.docStoreConn.update({"id": cid},
- {"available_int": int(req["available_int"])},
- search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
- doc.kb_id):
- return get_data_error_result(message="Index updating failure")
- return get_json_result(data=True)
-
- return await thread_pool_exec(_switch_sync)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/rm', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("doc_id")
-async def rm():
- req = await get_request_json()
- try:
- def _rm_sync():
- deleted_chunk_ids = req.get("chunk_ids")
- if isinstance(deleted_chunk_ids, list):
- unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids))
- has_ids = len(unique_chunk_ids) > 0
- elif deleted_chunk_ids is not None:
- unique_chunk_ids = [deleted_chunk_ids]
- has_ids = deleted_chunk_ids not in (None, "")
- else:
- unique_chunk_ids = []
- has_ids = False
- if not has_ids:
- if req.get("delete_all") is True:
- e, doc = DocumentService.get_by_id(req["doc_id"])
- if not e:
- return get_data_error_result(message="Document not found!")
- tenant_id = DocumentService.get_tenant_id(req["doc_id"])
- # Clean up storage assets while index rows still exist for discovery
- DocumentService.delete_chunk_images(doc, tenant_id)
- condition = {"doc_id": req["doc_id"]}
- try:
- deleted_count = settings.docStoreConn.delete(condition, search.index_name(tenant_id), doc.kb_id)
- except Exception:
- return get_data_error_result(message="Chunk deleting failure")
- if deleted_count > 0:
- DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, deleted_count, 0)
- return get_json_result(data=True)
- return get_json_result(data=True)
-
- e, doc = DocumentService.get_by_id(req["doc_id"])
- if not e:
- return get_data_error_result(message="Document not found!")
- condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]}
- try:
- deleted_count = settings.docStoreConn.delete(condition,
- search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
- doc.kb_id)
- except Exception:
- return get_data_error_result(message="Chunk deleting failure")
- if has_ids and deleted_count == 0:
- return get_data_error_result(message="Index updating failure")
- if deleted_count > 0 and deleted_count < len(unique_chunk_ids):
- deleted_count += settings.docStoreConn.delete({"doc_id": req["doc_id"]},
- search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
- doc.kb_id)
- chunk_number = deleted_count
- DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
- for cid in deleted_chunk_ids:
- if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):
- settings.STORAGE_IMPL.rm(doc.kb_id, cid)
- return get_json_result(data=True)
-
- return await thread_pool_exec(_rm_sync)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/create', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("doc_id", "content_with_weight")
-async def create():
- req = await get_request_json()
- req_id = request.headers.get("X-Request-ID")
- chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest()
- d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
- "content_with_weight": req["content_with_weight"]}
- d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
- d["important_kwd"] = req.get("important_kwd", [])
- if not isinstance(d["important_kwd"], list):
- return get_data_error_result(message="`important_kwd` is required to be a list")
- d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
- d["question_kwd"] = req.get("question_kwd", [])
- if not isinstance(d["question_kwd"], list):
- return get_data_error_result(message="`question_kwd` is required to be a list")
- d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
- d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
- d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
- if "tag_kwd" in req:
- if not isinstance(req["tag_kwd"], list):
- return get_data_error_result(message="`tag_kwd` is required to be a list")
- if not all(isinstance(t, str) for t in req["tag_kwd"]):
- return get_data_error_result(message="`tag_kwd` must be a list of strings")
- d["tag_kwd"] = req["tag_kwd"]
- if "tag_feas" in req:
- try:
- d["tag_feas"] = validate_tag_features(req["tag_feas"])
- except ValueError as exc:
- return get_data_error_result(message=f"`tag_feas` {exc}")
- image_base64 = req.get("image_base64", None)
-
- try:
- def _log_response(resp, code, message):
- logging.info(
- "chunk_create response req_id=%s status=%s code=%s message=%s",
- req_id,
- getattr(resp, "status_code", None),
- code,
- message,
- )
-
- def _create_sync():
- e, doc = DocumentService.get_by_id(req["doc_id"])
- if not e:
- resp = get_data_error_result(message="Document not found!")
- _log_response(resp, RetCode.DATA_ERROR, "Document not found!")
- return resp
- d["kb_id"] = [doc.kb_id]
- d["docnm_kwd"] = doc.name
- d["title_tks"] = rag_tokenizer.tokenize(doc.name)
- d["doc_id"] = doc.id
-
- tenant_id = DocumentService.get_tenant_id(req["doc_id"])
- if not tenant_id:
- resp = get_data_error_result(message="Tenant not found!")
- _log_response(resp, RetCode.DATA_ERROR, "Tenant not found!")
- return resp
-
- e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
- if not e:
- resp = get_data_error_result(message="Knowledgebase not found!")
- _log_response(resp, RetCode.DATA_ERROR, "Knowledgebase not found!")
- return resp
- if kb.pagerank:
- d[PAGERANK_FLD] = kb.pagerank
-
- tenant_embd_id = DocumentService.get_tenant_embd_id(req["doc_id"])
- if tenant_embd_id:
- embd_model_config = get_model_config_by_id(tenant_embd_id)
- else:
- embd_id = DocumentService.get_embd_id(req["doc_id"])
- if embd_id:
- embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id)
- else:
- embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING)
- embd_mdl = LLMBundle(tenant_id, embd_model_config)
-
- if image_base64:
- d["img_id"] = "{}-{}".format(doc.kb_id, chunck_id)
- d["doc_type_kwd"] = "image"
-
- v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
- v = 0.1 * v[0] + 0.9 * v[1]
- d["q_%d_vec" % len(v)] = v.tolist()
- settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
-
- if image_base64:
- store_chunk_image(doc.kb_id, chunck_id, base64.b64decode(image_base64))
-
- DocumentService.increment_chunk_num(
- doc.id, doc.kb_id, c, 1, 0)
- resp = get_json_result(data={"chunk_id": chunck_id, "image_id": d.get("img_id", "")})
- _log_response(resp, RetCode.SUCCESS, "success")
- return resp
-
- return await thread_pool_exec(_create_sync)
- except Exception as e:
- logging.info("chunk_create exception req_id=%s error=%r", req_id, e)
- return server_error_response(e)
-
-
-@manager.route('/retrieval_test', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("kb_id", "question")
-async def retrieval_test():
- req = await get_request_json()
- page = int(req.get("page", 1))
- size = int(req.get("size", 30))
- question = req["question"]
- kb_ids = req["kb_id"]
- if isinstance(kb_ids, str):
- kb_ids = [kb_ids]
- if not kb_ids:
- return get_json_result(data=False, message='Please specify dataset firstly.',
- code=RetCode.DATA_ERROR)
-
- doc_ids = req.get("doc_ids", [])
- use_kg = req.get("use_kg", False)
- top = int(req.get("top_k", 1024))
- langs = req.get("cross_languages", [])
- user_id = current_user.id
-
- async def _retrieval():
- local_doc_ids = list(doc_ids) if doc_ids else []
- tenant_ids = []
-
- meta_data_filter = {}
- chat_mdl = None
- if req.get("search_id", ""):
- search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
- meta_data_filter = search_config.get("meta_data_filter", {})
- if meta_data_filter.get("method") in ["auto", "semi_auto"]:
- chat_id = search_config.get("chat_id", "")
- if chat_id:
- chat_model_config = get_model_config_by_type_and_name(user_id, LLMType.CHAT, search_config["chat_id"])
- else:
- chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
- chat_mdl = LLMBundle(user_id, chat_model_config)
- else:
- meta_data_filter = req.get("meta_data_filter") or {}
- if meta_data_filter.get("method") in ["auto", "semi_auto"]:
- chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
- chat_mdl = LLMBundle(user_id, chat_model_config)
-
- if meta_data_filter:
- metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
- local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids)
-
- tenants = UserTenantService.query(user_id=user_id)
- for kb_id in kb_ids:
- for tenant in tenants:
- if KnowledgebaseService.query(
- tenant_id=tenant.tenant_id, id=kb_id):
- tenant_ids.append(tenant.tenant_id)
- break
- else:
- return get_json_result(
- data=False, message='Only owner of dataset authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
-
- e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
- if not e:
- return get_data_error_result(message="Knowledgebase not found!")
-
- _question = question
- if langs:
- _question = await cross_languages(kb.tenant_id, None, _question, langs)
- if kb.tenant_embd_id:
- embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
- elif kb.embd_id:
- embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
- else:
- embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
- embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
-
- rerank_mdl = None
- if req.get("tenant_rerank_id"):
- rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
- rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
- elif req.get("rerank_id"):
- rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
- rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
-
- if req.get("keyword", False):
- default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
- chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
- _question += await keyword_extraction(chat_mdl, _question)
-
- labels = label_question(_question, [kb])
- ranks = await settings.retriever.retrieval(
- _question,
- embd_mdl,
- tenant_ids,
- kb_ids,
- page,
- size,
- float(req.get("similarity_threshold", 0.0)),
- float(req.get("vector_similarity_weight", 0.3)),
- doc_ids=local_doc_ids,
- top=top,
- rerank_mdl=rerank_mdl,
- rank_feature=labels
- )
-
- if use_kg:
- default_chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT)
- ck = await settings.kg_retriever.retrieval(_question,
- tenant_ids,
- kb_ids,
- embd_mdl,
- LLMBundle(kb.tenant_id, default_chat_model_config))
- if ck["content_with_weight"]:
- ranks["chunks"].insert(0, ck)
- ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
-
- for c in ranks["chunks"]:
- c.pop("vector", None)
- ranks["labels"] = labels
-
- return get_json_result(data=ranks)
-
- try:
- return await _retrieval()
- except Exception as e:
- if str(e).find("not_found") > 0:
- return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
- code=RetCode.DATA_ERROR)
- return server_error_response(e)
-
-
-@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821
-@login_required
-async def knowledge_graph():
- doc_id = request.args["doc_id"]
- tenant_id = DocumentService.get_tenant_id(doc_id)
- kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
- req = {
- "doc_ids": [doc_id],
- "knowledge_graph_kwd": ["graph", "mind_map"]
- }
- sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
- obj = {"graph": {}, "mind_map": {}}
- for id in sres.ids[:2]:
- ty = sres.field[id]["knowledge_graph_kwd"]
- try:
- content_json = json.loads(sres.field[id]["content_with_weight"])
- except Exception:
- continue
-
- if ty == 'mind_map':
- node_dict = {}
-
- def repeat_deal(content_json, node_dict):
- if 'id' in content_json:
- if content_json['id'] in node_dict:
- node_name = content_json['id']
- content_json['id'] += f"({node_dict[content_json['id']]})"
- node_dict[node_name] += 1
- else:
- node_dict[content_json['id']] = 1
- if 'children' in content_json and content_json['children']:
- for item in content_json['children']:
- repeat_deal(item, node_dict)
-
- repeat_deal(content_json, node_dict)
-
- obj[ty] = content_json
-
- return get_json_result(data=obj)
diff --git a/api/apps/document_app.py b/api/apps/document_app.py
deleted file mode 100644
index 9a9cafb9b1c..00000000000
--- a/api/apps/document_app.py
+++ /dev/null
@@ -1,716 +0,0 @@
-#
-# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License
-#
-import os.path
-import re
-from pathlib import Path, PurePosixPath, PureWindowsPath
-
-from quart import make_response, request
-
-from api.apps import current_user, login_required
-from api.common.check_team_permission import check_kb_team_permission
-from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
-from api.db import VALID_FILE_TYPES, FileType
-from api.db.db_models import Task
-from api.db.services import duplicate_name
-from api.db.services.doc_metadata_service import DocMetadataService
-from api.db.services.document_service import DocumentService, doc_upload_and_parse
-from api.db.services.file2document_service import File2DocumentService
-from api.db.services.file_service import FileService
-from api.db.services.knowledgebase_service import KnowledgebaseService
-from api.db.services.task_service import TaskService, cancel_all_task_of
-from api.db.services.user_service import UserTenantService
-from api.utils.api_utils import (
- get_data_error_result,
- get_json_result,
- get_request_json,
- server_error_response,
- validate_request,
-)
-from api.utils.file_utils import filename_type, thumbnail
-from api.utils.web_utils import CONTENT_TYPE_MAP, apply_safe_file_response_headers, html2pdf, is_valid_url
-from common import settings
-from common.constants import SANDBOX_ARTIFACT_BUCKET, VALID_TASK_STATUS, ParserType, RetCode, TaskStatus
-from common.file_utils import get_project_base_directory
-from common.misc_utils import get_uuid, thread_pool_exec
-from deepdoc.parser.html_parser import RAGFlowHtmlParser
-from rag.nlp import search
-
-
-def _is_safe_download_filename(name: str) -> bool:
- if not name or name in {".", ".."}:
- return False
- if "\x00" in name or len(name) > 255:
- return False
- if name != PurePosixPath(name).name:
- return False
- if name != PureWindowsPath(name).name:
- return False
- return True
-
-
-@manager.route("/web_crawl", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("kb_id", "name", "url")
-async def web_crawl():
- form = await request.form
- kb_id = form.get("kb_id")
- if not kb_id:
- return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
- name = form.get("name")
- url = form.get("url")
- if not is_valid_url(url):
- return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
- e, kb = KnowledgebaseService.get_by_id(kb_id)
- if not e:
- raise LookupError("Can't find this dataset!")
- if not check_kb_team_permission(kb, current_user.id):
- return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
-
- blob = html2pdf(url)
- if not blob:
- return server_error_response(ValueError("Download failure."))
-
- root_folder = FileService.get_root_folder(current_user.id)
- pf_id = root_folder["id"]
- FileService.init_knowledgebase_docs(pf_id, current_user.id)
- kb_root_folder = FileService.get_kb_folder(current_user.id)
- kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
-
- try:
- filename = duplicate_name(DocumentService.query, name=name + ".pdf", kb_id=kb.id)
- filetype = filename_type(filename)
- if filetype == FileType.OTHER.value:
- raise RuntimeError("This type of file has not been supported yet!")
-
- location = filename
- while settings.STORAGE_IMPL.obj_exist(kb_id, location):
- location += "_"
- settings.STORAGE_IMPL.put(kb_id, location, blob)
- doc = {
- "id": get_uuid(),
- "kb_id": kb.id,
- "parser_id": kb.parser_id,
- "parser_config": kb.parser_config,
- "created_by": current_user.id,
- "type": filetype,
- "name": filename,
- "location": location,
- "size": len(blob),
- "thumbnail": thumbnail(filename, blob),
- "suffix": Path(filename).suffix.lstrip("."),
- }
- if doc["type"] == FileType.VISUAL:
- doc["parser_id"] = ParserType.PICTURE.value
- if doc["type"] == FileType.AURAL:
- doc["parser_id"] = ParserType.AUDIO.value
- if re.search(r"\.(ppt|pptx|pages)$", filename):
- doc["parser_id"] = ParserType.PRESENTATION.value
- if re.search(r"\.(eml)$", filename):
- doc["parser_id"] = ParserType.EMAIL.value
- DocumentService.insert(doc)
- FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
- except Exception as e:
- return server_error_response(e)
- return get_json_result(data=True)
-
-
-@manager.route("/create", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("name", "kb_id")
-async def create():
- req = await get_request_json()
- kb_id = req["kb_id"]
- if not kb_id:
- return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
- if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT:
- return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
-
- if req["name"].strip() == "":
- return get_json_result(data=False, message="File name can't be empty.", code=RetCode.ARGUMENT_ERROR)
- req["name"] = req["name"].strip()
-
- try:
- e, kb = KnowledgebaseService.get_by_id(kb_id)
- if not e:
- return get_data_error_result(message="Can't find this dataset!")
-
- if DocumentService.query(name=req["name"], kb_id=kb_id):
- return get_data_error_result(message="Duplicated document name in the same dataset.")
-
- kb_root_folder = FileService.get_kb_folder(kb.tenant_id)
- if not kb_root_folder:
- return get_data_error_result(message="Cannot find the root folder.")
- kb_folder = FileService.new_a_file_from_kb(
- kb.tenant_id,
- kb.name,
- kb_root_folder["id"],
- )
- if not kb_folder:
- return get_data_error_result(message="Cannot find the kb folder for this file.")
-
- doc = DocumentService.insert(
- {
- "id": get_uuid(),
- "kb_id": kb.id,
- "parser_id": kb.parser_id,
- "pipeline_id": kb.pipeline_id,
- "parser_config": kb.parser_config,
- "created_by": current_user.id,
- "type": FileType.VIRTUAL,
- "name": req["name"],
- "suffix": Path(req["name"]).suffix.lstrip("."),
- "location": "",
- "size": 0,
- }
- )
-
- FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], kb.tenant_id)
-
- return get_json_result(data=doc.to_json())
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/filter", methods=["POST"]) # noqa: F821
-@login_required
-async def get_filter():
- req = await get_request_json()
-
- kb_id = req.get("kb_id")
- if not kb_id:
- return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
- tenants = UserTenantService.query(user_id=current_user.id)
- for tenant in tenants:
- if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
- break
- else:
- return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR)
-
- keywords = req.get("keywords", "")
-
- suffix = req.get("suffix", [])
-
- run_status = req.get("run_status", [])
- if run_status:
- invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS}
- if invalid_status:
- return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}")
-
- types = req.get("types", [])
- if types:
- invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
- if invalid_types:
- return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
-
- try:
- filter, total = DocumentService.get_filter_by_kb_id(kb_id, keywords, run_status, types, suffix)
- return get_json_result(data={"total": total, "filter": filter})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/infos", methods=["POST"]) # noqa: F821
-@login_required
-async def doc_infos():
- req = await get_request_json()
- doc_ids = req["doc_ids"]
- for doc_id in doc_ids:
- if not DocumentService.accessible(doc_id, current_user.id):
- return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
- docs = DocumentService.get_by_ids(doc_ids)
- docs_list = list(docs.dicts())
- # Add meta_fields for each document
- for doc in docs_list:
- doc["meta_fields"] = DocMetadataService.get_document_metadata(doc["id"])
- return get_json_result(data=docs_list)
-
-
-@manager.route("/metadata/update", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("doc_ids")
-async def metadata_update():
- req = await get_request_json()
- kb_id = req.get("kb_id")
- document_ids = req.get("doc_ids")
- updates = req.get("updates", []) or []
- deletes = req.get("deletes", []) or []
-
- if not kb_id:
- return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
-
- if not isinstance(updates, list) or not isinstance(deletes, list):
- return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR)
-
- for upd in updates:
- if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
- return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR)
- for d in deletes:
- if not isinstance(d, dict) or not d.get("key"):
- return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR)
-
- updated = DocMetadataService.batch_update_metadata(kb_id, document_ids, updates, deletes)
- return get_json_result(data={"updated": updated, "matched_docs": len(document_ids)})
-
-
-@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("doc_id", "metadata")
-async def update_metadata_setting():
- req = await get_request_json()
- if not DocumentService.accessible(req["doc_id"], current_user.id):
- return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
-
- e, doc = DocumentService.get_by_id(req["doc_id"])
- if not e:
- return get_data_error_result(message="Document not found!")
-
- DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]})
- e, doc = DocumentService.get_by_id(doc.id)
- if not e:
- return get_data_error_result(message="Document not found!")
-
- return get_json_result(data=doc.to_dict())
-
-
-@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
-# @login_required
-def thumbnails():
- doc_ids = request.args.getlist("doc_ids")
- if not doc_ids:
- return get_json_result(data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR)
-
- try:
- docs = DocumentService.get_thumbnails(doc_ids)
-
- for doc_item in docs:
- if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
- doc_item["thumbnail"] = f"/v1/document/image/{doc_item['kb_id']}-{doc_item['thumbnail']}"
-
- return get_json_result(data={d["id"]: d["thumbnail"] for d in docs})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/change_status", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("doc_ids", "status")
-async def change_status():
- req = await get_request_json()
- doc_ids = req.get("doc_ids", [])
- status = str(req.get("status", ""))
-
- if status not in ["0", "1"]:
- return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=RetCode.ARGUMENT_ERROR)
-
- result = {}
- has_error = False
- for doc_id in doc_ids:
- if not DocumentService.accessible(doc_id, current_user.id):
- result[doc_id] = {"error": "No authorization."}
- has_error = True
- continue
-
- try:
- e, doc = DocumentService.get_by_id(doc_id)
- if not e:
- result[doc_id] = {"error": "No authorization."}
- has_error = True
- continue
- e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
- if not e:
- result[doc_id] = {"error": "Can't find this dataset!"}
- has_error = True
- continue
- current_status = str(doc.status)
- if current_status == status:
- result[doc_id] = {"status": status}
- continue
- if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
- result[doc_id] = {"error": "Database error (Document update)!"}
- has_error = True
- continue
-
- status_int = int(status)
- if getattr(doc, "chunk_num", 0) > 0:
- try:
- ok = settings.docStoreConn.update(
- {"doc_id": doc_id},
- {"available_int": status_int},
- search.index_name(kb.tenant_id),
- doc.kb_id,
- )
- except Exception as exc:
- msg = str(exc)
- if "3022" in msg:
- result[doc_id] = {"error": "Document store table missing."}
- else:
- result[doc_id] = {"error": f"Document store update failed: {msg}"}
- has_error = True
- continue
- if not ok:
- result[doc_id] = {"error": "Database error (docStore update)!"}
- has_error = True
- continue
- result[doc_id] = {"status": status}
- except Exception as e:
- result[doc_id] = {"error": f"Internal server error: {str(e)}"}
- has_error = True
-
- if has_error:
- return get_json_result(data=result, message="Partial failure", code=RetCode.SERVER_ERROR)
- return get_json_result(data=result)
-
-
-@manager.route("/rm", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("doc_id")
-async def rm():
- req = await get_request_json()
- doc_ids = req["doc_id"]
- if isinstance(doc_ids, str):
- doc_ids = [doc_ids]
-
- for doc_id in doc_ids:
- if not DocumentService.accessible4deletion(doc_id, current_user.id):
- return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
-
- errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id)
-
- if errors:
- return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
-
- return get_json_result(data=True)
-
-
-@manager.route("/run", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("doc_ids", "run")
-async def run():
- req = await get_request_json()
- uid = current_user.id
- try:
-
- def _run_sync():
- for doc_id in req["doc_ids"]:
- if not DocumentService.accessible(doc_id, uid):
- return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
-
- kb_table_num_map = {}
- for id in req["doc_ids"]:
- info = {"run": str(req["run"]), "progress": 0}
- if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False):
- info["progress_msg"] = ""
- info["chunk_num"] = 0
- info["token_num"] = 0
-
- tenant_id = DocumentService.get_tenant_id(id)
- if not tenant_id:
- return get_data_error_result(message="Tenant not found!")
- e, doc = DocumentService.get_by_id(id)
- if not e:
- return get_data_error_result(message="Document not found!")
-
- if str(req["run"]) == TaskStatus.CANCEL.value:
- tasks = list(TaskService.query(doc_id=id))
- has_unfinished_task = any((task.progress or 0) < 1 for task in tasks)
- if str(doc.run) in [TaskStatus.RUNNING.value, TaskStatus.CANCEL.value] or has_unfinished_task:
- cancel_all_task_of(id)
- else:
- return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status")
- if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]):
- DocumentService.clear_chunk_num_when_rerun(doc.id)
-
- DocumentService.update_by_id(id, info)
- if req.get("delete", False):
- TaskService.filter_delete([Task.doc_id == id])
- if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
- settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
-
- if str(req["run"]) == TaskStatus.RUNNING.value:
- if req.get("apply_kb"):
- e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
- if not e:
- raise LookupError("Can't find this dataset!")
- doc.parser_config["llm_id"] = kb.parser_config.get("llm_id")
- doc.parser_config["enable_metadata"] = kb.parser_config.get("enable_metadata", False)
- doc.parser_config["metadata"] = kb.parser_config.get("metadata", {})
- DocumentService.update_parser_config(doc.id, doc.parser_config)
- doc_dict = doc.to_dict()
- DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
-
- return get_json_result(data=True)
-
- return await thread_pool_exec(_run_sync)
- except Exception as e:
- return server_error_response(e)
-
-@manager.route("/get/", methods=["GET"]) # noqa: F821
-@login_required
-async def get(doc_id):
- try:
- e, doc = DocumentService.get_by_id(doc_id)
- if not e:
- return get_data_error_result(message="Document not found!")
-
- b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
- data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
- response = await make_response(data)
-
- ext = re.search(r"\.([^.]+)$", doc.name.lower())
- ext = ext.group(1) if ext else None
- content_type = None
- if ext:
- fallback_prefix = "image" if doc.type == FileType.VISUAL.value else "application"
- content_type = CONTENT_TYPE_MAP.get(ext, f"{fallback_prefix}/{ext}")
- apply_safe_file_response_headers(response, content_type, ext)
- return response
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/download/", methods=["GET"]) # noqa: F821
-@login_required
-async def download_attachment(attachment_id):
- try:
- ext = request.args.get("ext", "markdown")
- data = await thread_pool_exec(settings.STORAGE_IMPL.get, current_user.id, attachment_id)
- response = await make_response(data)
- content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
- apply_safe_file_response_headers(response, content_type, ext)
-
- return response
-
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/change_parser", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("doc_id")
-async def change_parser():
- req = await get_request_json()
- if not DocumentService.accessible(req["doc_id"], current_user.id):
- return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
-
- e, doc = DocumentService.get_by_id(req["doc_id"])
- if not e:
- return get_data_error_result(message="Document not found!")
-
- def reset_doc():
- nonlocal doc
- e = DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"], "parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value})
- if not e:
- return get_data_error_result(message="Document not found!")
- if doc.token_num > 0:
- e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1)
- if not e:
- return get_data_error_result(message="Document not found!")
- tenant_id = DocumentService.get_tenant_id(req["doc_id"])
- if not tenant_id:
- return get_data_error_result(message="Tenant not found!")
- DocumentService.delete_chunk_images(doc, tenant_id)
- if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
- settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
- return None
-
- try:
- if "pipeline_id" in req and req["pipeline_id"] != "":
- if doc.pipeline_id == req["pipeline_id"]:
- return get_json_result(data=True)
- DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]})
- reset_doc()
- return get_json_result(data=True)
-
- if doc.parser_id.lower() == req["parser_id"].lower():
- if "parser_config" in req:
- if req["parser_config"] == doc.parser_config:
- return get_json_result(data=True)
- else:
- return get_json_result(data=True)
-
- if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"):
- return get_data_error_result(message="Not supported yet!")
- if "parser_config" in req:
- DocumentService.update_parser_config(doc.id, req["parser_config"])
- reset_doc()
- return get_json_result(data=True)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/image/", methods=["GET"]) # noqa: F821
-# @login_required
-async def get_image(image_id):
- try:
- arr = image_id.split("-")
- if len(arr) != 2:
- return get_data_error_result(message="Image not found.")
- bkt, nm = image_id.split("-")
- data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm)
- response = await make_response(data)
- response.headers.set("Content-Type", "image/JPEG")
- return response
- except Exception as e:
- return server_error_response(e)
-
-
-ARTIFACT_CONTENT_TYPES = {
- ".png": "image/png",
- ".jpg": "image/jpeg",
- ".jpeg": "image/jpeg",
- ".svg": "image/svg+xml",
- ".pdf": "application/pdf",
- ".csv": "text/csv",
- ".json": "application/json",
- ".html": "text/html",
-}
-
-
-@manager.route("/artifact/", methods=["GET"]) # noqa: F821
-@login_required
-async def get_artifact(filename):
- try:
- bucket = SANDBOX_ARTIFACT_BUCKET
- # Validate filename: must be uuid hex + allowed extension, nothing else
- basename = os.path.basename(filename)
- if basename != filename or "/" in filename or "\\" in filename:
- return get_data_error_result(message="Invalid filename.")
- ext = os.path.splitext(basename)[1].lower()
- if ext not in ARTIFACT_CONTENT_TYPES:
- return get_data_error_result(message="Invalid file type.")
- data = await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, basename)
- if not data:
- return get_data_error_result(message="Artifact not found.")
- content_type = ARTIFACT_CONTENT_TYPES.get(ext, "application/octet-stream")
- response = await make_response(data)
- safe_filename = re.sub(r"[^\w.\-]", "_", basename)
- apply_safe_file_response_headers(response, content_type, ext)
- if not response.headers.get("Content-Disposition"):
- response.headers.set("Content-Disposition", f'inline; filename="{safe_filename}"')
- return response
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("conversation_id")
-async def upload_and_parse():
- files = await request.files
- if "file" not in files:
- return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
-
- file_objs = files.getlist("file")
- for file_obj in file_objs:
- if file_obj.filename == "":
- return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
-
- form = await request.form
- doc_ids = doc_upload_and_parse(form.get("conversation_id"), file_objs, current_user.id)
- return get_json_result(data=doc_ids)
-
-
-@manager.route("/parse", methods=["POST"]) # noqa: F821
-@login_required
-async def parse():
- req = await get_request_json()
- url = req.get("url", "")
- if url:
- if not is_valid_url(url):
- return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
- download_path = os.path.join(get_project_base_directory(), "logs/downloads")
- os.makedirs(download_path, exist_ok=True)
- from seleniumwire.webdriver import Chrome, ChromeOptions
-
- options = ChromeOptions()
- options.add_argument("--headless")
- options.add_argument("--disable-gpu")
- options.add_argument("--no-sandbox")
- options.add_argument("--disable-dev-shm-usage")
- options.add_experimental_option("prefs", {"download.default_directory": download_path, "download.prompt_for_download": False, "download.directory_upgrade": True, "safebrowsing.enabled": True})
- driver = Chrome(options=options)
- driver.get(url)
- res_headers = [r.response.headers for r in driver.requests if r and r.response]
- if len(res_headers) > 1:
- sections = RAGFlowHtmlParser().parser_txt(driver.page_source)
- driver.quit()
- return get_json_result(data="\n".join(sections))
-
- class File:
- filename: str
- filepath: str
-
- def __init__(self, filename, filepath):
- self.filename = filename
- self.filepath = filepath
-
- def read(self):
- with open(self.filepath, "rb") as f:
- return f.read()
-
- r = re.search(r"filename=\"([^\"]+)\"", str(res_headers))
- if not r or not r.group(1):
- return get_json_result(data=False, message="Can't not identify downloaded file", code=RetCode.ARGUMENT_ERROR)
- filename = r.group(1).strip()
- if not _is_safe_download_filename(filename):
- return get_json_result(data=False, message="Invalid downloaded filename", code=RetCode.ARGUMENT_ERROR)
- filepath = os.path.join(download_path, filename)
- f = File(filename, filepath)
- txt = FileService.parse_docs([f], current_user.id)
- return get_json_result(data=txt)
-
- files = await request.files
- if "file" not in files:
- return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
-
- file_objs = files.getlist("file")
- txt = FileService.parse_docs(file_objs, current_user.id)
-
- return get_json_result(data=txt)
-
-
-@manager.route("/upload_info", methods=["POST"]) # noqa: F821
-@login_required
-async def upload_info():
- files = await request.files
- file_objs = files.getlist("file") if files and files.get("file") else []
- url = request.args.get("url")
-
- if file_objs and url:
- return get_json_result(
- data=False,
- message="Provide either multipart file(s) or ?url=..., not both.",
- code=RetCode.BAD_REQUEST,
- )
-
- if not file_objs and not url:
- return get_json_result(
- data=False,
- message="Missing input: provide multipart file(s) or url",
- code=RetCode.BAD_REQUEST,
- )
-
- try:
- if url and not file_objs:
- return get_json_result(data=FileService.upload_info(current_user.id, None, url))
-
- if len(file_objs) == 1:
- return get_json_result(data=FileService.upload_info(current_user.id, file_objs[0], None))
-
- results = [FileService.upload_info(current_user.id, f, None) for f in file_objs]
- return get_json_result(data=results)
- except Exception as e:
- return server_error_response(e)
diff --git a/api/apps/evaluation_app.py b/api/apps/evaluation_app.py
deleted file mode 100644
index b33db26da17..00000000000
--- a/api/apps/evaluation_app.py
+++ /dev/null
@@ -1,479 +0,0 @@
-#
-# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""
-RAG Evaluation API Endpoints
-
-Provides REST API for RAG evaluation functionality including:
-- Dataset management
-- Test case management
-- Evaluation execution
-- Results retrieval
-- Configuration recommendations
-"""
-
-from quart import request
-from api.apps import login_required, current_user
-from api.db.services.evaluation_service import EvaluationService
-from api.utils.api_utils import (
- get_data_error_result,
- get_json_result,
- get_request_json,
- server_error_response,
- validate_request
-)
-from common.constants import RetCode
-
-
-# ==================== Dataset Management ====================
-
-@manager.route('/dataset/create', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("name", "kb_ids")
-async def create_dataset():
- """
- Create a new evaluation dataset.
-
- Request body:
- {
- "name": "Dataset name",
- "description": "Optional description",
- "kb_ids": ["kb_id1", "kb_id2"]
- }
- """
- try:
- req = await get_request_json()
- name = req.get("name", "").strip()
- description = req.get("description", "")
- kb_ids = req.get("kb_ids", [])
-
- if not name:
- return get_data_error_result(message="Dataset name cannot be empty")
-
- if not kb_ids or not isinstance(kb_ids, list):
- return get_data_error_result(message="kb_ids must be a non-empty list")
-
- success, result = EvaluationService.create_dataset(
- name=name,
- description=description,
- kb_ids=kb_ids,
- tenant_id=current_user.id,
- user_id=current_user.id
- )
-
- if not success:
- return get_data_error_result(message=result)
-
- return get_json_result(data={"dataset_id": result})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/dataset/list', methods=['GET']) # noqa: F821
-@login_required
-async def list_datasets():
- """
- List evaluation datasets for current tenant.
-
- Query params:
- - page: Page number (default: 1)
- - page_size: Items per page (default: 20)
- """
- try:
- page = int(request.args.get("page", 1))
- page_size = int(request.args.get("page_size", 20))
-
- result = EvaluationService.list_datasets(
- tenant_id=current_user.id,
- user_id=current_user.id,
- page=page,
- page_size=page_size
- )
-
- return get_json_result(data=result)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/dataset/', methods=['GET']) # noqa: F821
-@login_required
-async def get_dataset(dataset_id):
- """Get dataset details by ID"""
- try:
- dataset = EvaluationService.get_dataset(dataset_id)
- if not dataset:
- return get_data_error_result(
- message="Dataset not found",
- code=RetCode.DATA_ERROR
- )
-
- return get_json_result(data=dataset)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/dataset/', methods=['PUT']) # noqa: F821
-@login_required
-async def update_dataset(dataset_id):
- """
- Update dataset.
-
- Request body:
- {
- "name": "New name",
- "description": "New description",
- "kb_ids": ["kb_id1", "kb_id2"]
- }
- """
- try:
- req = await get_request_json()
-
- # Remove fields that shouldn't be updated
- req.pop("id", None)
- req.pop("tenant_id", None)
- req.pop("created_by", None)
- req.pop("create_time", None)
-
- success = EvaluationService.update_dataset(dataset_id, **req)
-
- if not success:
- return get_data_error_result(message="Failed to update dataset")
-
- return get_json_result(data={"dataset_id": dataset_id})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/dataset/', methods=['DELETE']) # noqa: F821
-@login_required
-async def delete_dataset(dataset_id):
- """Delete dataset (soft delete)"""
- try:
- success = EvaluationService.delete_dataset(dataset_id)
-
- if not success:
- return get_data_error_result(message="Failed to delete dataset")
-
- return get_json_result(data={"dataset_id": dataset_id})
- except Exception as e:
- return server_error_response(e)
-
-
-# ==================== Test Case Management ====================
-
-@manager.route('/dataset//case/add', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("question")
-async def add_test_case(dataset_id):
- """
- Add a test case to a dataset.
-
- Request body:
- {
- "question": "Test question",
- "reference_answer": "Optional ground truth answer",
- "relevant_doc_ids": ["doc_id1", "doc_id2"],
- "relevant_chunk_ids": ["chunk_id1", "chunk_id2"],
- "metadata": {"key": "value"}
- }
- """
- try:
- req = await get_request_json()
- question = req.get("question", "").strip()
-
- if not question:
- return get_data_error_result(message="Question cannot be empty")
-
- success, result = EvaluationService.add_test_case(
- dataset_id=dataset_id,
- question=question,
- reference_answer=req.get("reference_answer"),
- relevant_doc_ids=req.get("relevant_doc_ids"),
- relevant_chunk_ids=req.get("relevant_chunk_ids"),
- metadata=req.get("metadata")
- )
-
- if not success:
- return get_data_error_result(message=result)
-
- return get_json_result(data={"case_id": result})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/dataset//case/import', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("cases")
-async def import_test_cases(dataset_id):
- """
- Bulk import test cases.
-
- Request body:
- {
- "cases": [
- {
- "question": "Question 1",
- "reference_answer": "Answer 1",
- ...
- },
- {
- "question": "Question 2",
- ...
- }
- ]
- }
- """
- try:
- req = await get_request_json()
- cases = req.get("cases", [])
-
- if not cases or not isinstance(cases, list):
- return get_data_error_result(message="cases must be a non-empty list")
-
- success_count, failure_count = EvaluationService.import_test_cases(
- dataset_id=dataset_id,
- cases=cases
- )
-
- return get_json_result(data={
- "success_count": success_count,
- "failure_count": failure_count,
- "total": len(cases)
- })
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/dataset//cases', methods=['GET']) # noqa: F821
-@login_required
-async def get_test_cases(dataset_id):
- """Get all test cases for a dataset"""
- try:
- cases = EvaluationService.get_test_cases(dataset_id)
- return get_json_result(data={"cases": cases, "total": len(cases)})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/case/', methods=['DELETE']) # noqa: F821
-@login_required
-async def delete_test_case(case_id):
- """Delete a test case"""
- try:
- success = EvaluationService.delete_test_case(case_id)
-
- if not success:
- return get_data_error_result(message="Failed to delete test case")
-
- return get_json_result(data={"case_id": case_id})
- except Exception as e:
- return server_error_response(e)
-
-
-# ==================== Evaluation Execution ====================
-
-@manager.route('/run/start', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("dataset_id", "dialog_id")
-async def start_evaluation():
- """
- Start an evaluation run.
-
- Request body:
- {
- "dataset_id": "dataset_id",
- "dialog_id": "dialog_id",
- "name": "Optional run name"
- }
- """
- try:
- req = await get_request_json()
- dataset_id = req.get("dataset_id")
- dialog_id = req.get("dialog_id")
- name = req.get("name")
-
- success, result = EvaluationService.start_evaluation(
- dataset_id=dataset_id,
- dialog_id=dialog_id,
- user_id=current_user.id,
- name=name
- )
-
- if not success:
- return get_data_error_result(message=result)
-
- return get_json_result(data={"run_id": result})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/run/', methods=['GET']) # noqa: F821
-@login_required
-async def get_evaluation_run(run_id):
- """Get evaluation run details"""
- try:
- result = EvaluationService.get_run_results(run_id)
-
- if not result:
- return get_data_error_result(
- message="Evaluation run not found",
- code=RetCode.DATA_ERROR
- )
-
- return get_json_result(data=result)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/run//results', methods=['GET']) # noqa: F821
-@login_required
-async def get_run_results(run_id):
- """Get detailed results for an evaluation run"""
- try:
- result = EvaluationService.get_run_results(run_id)
-
- if not result:
- return get_data_error_result(
- message="Evaluation run not found",
- code=RetCode.DATA_ERROR
- )
-
- return get_json_result(data=result)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/run/list', methods=['GET']) # noqa: F821
-@login_required
-async def list_evaluation_runs():
- """
- List evaluation runs.
-
- Query params:
- - dataset_id: Filter by dataset (optional)
- - dialog_id: Filter by dialog (optional)
- - page: Page number (default: 1)
- - page_size: Items per page (default: 20)
- """
- try:
- # TODO: Implement list_runs in EvaluationService
- return get_json_result(data={"runs": [], "total": 0})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/run/', methods=['DELETE']) # noqa: F821
-@login_required
-async def delete_evaluation_run(run_id):
- """Delete an evaluation run"""
- try:
- # TODO: Implement delete_run in EvaluationService
- return get_json_result(data={"run_id": run_id})
- except Exception as e:
- return server_error_response(e)
-
-
-# ==================== Analysis & Recommendations ====================
-
-@manager.route('/run//recommendations', methods=['GET']) # noqa: F821
-@login_required
-async def get_recommendations(run_id):
- """Get configuration recommendations based on evaluation results"""
- try:
- recommendations = EvaluationService.get_recommendations(run_id)
- return get_json_result(data={"recommendations": recommendations})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/compare', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("run_ids")
-async def compare_runs():
- """
- Compare multiple evaluation runs.
-
- Request body:
- {
- "run_ids": ["run_id1", "run_id2", "run_id3"]
- }
- """
- try:
- req = await get_request_json()
- run_ids = req.get("run_ids", [])
-
- if not run_ids or not isinstance(run_ids, list) or len(run_ids) < 2:
- return get_data_error_result(
- message="run_ids must be a list with at least 2 run IDs"
- )
-
- # TODO: Implement compare_runs in EvaluationService
- return get_json_result(data={"comparison": {}})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/run//export', methods=['GET']) # noqa: F821
-@login_required
-async def export_results(run_id):
- """Export evaluation results as JSON/CSV"""
- try:
- # format_type = request.args.get("format", "json") # TODO: Use for CSV export
-
- result = EvaluationService.get_run_results(run_id)
-
- if not result:
- return get_data_error_result(
- message="Evaluation run not found",
- code=RetCode.DATA_ERROR
- )
-
- # TODO: Implement CSV export
- return get_json_result(data=result)
- except Exception as e:
- return server_error_response(e)
-
-
-# ==================== Real-time Evaluation ====================
-
-@manager.route('/evaluate_single', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("question", "dialog_id")
-async def evaluate_single():
- """
- Evaluate a single question-answer pair in real-time.
-
- Request body:
- {
- "question": "Test question",
- "dialog_id": "dialog_id",
- "reference_answer": "Optional ground truth",
- "relevant_chunk_ids": ["chunk_id1", "chunk_id2"]
- }
- """
- try:
- # req = await get_request_json() # TODO: Use for single evaluation implementation
-
- # TODO: Implement single evaluation
- # This would execute the RAG pipeline and return metrics immediately
-
- return get_json_result(data={
- "answer": "",
- "metrics": {},
- "retrieved_chunks": []
- })
- except Exception as e:
- return server_error_response(e)
diff --git a/api/apps/file_app.py b/api/apps/file_app.py
deleted file mode 100644
index 172b49ff850..00000000000
--- a/api/apps/file_app.py
+++ /dev/null
@@ -1,464 +0,0 @@
-# #
-# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
-# #
-# # Licensed under the Apache License, Version 2.0 (the "License");
-# # you may not use this file except in compliance with the License.
-# # You may obtain a copy of the License at
-# #
-# # http://www.apache.org/licenses/LICENSE-2.0
-# #
-# # Unless required by applicable law or agreed to in writing, software
-# # distributed under the License is distributed on an "AS IS" BASIS,
-# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# # See the License for the specific language governing permissions and
-# # limitations under the License
-# #
-# import logging
-# import os
-# import pathlib
-# import re
-# from quart import request, make_response
-# from api.apps import login_required, current_user
-#
-# from api.common.check_team_permission import check_file_team_permission
-# from api.db.services.document_service import DocumentService
-# from api.db.services.file2document_service import File2DocumentService
-# from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
-# from common.misc_utils import get_uuid, thread_pool_exec
-# from common.constants import RetCode, FileSource
-# from api.db import FileType
-# from api.db.services import duplicate_name
-# from api.db.services.file_service import FileService
-# from api.utils.api_utils import get_json_result, get_request_json
-# from api.utils.file_utils import filename_type
-# from api.utils.web_utils import CONTENT_TYPE_MAP, apply_safe_file_response_headers
-# from common import settings
-#
-# @manager.route('/upload', methods=['POST']) # noqa: F821
-# @login_required
-# # @validate_request("parent_id")
-# async def upload():
-# form = await request.form
-# pf_id = form.get("parent_id")
-#
-# if not pf_id:
-# root_folder = FileService.get_root_folder(current_user.id)
-# pf_id = root_folder["id"]
-#
-# files = await request.files
-# if 'file' not in files:
-# return get_json_result(
-# data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
-# file_objs = files.getlist('file')
-#
-# for file_obj in file_objs:
-# if file_obj.filename == '':
-# return get_json_result(
-# data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
-# file_res = []
-# try:
-# e, pf_folder = FileService.get_by_id(pf_id)
-# if not e:
-# return get_data_error_result( message="Can't find this folder!")
-#
-# async def _handle_single_file(file_obj):
-# MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
-# if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, current_user.id):
-# return get_data_error_result( message="Exceed the maximum file number of a free user!")
-#
-# # split file name path
-# if not file_obj.filename:
-# file_obj_names = [pf_folder.name, file_obj.filename]
-# else:
-# full_path = '/' + file_obj.filename
-# file_obj_names = full_path.split('/')
-# file_len = len(file_obj_names)
-#
-# # get folder
-# file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id])
-# len_id_list = len(file_id_list)
-#
-# # create folder
-# if file_len != len_id_list:
-# e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1])
-# if not e:
-# return get_data_error_result(message="Folder not found!")
-# last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names,
-# len_id_list)
-# else:
-# e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2])
-# if not e:
-# return get_data_error_result(message="Folder not found!")
-# last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names,
-# len_id_list)
-#
-# # file type
-# filetype = filename_type(file_obj_names[file_len - 1])
-# location = file_obj_names[file_len - 1]
-# while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location):
-# location += "_"
-# blob = await thread_pool_exec(file_obj.read)
-# filename = await thread_pool_exec(
-# duplicate_name,
-# FileService.query,
-# name=file_obj_names[file_len - 1],
-# parent_id=last_folder.id)
-# await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob)
-# file_data = {
-# "id": get_uuid(),
-# "parent_id": last_folder.id,
-# "tenant_id": current_user.id,
-# "created_by": current_user.id,
-# "type": filetype,
-# "name": filename,
-# "location": location,
-# "size": len(blob),
-# }
-# inserted = await thread_pool_exec(FileService.insert, file_data)
-# return inserted.to_json()
-#
-# for file_obj in file_objs:
-# res = await _handle_single_file(file_obj)
-# file_res.append(res)
-#
-# return get_json_result(data=file_res)
-# except Exception as e:
-# return server_error_response(e)
-#
-#
-# @manager.route('/create', methods=['POST']) # noqa: F821
-# @login_required
-# @validate_request("name")
-# async def create():
-# req = await get_request_json()
-# pf_id = req.get("parent_id")
-# input_file_type = req.get("type")
-# if not pf_id:
-# root_folder = FileService.get_root_folder(current_user.id)
-# pf_id = root_folder["id"]
-#
-# try:
-# if not FileService.is_parent_folder_exist(pf_id):
-# return get_json_result(
-# data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR)
-# if FileService.query(name=req["name"], parent_id=pf_id):
-# return get_data_error_result(
-# message="Duplicated folder name in the same folder.")
-#
-# if input_file_type == FileType.FOLDER.value:
-# file_type = FileType.FOLDER.value
-# else:
-# file_type = FileType.VIRTUAL.value
-#
-# file = FileService.insert({
-# "id": get_uuid(),
-# "parent_id": pf_id,
-# "tenant_id": current_user.id,
-# "created_by": current_user.id,
-# "name": req["name"],
-# "location": "",
-# "size": 0,
-# "type": file_type
-# })
-#
-# return get_json_result(data=file.to_json())
-# except Exception as e:
-# return server_error_response(e)
-#
-#
-# @manager.route('/list', methods=['GET']) # noqa: F821
-# @login_required
-# def list_files():
-# pf_id = request.args.get("parent_id")
-#
-# keywords = request.args.get("keywords", "")
-#
-# page_number = int(request.args.get("page", 1))
-# items_per_page = int(request.args.get("page_size", 15))
-# orderby = request.args.get("orderby", "create_time")
-# desc = request.args.get("desc", True)
-# if not pf_id:
-# root_folder = FileService.get_root_folder(current_user.id)
-# pf_id = root_folder["id"]
-# FileService.init_knowledgebase_docs(pf_id, current_user.id)
-# try:
-# e, file = FileService.get_by_id(pf_id)
-# if not e:
-# return get_data_error_result(message="Folder not found!")
-#
-# files, total = FileService.get_by_pf_id(
-# current_user.id, pf_id, page_number, items_per_page, orderby, desc, keywords)
-#
-# parent_folder = FileService.get_parent_folder(pf_id)
-# if not parent_folder:
-# return get_json_result(message="File not found!")
-#
-# return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()})
-# except Exception as e:
-# return server_error_response(e)
-#
-#
-# @manager.route('/root_folder', methods=['GET']) # noqa: F821
-# @login_required
-# def get_root_folder():
-# try:
-# root_folder = FileService.get_root_folder(current_user.id)
-# return get_json_result(data={"root_folder": root_folder})
-# except Exception as e:
-# return server_error_response(e)
-#
-#
-# @manager.route('/parent_folder', methods=['GET']) # noqa: F821
-# @login_required
-# def get_parent_folder():
-# file_id = request.args.get("file_id")
-# try:
-# e, file = FileService.get_by_id(file_id)
-# if not e:
-# return get_data_error_result(message="Folder not found!")
-#
-# parent_folder = FileService.get_parent_folder(file_id)
-# return get_json_result(data={"parent_folder": parent_folder.to_json()})
-# except Exception as e:
-# return server_error_response(e)
-#
-#
-# @manager.route('/all_parent_folder', methods=['GET']) # noqa: F821
-# @login_required
-# def get_all_parent_folders():
-# file_id = request.args.get("file_id")
-# try:
-# e, file = FileService.get_by_id(file_id)
-# if not e:
-# return get_data_error_result(message="Folder not found!")
-#
-# parent_folders = FileService.get_all_parent_folders(file_id)
-# parent_folders_res = []
-# for parent_folder in parent_folders:
-# parent_folders_res.append(parent_folder.to_json())
-# return get_json_result(data={"parent_folders": parent_folders_res})
-# except Exception as e:
-# return server_error_response(e)
-#
-#
-# @manager.route("/rm", methods=["POST"]) # noqa: F821
-# @login_required
-# @validate_request("file_ids")
-# async def rm():
-# req = await get_request_json()
-# file_ids = req["file_ids"]
-# uid = current_user.id
-#
-# try:
-# def _delete_single_file(file):
-# try:
-# if file.location:
-# settings.STORAGE_IMPL.rm(file.parent_id, file.location)
-# except Exception as e:
-# logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
-#
-# informs = File2DocumentService.get_by_file_id(file.id)
-# for inform in informs:
-# doc_id = inform.document_id
-# e, doc = DocumentService.get_by_id(doc_id)
-# if e and doc:
-# tenant_id = DocumentService.get_tenant_id(doc_id)
-# if tenant_id:
-# DocumentService.remove_document(doc, tenant_id)
-# File2DocumentService.delete_by_file_id(file.id)
-#
-# FileService.delete(file)
-#
-# def _delete_folder_recursive(folder, tenant_id):
-# sub_files = FileService.list_all_files_by_parent_id(folder.id)
-# for sub_file in sub_files:
-# if sub_file.type == FileType.FOLDER.value:
-# _delete_folder_recursive(sub_file, tenant_id)
-# else:
-# _delete_single_file(sub_file)
-#
-# FileService.delete(folder)
-#
-# def _rm_sync():
-# for file_id in file_ids:
-# e, file = FileService.get_by_id(file_id)
-# if not e or not file:
-# return get_data_error_result(message="File or Folder not found!")
-# if not file.tenant_id:
-# return get_data_error_result(message="Tenant not found!")
-# if not check_file_team_permission(file, uid):
-# return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
-#
-# if file.source_type == FileSource.KNOWLEDGEBASE:
-# continue
-#
-# if file.type == FileType.FOLDER.value:
-# _delete_folder_recursive(file, uid)
-# continue
-#
-# _delete_single_file(file)
-#
-# return get_json_result(data=True)
-#
-# return await thread_pool_exec(_rm_sync)
-#
-# except Exception as e:
-# return server_error_response(e)
-#
-#
-# @manager.route('/rename', methods=['POST']) # noqa: F821
-# @login_required
-# @validate_request("file_id", "name")
-# async def rename():
-# req = await get_request_json()
-# try:
-# e, file = FileService.get_by_id(req["file_id"])
-# if not e:
-# return get_data_error_result(message="File not found!")
-# if not check_file_team_permission(file, current_user.id):
-# return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
-# if file.type != FileType.FOLDER.value \
-# and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
-# file.name.lower()).suffix:
-# return get_json_result(
-# data=False,
-# message="The extension of file can't be changed",
-# code=RetCode.ARGUMENT_ERROR)
-# for file in FileService.query(name=req["name"], pf_id=file.parent_id):
-# if file.name == req["name"]:
-# return get_data_error_result(
-# message="Duplicated file name in the same folder.")
-#
-# if not FileService.update_by_id(
-# req["file_id"], {"name": req["name"]}):
-# return get_data_error_result(
-# message="Database error (File rename)!")
-#
-# informs = File2DocumentService.get_by_file_id(req["file_id"])
-# if informs:
-# if not DocumentService.update_by_id(
-# informs[0].document_id, {"name": req["name"]}):
-# return get_data_error_result(
-# message="Database error (Document rename)!")
-#
-# return get_json_result(data=True)
-# except Exception as e:
-# return server_error_response(e)
-#
-#
-# @manager.route('/get/', methods=['GET']) # noqa: F821
-# @login_required
-# async def get(file_id):
-# try:
-# e, file = FileService.get_by_id(file_id)
-# if not e:
-# return get_data_error_result(message="Document not found!")
-# if not check_file_team_permission(file, current_user.id):
-# return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR)
-#
-# blob = await thread_pool_exec(settings.STORAGE_IMPL.get, file.parent_id, file.location)
-# if not blob:
-# b, n = File2DocumentService.get_storage_address(file_id=file_id)
-# blob = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
-#
-# response = await make_response(blob)
-# ext = re.search(r"\.([^.]+)$", file.name.lower())
-# ext = ext.group(1) if ext else None
-# content_type = None
-# if ext:
-# fallback_prefix = "image" if file.type == FileType.VISUAL.value else "application"
-# content_type = CONTENT_TYPE_MAP.get(ext, f"{fallback_prefix}/{ext}")
-# apply_safe_file_response_headers(response, content_type, ext)
-# return response
-# except Exception as e:
-# return server_error_response(e)
-#
-#
-# @manager.route("/mv", methods=["POST"]) # noqa: F821
-# @login_required
-# @validate_request("src_file_ids", "dest_file_id")
-# async def move():
-# req = await get_request_json()
-# try:
-# file_ids = req["src_file_ids"]
-# dest_parent_id = req["dest_file_id"]
-#
-# ok, dest_folder = FileService.get_by_id(dest_parent_id)
-# if not ok or not dest_folder:
-# return get_data_error_result(message="Parent folder not found!")
-#
-# files = FileService.get_by_ids(file_ids)
-# if not files:
-# return get_data_error_result(message="Source files not found!")
-#
-# files_dict = {f.id: f for f in files}
-#
-# for file_id in file_ids:
-# file = files_dict.get(file_id)
-# if not file:
-# return get_data_error_result(message="File or folder not found!")
-# if not file.tenant_id:
-# return get_data_error_result(message="Tenant not found!")
-# if not check_file_team_permission(file, current_user.id):
-# return get_json_result(
-# data=False,
-# message="No authorization.",
-# code=RetCode.AUTHENTICATION_ERROR,
-# )
-#
-# def _move_entry_recursive(source_file_entry, dest_folder):
-# if source_file_entry.type == FileType.FOLDER.value:
-# existing_folder = FileService.query(name=source_file_entry.name, parent_id=dest_folder.id)
-# if existing_folder:
-# new_folder = existing_folder[0]
-# else:
-# new_folder = FileService.insert(
-# {
-# "id": get_uuid(),
-# "parent_id": dest_folder.id,
-# "tenant_id": source_file_entry.tenant_id,
-# "created_by": current_user.id,
-# "name": source_file_entry.name,
-# "location": "",
-# "size": 0,
-# "type": FileType.FOLDER.value,
-# }
-# )
-#
-# sub_files = FileService.list_all_files_by_parent_id(source_file_entry.id)
-# for sub_file in sub_files:
-# _move_entry_recursive(sub_file, new_folder)
-#
-# FileService.delete_by_id(source_file_entry.id)
-# return
-#
-# old_parent_id = source_file_entry.parent_id
-# old_location = source_file_entry.location
-# filename = source_file_entry.name
-#
-# new_location = filename
-# while settings.STORAGE_IMPL.obj_exist(dest_folder.id, new_location):
-# new_location += "_"
-#
-# try:
-# settings.STORAGE_IMPL.move(old_parent_id, old_location, dest_folder.id, new_location)
-# except Exception as storage_err:
-# raise RuntimeError(f"Move file failed at storage layer: {str(storage_err)}")
-#
-# FileService.update_by_id(
-# source_file_entry.id,
-# {
-# "parent_id": dest_folder.id,
-# "location": new_location,
-# },
-# )
-#
-# def _move_sync():
-# for file in files:
-# _move_entry_recursive(file, dest_folder)
-# return get_json_result(data=True)
-#
-# return await thread_pool_exec(_move_sync)
-#
-# except Exception as e:
-# return server_error_response(e)
diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py
deleted file mode 100644
index 730d63c66ca..00000000000
--- a/api/apps/kb_app.py
+++ /dev/null
@@ -1,1012 +0,0 @@
-#
-# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import logging
-import random
-import re
-
-from common.metadata_utils import turn2jsonschema
-from quart import request
-import numpy as np
-
-from api.db.services.connector_service import Connector2KbService
-from api.db.services.llm_service import LLMBundle
-from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
-from api.db.services.doc_metadata_service import DocMetadataService
-from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
-from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
-from api.db.services.user_service import UserTenantService
-from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_model_config_by_id
-from api.utils.api_utils import (
- get_error_data_result,
- server_error_response,
- get_data_error_result,
- validate_request,
- get_request_json,
-)
-from api.db import VALID_FILE_TYPES
-from api.db.services.knowledgebase_service import KnowledgebaseService
-from api.utils.api_utils import get_json_result
-from rag.nlp import search
-from rag.utils.redis_conn import REDIS_CONN
-from common.constants import RetCode, PipelineTaskType, VALID_TASK_STATUS, LLMType
-from common import settings
-from common.doc_store.doc_store_base import OrderByExpr
-from api.apps import login_required, current_user
-
-"""
-Deprecated, todo delete
-@manager.route('/create', methods=['post']) # noqa: F821
-@login_required
-@validate_request("name")
-async def create():
- req = await get_request_json()
- create_dict = ensure_tenant_model_id_for_params(current_user.id, req)
- e, res = KnowledgebaseService.create_with_name(
- name = create_dict.pop("name", None),
- tenant_id = current_user.id,
- parser_id = create_dict.pop("parser_id", None),
- **create_dict
- )
-
- if not e:
- return res
-
- try:
- if not KnowledgebaseService.save(**res):
- return get_data_error_result()
- return get_json_result(data={"kb_id":res["id"]})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/update', methods=['post']) # noqa: F821
-@login_required
-@validate_request("kb_id", "name", "description", "parser_id")
-@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
-async def update():
- req = await get_request_json()
- update_dict = ensure_tenant_model_id_for_params(current_user.id, req)
- if not isinstance(update_dict["name"], str):
- return get_data_error_result(message="Dataset name must be string.")
- if update_dict["name"].strip() == "":
- return get_data_error_result(message="Dataset name can't be empty.")
- if len(update_dict["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
- return get_data_error_result(
- message=f"Dataset name length is {len(update_dict['name'])} which is large than {DATASET_NAME_LIMIT}")
- update_dict["name"] = update_dict["name"].strip()
- if settings.DOC_ENGINE_INFINITY:
- parser_id = update_dict.get("parser_id")
- if isinstance(parser_id, str) and parser_id.lower() == "tag":
- return get_json_result(
- code=RetCode.OPERATING_ERROR,
- message="The chunking method Tag has not been supported by Infinity yet.",
- data=False,
- )
- if "pagerank" in update_dict and update_dict["pagerank"] > 0:
- return get_json_result(
- code=RetCode.DATA_ERROR,
- message="'pagerank' can only be set when doc_engine is elasticsearch",
- data=False,
- )
-
- if not KnowledgebaseService.accessible4deletion(update_dict["kb_id"], current_user.id):
- return get_json_result(
- data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR
- )
- try:
- if not KnowledgebaseService.query(
- created_by=current_user.id, id=update_dict["kb_id"]):
- return get_json_result(
- data=False, message='Only owner of dataset authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
-
- e, kb = KnowledgebaseService.get_by_id(update_dict["kb_id"])
-
- # Rename folder in FileService
- if e and update_dict["name"].lower() != kb.name.lower():
- FileService.filter_update(
- [
- File.tenant_id == kb.tenant_id,
- File.source_type == FileSource.KNOWLEDGEBASE,
- File.type == "folder",
- File.name == kb.name,
- ],
- {"name": update_dict["name"]},
- )
-
- if not e:
- return get_data_error_result(
- message="Can't find this dataset!")
-
- if update_dict["name"].lower() != kb.name.lower() \
- and len(
- KnowledgebaseService.query(name=update_dict["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1:
- return get_data_error_result(
- message="Duplicated dataset name.")
-
- del update_dict["kb_id"]
- connectors = []
- if "connectors" in update_dict:
- connectors = update_dict["connectors"]
- del update_dict["connectors"]
- if not KnowledgebaseService.update_by_id(kb.id, update_dict):
- return get_data_error_result()
-
- if kb.pagerank != update_dict.get("pagerank", 0):
- if update_dict.get("pagerank", 0) > 0:
- await thread_pool_exec(
- settings.docStoreConn.update,
- {"kb_id": kb.id},
- {PAGERANK_FLD: update_dict["pagerank"]},
- search.index_name(kb.tenant_id),
- kb.id,
- )
- else:
- # Elasticsearch requires PAGERANK_FLD be non-zero!
- await thread_pool_exec(
- settings.docStoreConn.update,
- {"exists": PAGERANK_FLD},
- {"remove": PAGERANK_FLD},
- search.index_name(kb.tenant_id),
- kb.id,
- )
-
- e, kb = KnowledgebaseService.get_by_id(kb.id)
- if not e:
- return get_data_error_result(
- message="Database error (Knowledgebase rename)!")
- errors = Connector2KbService.link_connectors(kb.id, [conn for conn in connectors], current_user.id)
- if errors:
- logging.error("Link KB errors: ", errors)
- kb = kb.to_dict()
- kb.update(update_dict)
- kb["connectors"] = connectors
-
- return get_json_result(data=kb)
- except Exception as e:
- return server_error_response(e)
-"""
-
-@manager.route('/update_metadata_setting', methods=['post']) # noqa: F821
-@login_required
-@validate_request("kb_id", "metadata")
-async def update_metadata_setting():
- req = await get_request_json()
- e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
- if not e:
- return get_data_error_result(
- message="Database error (Knowledgebase rename)!")
- kb = kb.to_dict()
- kb["parser_config"]["metadata"] = req["metadata"]
- kb["parser_config"]["enable_metadata"] = req.get("enable_metadata", True)
- KnowledgebaseService.update_by_id(kb["id"], kb)
- return get_json_result(data=kb)
-
-
-@manager.route('/detail', methods=['GET']) # noqa: F821
-@login_required
-def detail():
- kb_id = request.args["kb_id"]
- try:
- tenants = UserTenantService.query(user_id=current_user.id)
- for tenant in tenants:
- if KnowledgebaseService.query(
- tenant_id=tenant.tenant_id, id=kb_id):
- break
- else:
- return get_json_result(
- data=False, message='Only owner of dataset authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
- kb = KnowledgebaseService.get_detail(kb_id)
- if not kb:
- return get_data_error_result(
- message="Can't find this dataset!")
- kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[])
- kb["connectors"] = Connector2KbService.list_connectors(kb_id)
- if kb["parser_config"].get("metadata"):
- kb["parser_config"]["metadata"] = turn2jsonschema(kb["parser_config"]["metadata"])
-
- for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]:
- if finish_at := kb.get(key):
- kb[key] = finish_at.strftime("%Y-%m-%d %H:%M:%S")
- return get_json_result(data=kb)
- except Exception as e:
- return server_error_response(e)
-
-"""
-Deprecated, todo delete
-@manager.route('/list', methods=['POST']) # noqa: F821
-@login_required
-async def list_kbs():
- args = request.args
- keywords = args.get("keywords", "")
- page_number = int(args.get("page", 0))
- items_per_page = int(args.get("page_size", 0))
- parser_id = args.get("parser_id")
- orderby = args.get("orderby", "create_time")
- if args.get("desc", "true").lower() == "false":
- desc = False
- else:
- desc = True
-
- req = await get_request_json()
- owner_ids = req.get("owner_ids", [])
- try:
- if not owner_ids:
- tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
- tenants = [m["tenant_id"] for m in tenants]
- kbs, total = KnowledgebaseService.get_by_tenant_ids(
- tenants, current_user.id, page_number,
- items_per_page, orderby, desc, keywords, parser_id)
- else:
- tenants = owner_ids
- kbs, total = KnowledgebaseService.get_by_tenant_ids(
- tenants, current_user.id, 0,
- 0, orderby, desc, keywords, parser_id)
- kbs = [kb for kb in kbs if kb["tenant_id"] in tenants]
- total = len(kbs)
- if page_number and items_per_page:
- kbs = kbs[(page_number-1)*items_per_page:page_number*items_per_page]
- return get_json_result(data={"kbs": kbs, "total": total})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/rm', methods=['post']) # noqa: F821
-@login_required
-@validate_request("kb_id")
-async def rm():
- req = await get_request_json()
- uid = current_user.id
- if not KnowledgebaseService.accessible4deletion(req["kb_id"], uid):
- return get_json_result(
- data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR
- )
- try:
- kbs = KnowledgebaseService.query(
- created_by=uid, id=req["kb_id"])
- if not kbs:
- return get_json_result(
- data=False, message='Only owner of dataset authorized for this operation.',
- code=RetCode.OPERATING_ERROR)
-
- def _rm_sync():
- for doc in DocumentService.query(kb_id=req["kb_id"]):
- if not DocumentService.remove_document(doc, kbs[0].tenant_id):
- return get_data_error_result(
- message="Database error (Document removal)!")
- f2d = File2DocumentService.get_by_document_id(doc.id)
- if f2d:
- FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
- File2DocumentService.delete_by_document_id(doc.id)
- FileService.filter_delete(
- [
- File.tenant_id == kbs[0].tenant_id,
- File.source_type == FileSource.KNOWLEDGEBASE,
- File.type == "folder",
- File.name == kbs[0].name,
- ]
- )
- # Delete the table BEFORE deleting the database record
- for kb in kbs:
- try:
- settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
- settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
- logging.info(f"Dropped index for dataset {kb.id}")
- except Exception as e:
- logging.error(f"Failed to drop index for dataset {kb.id}: {e}")
-
- if not KnowledgebaseService.delete_by_id(req["kb_id"]):
- return get_data_error_result(
- message="Database error (Knowledgebase removal)!")
- for kb in kbs:
- if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
- settings.STORAGE_IMPL.remove_bucket(kb.id)
- return get_json_result(data=True)
-
- return await thread_pool_exec(_rm_sync)
- except Exception as e:
- return server_error_response(e)
-"""
-
-@manager.route('//tags', methods=['GET']) # noqa: F821
-@login_required
-def list_tags(kb_id):
- if not KnowledgebaseService.accessible(kb_id, current_user.id):
- return get_json_result(
- data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR
- )
-
- tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
- tags = []
- for tenant in tenants:
- tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id])
- return get_json_result(data=tags)
-
-
-@manager.route('/tags', methods=['GET']) # noqa: F821
-@login_required
-def list_tags_from_kbs():
- kb_ids = request.args.get("kb_ids", "").split(",")
- for kb_id in kb_ids:
- if not KnowledgebaseService.accessible(kb_id, current_user.id):
- return get_json_result(
- data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR
- )
-
- tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
- tags = []
- for tenant in tenants:
- tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids)
- return get_json_result(data=tags)
-
-
-@manager.route('//rm_tags', methods=['POST']) # noqa: F821
-@login_required
-async def rm_tags(kb_id):
- req = await get_request_json()
- if not KnowledgebaseService.accessible(kb_id, current_user.id):
- return get_json_result(
- data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR
- )
- e, kb = KnowledgebaseService.get_by_id(kb_id)
-
- for t in req["tags"]:
- settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
- {"remove": {"tag_kwd": t}},
- search.index_name(kb.tenant_id),
- kb_id)
- return get_json_result(data=True)
-
-
-@manager.route('//rename_tag', methods=['POST']) # noqa: F821
-@login_required
-async def rename_tags(kb_id):
- req = await get_request_json()
- if not KnowledgebaseService.accessible(kb_id, current_user.id):
- return get_json_result(
- data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR
- )
- e, kb = KnowledgebaseService.get_by_id(kb_id)
-
- settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
- {"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
- search.index_name(kb.tenant_id),
- kb_id)
- return get_json_result(data=True)
-
-"""
-Deprecated, todo delete
-@manager.route('//knowledge_graph', methods=['GET']) # noqa: F821
-@login_required
-async def knowledge_graph(kb_id):
- if not KnowledgebaseService.accessible(kb_id, current_user.id):
- return get_json_result(
- data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR
- )
- _, kb = KnowledgebaseService.get_by_id(kb_id)
- req = {
- "kb_id": [kb_id],
- "knowledge_graph_kwd": ["graph"]
- }
-
- obj = {"graph": {}, "mind_map": {}}
- if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
- return get_json_result(data=obj)
- sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
- if not len(sres.ids):
- return get_json_result(data=obj)
-
- for id in sres.ids[:1]:
- ty = sres.field[id]["knowledge_graph_kwd"]
- try:
- content_json = json.loads(sres.field[id]["content_with_weight"])
- except Exception:
- continue
-
- obj[ty] = content_json
-
- if "nodes" in obj["graph"]:
- obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
- if "edges" in obj["graph"]:
- node_id_set = { o["id"] for o in obj["graph"]["nodes"] }
- filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
- obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
- return get_json_result(data=obj)
-
-
-@manager.route('//knowledge_graph', methods=['DELETE']) # noqa: F821
-@login_required
-def delete_knowledge_graph(kb_id):
- if not KnowledgebaseService.accessible(kb_id, current_user.id):
- return get_json_result(
- data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR
- )
- _, kb = KnowledgebaseService.get_by_id(kb_id)
- settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
-
- return get_json_result(data=True)
-"""
-
-@manager.route("/get_meta", methods=["GET"]) # noqa: F821
-@login_required
-def get_meta():
- kb_ids = request.args.get("kb_ids", "").split(",")
- for kb_id in kb_ids:
- if not KnowledgebaseService.accessible(kb_id, current_user.id):
- return get_json_result(
- data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR
- )
- return get_json_result(data=DocMetadataService.get_flatted_meta_by_kbs(kb_ids))
-
-
-@manager.route("/basic_info", methods=["GET"]) # noqa: F821
-@login_required
-def get_basic_info():
- kb_id = request.args.get("kb_id", "")
- if not KnowledgebaseService.accessible(kb_id, current_user.id):
- return get_json_result(
- data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR
- )
-
- basic_info = DocumentService.knowledgebase_basic_info(kb_id)
-
- return get_json_result(data=basic_info)
-
-
-@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821
-@login_required
-async def list_pipeline_logs():
- kb_id = request.args.get("kb_id")
- if not kb_id:
- return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
-
- keywords = request.args.get("keywords", "")
-
- page_number = int(request.args.get("page", 0))
- items_per_page = int(request.args.get("page_size", 0))
- orderby = request.args.get("orderby", "create_time")
- if request.args.get("desc", "true").lower() == "false":
- desc = False
- else:
- desc = True
- create_date_from = request.args.get("create_date_from", "")
- create_date_to = request.args.get("create_date_to", "")
- if create_date_to > create_date_from:
- return get_data_error_result(message="Create data filter is abnormal.")
-
- req = await get_request_json()
-
- operation_status = req.get("operation_status", [])
- if operation_status:
- invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
- if invalid_status:
- return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
-
- types = req.get("types", [])
- if types:
- invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
- if invalid_types:
- return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
-
- suffix = req.get("suffix", [])
-
- try:
- logs, count = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to)
- return get_json_result(data={"total": count, "logs": logs})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
-@login_required
-async def list_pipeline_dataset_logs():
- kb_id = request.args.get("kb_id")
- if not kb_id:
- return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
-
- page_number = int(request.args.get("page", 0))
- items_per_page = int(request.args.get("page_size", 0))
- orderby = request.args.get("orderby", "create_time")
- if request.args.get("desc", "true").lower() == "false":
- desc = False
- else:
- desc = True
- create_date_from = request.args.get("create_date_from", "")
- create_date_to = request.args.get("create_date_to", "")
- if create_date_to > create_date_from:
- return get_data_error_result(message="Create data filter is abnormal.")
-
- req = await get_request_json()
-
- operation_status = req.get("operation_status", [])
- if operation_status:
- invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
- if invalid_status:
- return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
-
- try:
- logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to)
- return get_json_result(data={"total": tol, "logs": logs})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821
-@login_required
-async def delete_pipeline_logs():
- kb_id = request.args.get("kb_id")
- if not kb_id:
- return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
-
- req = await get_request_json()
- log_ids = req.get("log_ids", [])
-
- PipelineOperationLogService.delete_by_ids(log_ids)
-
- return get_json_result(data=True)
-
-
-@manager.route("/pipeline_log_detail", methods=["GET"]) # noqa: F821
-@login_required
-def pipeline_log_detail():
- log_id = request.args.get("log_id")
- if not log_id:
- return get_json_result(data=False, message='Lack of "Pipeline log ID"', code=RetCode.ARGUMENT_ERROR)
-
- ok, log = PipelineOperationLogService.get_by_id(log_id)
- if not ok:
- return get_data_error_result(message="Invalid pipeline log ID")
-
- return get_json_result(data=log.to_dict())
-
-
-"""
-Deprecated, todo delete
-@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
-@login_required
-async def run_graphrag():
- req = await get_request_json()
-
- kb_id = req.get("kb_id", "")
- if not kb_id:
- return get_error_data_result(message='Lack of "KB ID"')
-
- ok, kb = KnowledgebaseService.get_by_id(kb_id)
- if not ok:
- return get_error_data_result(message="Invalid Knowledgebase ID")
-
- task_id = kb.graphrag_task_id
- if task_id:
- ok, task = TaskService.get_by_id(task_id)
- if not ok:
- logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
-
- if task and task.progress not in [-1, 1]:
- return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
-
- documents, _ = DocumentService.get_by_kb_id(
- kb_id=kb_id,
- page_number=0,
- items_per_page=0,
- orderby="create_time",
- desc=False,
- keywords="",
- run_status=[],
- types=[],
- suffix=[],
- )
- if not documents:
- return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
-
- sample_document = documents[0]
- document_ids = [document["id"] for document in documents]
-
- task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
-
- if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
- logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
-
- return get_json_result(data={"graphrag_task_id": task_id})
-
-
-@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821
-@login_required
-def trace_graphrag():
- kb_id = request.args.get("kb_id", "")
- if not kb_id:
- return get_error_data_result(message='Lack of "KB ID"')
-
- ok, kb = KnowledgebaseService.get_by_id(kb_id)
- if not ok:
- return get_error_data_result(message="Invalid Knowledgebase ID")
-
- task_id = kb.graphrag_task_id
- if not task_id:
- return get_json_result(data={})
-
- ok, task = TaskService.get_by_id(task_id)
- if not ok:
- return get_json_result(data={})
-
- return get_json_result(data=task.to_dict())
-
-
-@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
-@login_required
-async def run_raptor():
- req = await get_request_json()
-
- kb_id = req.get("kb_id", "")
- if not kb_id:
- return get_error_data_result(message='Lack of "KB ID"')
-
- ok, kb = KnowledgebaseService.get_by_id(kb_id)
- if not ok:
- return get_error_data_result(message="Invalid Knowledgebase ID")
-
- task_id = kb.raptor_task_id
- if task_id:
- ok, task = TaskService.get_by_id(task_id)
- if not ok:
- logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}")
-
- if task and task.progress not in [-1, 1]:
- return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
-
- documents, _ = DocumentService.get_by_kb_id(
- kb_id=kb_id,
- page_number=0,
- items_per_page=0,
- orderby="create_time",
- desc=False,
- keywords="",
- run_status=[],
- types=[],
- suffix=[],
- )
- if not documents:
- return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
-
- sample_document = documents[0]
- document_ids = [document["id"] for document in documents]
-
- task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
-
- if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
- logging.warning(f"Cannot save raptor_task_id for kb {kb_id}")
-
- return get_json_result(data={"raptor_task_id": task_id})
-
-
-@manager.route("/trace_raptor", methods=["GET"]) # noqa: F821
-@login_required
-def trace_raptor():
- kb_id = request.args.get("kb_id", "")
- if not kb_id:
- return get_error_data_result(message='Lack of "KB ID"')
-
- ok, kb = KnowledgebaseService.get_by_id(kb_id)
- if not ok:
- return get_error_data_result(message="Invalid Knowledgebase ID")
-
- task_id = kb.raptor_task_id
- if not task_id:
- return get_json_result(data={})
-
- ok, task = TaskService.get_by_id(task_id)
- if not ok:
- return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
-
- return get_json_result(data=task.to_dict())
-"""
-
-@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
-@login_required
-async def run_mindmap():
- req = await get_request_json()
-
- kb_id = req.get("kb_id", "")
- if not kb_id:
- return get_error_data_result(message='Lack of "KB ID"')
-
- ok, kb = KnowledgebaseService.get_by_id(kb_id)
- if not ok:
- return get_error_data_result(message="Invalid Knowledgebase ID")
-
- task_id = kb.mindmap_task_id
- if task_id:
- ok, task = TaskService.get_by_id(task_id)
- if not ok:
- logging.warning(f"A valid Mindmap task id is expected for kb {kb_id}")
-
- if task and task.progress not in [-1, 1]:
- return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Mindmap Task is already running.")
-
- documents, _ = DocumentService.get_by_kb_id(
- kb_id=kb_id,
- page_number=0,
- items_per_page=0,
- orderby="create_time",
- desc=False,
- keywords="",
- run_status=[],
- types=[],
- suffix=[],
- )
- if not documents:
- return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
-
- sample_document = documents[0]
- document_ids = [document["id"] for document in documents]
-
- task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
-
- if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}):
- logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}")
-
- return get_json_result(data={"mindmap_task_id": task_id})
-
-
-@manager.route("/trace_mindmap", methods=["GET"]) # noqa: F821
-@login_required
-def trace_mindmap():
- kb_id = request.args.get("kb_id", "")
- if not kb_id:
- return get_error_data_result(message='Lack of "KB ID"')
-
- ok, kb = KnowledgebaseService.get_by_id(kb_id)
- if not ok:
- return get_error_data_result(message="Invalid Knowledgebase ID")
-
- task_id = kb.mindmap_task_id
- if not task_id:
- return get_json_result(data={})
-
- ok, task = TaskService.get_by_id(task_id)
- if not ok:
- return get_error_data_result(message="Mindmap Task Not Found or Error Occurred")
-
- return get_json_result(data=task.to_dict())
-
-
-@manager.route("/unbind_task", methods=["DELETE"]) # noqa: F821
-@login_required
-def delete_kb_task():
- kb_id = request.args.get("kb_id", "")
- if not kb_id:
- return get_error_data_result(message='Lack of "KB ID"')
- ok, kb = KnowledgebaseService.get_by_id(kb_id)
- if not ok:
- return get_json_result(data=True)
-
- pipeline_task_type = request.args.get("pipeline_task_type", "")
- if not pipeline_task_type or pipeline_task_type not in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
- return get_error_data_result(message="Invalid task type")
-
- def cancel_task(task_id):
- REDIS_CONN.set(f"{task_id}-cancel", "x")
-
- kb_task_id_field: str = ""
- kb_task_finish_at: str = ""
- match pipeline_task_type:
- case PipelineTaskType.GRAPH_RAG:
- kb_task_id_field = "graphrag_task_id"
- task_id = kb.graphrag_task_id
- kb_task_finish_at = "graphrag_task_finish_at"
- cancel_task(task_id)
- settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
- case PipelineTaskType.RAPTOR:
- kb_task_id_field = "raptor_task_id"
- task_id = kb.raptor_task_id
- kb_task_finish_at = "raptor_task_finish_at"
- cancel_task(task_id)
- settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
- case PipelineTaskType.MINDMAP:
- kb_task_id_field = "mindmap_task_id"
- task_id = kb.mindmap_task_id
- kb_task_finish_at = "mindmap_task_finish_at"
- cancel_task(task_id)
- case _:
- return get_error_data_result(message="Internal Error: Invalid task type")
-
-
- ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id_field: "", kb_task_finish_at: None})
- if not ok:
- return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
-
- return get_json_result(data=True)
-
-@manager.route("/check_embedding", methods=["post"]) # noqa: F821
-@login_required
-async def check_embedding():
-
- def _guess_vec_field(src: dict) -> str | None:
- for k in src or {}:
- if k.endswith("_vec"):
- return k
- return None
-
- def _as_float_vec(v):
- if v is None:
- return []
- if isinstance(v, str):
- return [float(x) for x in v.split("\t") if x != ""]
- if isinstance(v, (list, tuple, np.ndarray)):
- return [float(x) for x in v]
- return []
-
- def _to_1d(x):
- a = np.asarray(x, dtype=np.float32)
- return a.reshape(-1)
-
- def _cos_sim(a, b, eps=1e-12):
- a = _to_1d(a)
- b = _to_1d(b)
- na = np.linalg.norm(a)
- nb = np.linalg.norm(b)
- if na < eps or nb < eps:
- return 0.0
- return float(np.dot(a, b) / (na * nb))
-
- def sample_random_chunks_with_vectors(
- docStoreConn,
- tenant_id: str,
- kb_id: str,
- n: int = 5,
- base_fields=("docnm_kwd","doc_id","content_with_weight","page_num_int","position_int","top_int"),
- ):
- index_nm = search.index_name(tenant_id)
-
- res0 = docStoreConn.search(
- select_fields=[], highlight_fields=[],
- condition={"kb_id": kb_id, "available_int": 1},
- match_expressions=[], order_by=OrderByExpr(),
- offset=0, limit=1,
- index_names=index_nm, knowledgebase_ids=[kb_id]
- )
- total = docStoreConn.get_total(res0)
- if total <= 0:
- return []
-
- n = min(n, total)
- offsets = sorted(random.sample(range(min(total,1000)), n))
- out = []
-
- for off in offsets:
- res1 = docStoreConn.search(
- select_fields=list(base_fields),
- highlight_fields=[],
- condition={"kb_id": kb_id, "available_int": 1},
- match_expressions=[], order_by=OrderByExpr(),
- offset=off, limit=1,
- index_names=index_nm, knowledgebase_ids=[kb_id]
- )
- ids = docStoreConn.get_doc_ids(res1)
- if not ids:
- continue
-
- cid = ids[0]
- full_doc = docStoreConn.get(cid, index_nm, [kb_id]) or {}
- vec_field = _guess_vec_field(full_doc)
- vec = _as_float_vec(full_doc.get(vec_field))
-
- out.append({
- "chunk_id": cid,
- "kb_id": kb_id,
- "doc_id": full_doc.get("doc_id"),
- "doc_name": full_doc.get("docnm_kwd"),
- "vector_field": vec_field,
- "vector_dim": len(vec),
- "vector": vec,
- "page_num_int": full_doc.get("page_num_int"),
- "position_int": full_doc.get("position_int"),
- "top_int": full_doc.get("top_int"),
- "content_with_weight": full_doc.get("content_with_weight") or "",
- "question_kwd": full_doc.get("question_kwd") or []
- })
- return out
-
- def _clean(s: str) -> str:
- s = re.sub(r"?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
- return s if s else "None"
- req = await get_request_json()
- kb_id = req.get("kb_id", "")
- tenant_embd_id = req.get("tenant_embd_id")
- embd_id = req.get("embd_id", "")
- n = int(req.get("check_num", 5))
- _, kb = KnowledgebaseService.get_by_id(kb_id)
- tenant_id = kb.tenant_id
- if tenant_embd_id:
- embd_model_config = get_model_config_by_id(tenant_embd_id)
- elif embd_id:
- embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id)
- else:
- return get_error_data_result("`tenant_embd_id` or `embd_id` is required.")
- emb_mdl = LLMBundle(tenant_id, embd_model_config)
- samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
-
- results, eff_sims = [], []
- for ck in samples:
- title = ck.get("doc_name") or "Title"
- txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
- txt_in = _clean(txt_in)
- if not txt_in:
- results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
- continue
-
- if not ck.get("vector"):
- results.append({"chunk_id": ck["chunk_id"], "reason": "no_stored_vector"})
- continue
-
- try:
- v, _ = emb_mdl.encode([title, txt_in])
- assert len(v[1]) == len(ck["vector"]), f"The dimension ({len(v[1])}) of given embedding model is different from the original ({len(ck['vector'])})"
- sim_content = _cos_sim(v[1], ck["vector"])
- title_w = 0.1
- qv_mix = title_w * v[0] + (1 - title_w) * v[1]
- sim_mix = _cos_sim(qv_mix, ck["vector"])
- sim = sim_content
- mode = "content_only"
- if sim_mix > sim:
- sim = sim_mix
- mode = "title+content"
- except Exception as e:
- return get_error_data_result(message=f"Embedding failure. {e}")
-
- eff_sims.append(sim)
- results.append({
- "chunk_id": ck["chunk_id"],
- "doc_id": ck["doc_id"],
- "doc_name": ck["doc_name"],
- "vector_field": ck["vector_field"],
- "vector_dim": ck["vector_dim"],
- "cos_sim": round(sim, 6),
- })
-
- summary = {
- "kb_id": kb_id,
- "model": embd_id,
- "sampled": len(samples),
- "valid": len(eff_sims),
- "avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
- "min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
- "max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
- "match_mode": mode,
- }
- if summary["avg_cos_sim"] > 0.9:
- return get_json_result(data={"summary": summary, "results": results})
- return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py
index 91c20fddfa7..583e05af7c9 100644
--- a/api/apps/llm_app.py
+++ b/api/apps/llm_app.py
@@ -29,6 +29,23 @@
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel
+def _resolve_my_llm_is_tools(o_dict: dict) -> bool:
+ decode_api_key_config = getattr(TenantLLMService, "_decode_api_key_config", None)
+ if callable(decode_api_key_config):
+ _, is_tools, _ = decode_api_key_config(o_dict.get("api_key", ""))
+ if is_tools is not None:
+ return bool(is_tools)
+
+ try:
+ base_name, fid = TenantLLMService.split_model_name_and_factory(o_dict["llm_name"])
+ llm_cfg = LLMService.query(llm_name=base_name, fid=fid) if fid else LLMService.query(llm_name=base_name)
+ if not llm_cfg and fid:
+ llm_cfg = LLMService.query(llm_name=base_name)
+ return bool(llm_cfg[0].is_tools) if llm_cfg else False
+ except Exception:
+ return False
+
+
@manager.route("/factories", methods=["GET"]) # noqa: F821
@login_required
def factories():
@@ -185,7 +202,9 @@ def apikey_json(keys):
elif factory == "Bedrock":
# For Bedrock, due to its special authentication method
# Assemble bedrock_ak, bedrock_sk, bedrock_region
- api_key = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"])
+ # Write into req["api_key"] to prevent the "existing key" override logic from replacing it
+ req["api_key"] = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"])
+ api_key = req["api_key"]
elif factory == "LocalAI":
llm_name += "___LocalAI"
@@ -226,6 +245,22 @@ def apikey_json(keys):
elif factory == "PaddleOCR":
api_key = apikey_json(["api_key", "provider_order"])
+ elif factory == "OpenDataLoader":
+ api_key = apikey_json(["api_key", "provider_order"])
+
+ existing_llm = None
+ existing_api_key = None
+ if req.get("api_key") is None:
+ existing_llms = TenantLLMService.query(tenant_id=current_user.id, llm_factory=factory, llm_name=llm_name)
+ if existing_llms:
+ existing_llm = existing_llms[0]
+ existing_api_key, _, existing_api_key_payload = TenantLLMService._decode_api_key_config(existing_llm.api_key)
+ if existing_api_key_payload is not None:
+ existing_api_key = existing_api_key_payload
+
+ if req.get("api_key") is None:
+ api_key = existing_api_key if existing_api_key is not None else "x"
+
llm = {
"tenant_id": current_user.id,
"llm_factory": factory,
@@ -350,6 +385,9 @@ def drain_tts():
if msg:
return get_data_error_result(message=msg)
+ if "is_tools" in req:
+ llm["api_key"] = TenantLLMService._encode_api_key_config(llm["api_key"], bool(req["is_tools"]))
+
if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm):
TenantLLMService.save(**llm)
@@ -390,6 +428,7 @@ async def delete_factory():
def my_llms():
try:
TenantLLMService.ensure_mineru_from_env(current_user.id)
+ TenantLLMService.ensure_opendataloader_from_env(current_user.id)
include_details = request.args.get("include_details", "false").lower() == "true"
if include_details:
@@ -417,6 +456,7 @@ def my_llms():
"api_base": o_dict["api_base"] or "",
"max_tokens": o_dict["max_tokens"] or 8192,
"status": o_dict["status"] or "1",
+ "is_tools": _resolve_my_llm_is_tools(o_dict),
}
)
else:
diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py
new file mode 100644
index 00000000000..c0c6c604af7
--- /dev/null
+++ b/api/apps/restful_apis/agent_api.py
@@ -0,0 +1,1892 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import asyncio
+import base64
+import copy
+import hashlib
+import hmac
+import inspect
+import ipaddress
+import json
+import logging
+import time
+from functools import partial, wraps
+
+import jwt
+from quart import Response, jsonify, request
+
+from agent.canvas import Canvas
+from agent.component import LLM
+from agent.dsl_migration import normalize_chunker_dsl
+from api.apps import current_user, login_required
+from api.apps.services.canvas_replica_service import CanvasReplicaService
+from api.db import CanvasCategory
+from api.db.db_models import Task
+from api.db.services.api_service import API4ConversationService
+from api.db.services.canvas_service import (
+ CanvasTemplateService,
+ UserCanvasService,
+ completion as agent_completion,
+ completion_openai,
+)
+from api.db.services.document_service import DocumentService
+from api.db.services.file_service import FileService
+from api.db.services.knowledgebase_service import KnowledgebaseService
+from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
+from api.db.services.task_service import CANVAS_DEBUG_DOC_ID, TaskService, queue_dataflow
+from api.db.services.user_service import TenantService, UserService
+from api.db.services.user_canvas_version import UserCanvasVersionService
+from api.utils.api_utils import (
+ add_tenant_id_to_kwargs,
+ get_data_error_result,
+ get_json_result,
+ get_result,
+ get_request_json,
+ server_error_response,
+ validate_request,
+)
+from common import settings
+from common.constants import RetCode
+from common.misc_utils import get_uuid, thread_pool_exec
+from peewee import MySQLDatabase, PostgresqlDatabase
+from rag.flow.pipeline import Pipeline
+from rag.nlp import search
+from rag.utils.redis_conn import REDIS_CONN
+
+
+def _require_canvas_access_sync(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if not UserCanvasService.accessible(kwargs.get('agent_id'), kwargs.get('tenant_id')):
+ return get_json_result(data=False, message="Make sure you have permission to access the agent.", code=RetCode.OPERATING_ERROR)
+ return func(*args, **kwargs)
+ return wrapper
+
+
+def _require_canvas_access_async(func):
+ @wraps(func)
+ async def wrapper(*args, **kwargs):
+ agent_id = kwargs.get('agent_id')
+ tenant_id = kwargs.get('tenant_id')
+ if not await thread_pool_exec(UserCanvasService.accessible, agent_id, tenant_id):
+ return get_json_result(data=False, message="Make sure you have permission to access the agent.", code=RetCode.OPERATING_ERROR)
+ return await func(*args, **kwargs)
+ return wrapper
+
+
+def _require_canvas_owner_sync(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if not UserCanvasService.query(user_id=kwargs.get('tenant_id'), id=kwargs.get('agent_id')):
+ return get_json_result(data=False, message="Only the owner of the agent is authorized for this operation.", code=RetCode.OPERATING_ERROR)
+ return func(*args, **kwargs)
+ return wrapper
+
+
+def _get_user_nickname(user_id: str) -> str:
+ exists, user = UserService.get_by_id(user_id)
+ if not exists:
+ return user_id
+ return str(getattr(user, "nickname", "") or user_id)
+
+
+def _build_sse_response(body):
+ resp = Response(body, mimetype="text/event-stream")
+ resp.headers.add_header("Cache-control", "no-cache")
+ resp.headers.add_header("Connection", "keep-alive")
+ resp.headers.add_header("X-Accel-Buffering", "no")
+ resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
+ return resp
+
+
+def _normalize_agent_session(conv):
+ conv["messages"] = conv.pop("message")
+ for info in conv["messages"]:
+ if "prompt" in info:
+ info.pop("prompt")
+ conv["agent_id"] = conv.pop("dialog_id")
+ if isinstance(conv["reference"], dict):
+ if "chunks" in conv["reference"]:
+ conv["reference"] = [conv["reference"]]
+ else:
+ conv["reference"] = [value for _, value in sorted(conv["reference"].items(), key=lambda item: int(item[0]))]
+
+ if conv["reference"]:
+ messages = [message for i, message in enumerate(conv["messages"]) if i != 0 and message["role"] != "user"]
+ for message, reference in zip(messages, conv["reference"]):
+ chunks = reference["chunks"]
+ message["reference"] = [
+ {
+ "id": chunk.get("chunk_id", chunk.get("id")),
+ "content": chunk.get("content_with_weight", chunk.get("content")),
+ "document_id": chunk.get("doc_id", chunk.get("document_id")),
+ "document_name": chunk.get("docnm_kwd", chunk.get("document_name")),
+ "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
+ "image_id": chunk.get("image_id", chunk.get("img_id")),
+ "positions": chunk.get("positions", chunk.get("position_int")),
+ }
+ for chunk in chunks
+ ]
+ del conv["reference"]
+ return conv
+
+
+def _agent_session_list_result(data, total):
+ return jsonify({"code": RetCode.SUCCESS, "message": "success", "data": data, "total": total})
+
+
+@manager.route("/agents//sessions", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_access_sync
+def list_agent_sessions(agent_id, tenant_id):
+ session_id = request.args.get("id")
+ user_id = request.args.get("user_id")
+ page_number = int(request.args.get("page", 1))
+ items_per_page = int(request.args.get("page_size", 30))
+ keywords = request.args.get("keywords")
+ from_date = request.args.get("from_date")
+ to_date = request.args.get("to_date")
+ orderby = request.args.get("orderby", "update_time")
+ exp_user_id = request.args.get("exp_user_id")
+ desc = request.args.get("desc") not in {"False", "false"}
+
+ if exp_user_id:
+ sessions = API4ConversationService.get_names(agent_id, exp_user_id)
+ return _agent_session_list_result(sessions, len(sessions))
+
+ include_dsl = request.args.get("dsl") not in {"False", "false"}
+ total, sessions = API4ConversationService.get_list(
+ agent_id,
+ tenant_id,
+ page_number,
+ items_per_page,
+ orderby,
+ desc,
+ session_id,
+ user_id,
+ include_dsl,
+ keywords,
+ from_date,
+ to_date,
+ exp_user_id=exp_user_id,
+ )
+ sessions = [_normalize_agent_session(session) for session in sessions]
+ return _agent_session_list_result(sessions, total)
+
+
+@manager.route("/agents//sessions", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_access_async
+async def create_agent_session(agent_id, tenant_id):
+ req = await get_request_json()
+ user_id = req.get("user_id") or request.args.get("user_id", tenant_id)
+ release_mode = bool(req.get("release", request.args.get("release", False)))
+
+ try:
+ cvs, dsl = UserCanvasService.get_agent_dsl_with_release(agent_id, release_mode, tenant_id)
+ except LookupError:
+ return get_data_error_result(message="Agent not found.")
+ except PermissionError as e:
+ return get_data_error_result(message=str(e))
+
+ session_id = get_uuid()
+ canvas = Canvas(dsl, tenant_id, agent_id, canvas_id=cvs.id)
+ canvas.reset()
+
+ cvs.dsl = json.loads(str(canvas))
+ version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=release_mode)
+ conv = {
+ "id": session_id,
+ "name": req.get("name", ""),
+ "dialog_id": cvs.id,
+ "user_id": user_id,
+ "exp_user_id": user_id,
+ "message": [{"role": "assistant", "content": canvas.get_prologue()}],
+ "source": "agent",
+ "dsl": cvs.dsl,
+ "reference": [],
+ "version_title": version_title,
+ }
+ API4ConversationService.save(**conv)
+ return get_result(data=_normalize_agent_session(conv))
+
+
+@manager.route("/agents//sessions/", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_access_sync
+def get_agent_session(agent_id, session_id, tenant_id):
+ _, conv = API4ConversationService.get_by_id(session_id)
+ return get_json_result(data=conv.to_dict())
+
+
+@manager.route("/agents//sessions/", methods=["DELETE"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_access_sync
+def delete_agent_session_item(agent_id, session_id, tenant_id):
+ return get_json_result(data=API4ConversationService.delete_by_id(session_id))
+
+
+@manager.route("/agents/download", methods=["GET"]) # noqa: F821
+async def download_agent_file():
+ id = request.args.get("id")
+ created_by = request.args.get("created_by")
+ blob = FileService.get_blob(created_by, id)
+ return Response(blob)
+
+
+async def _iter_session_completion_events(tenant_id, agent_id, req, return_trace):
+ # Stream and non-stream session completions share the same event parsing and trace injection.
+ trace_items = []
+ async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
+ if isinstance(answer, str):
+ try:
+ ans = json.loads(answer[5:])
+ except Exception:
+ continue
+ else:
+ ans = answer
+
+ event = ans.get("event")
+ if event == "node_finished":
+ if return_trace:
+ data = ans.get("data", {})
+ trace_items.append(
+ {
+ "component_id": data.get("component_id"),
+ "trace": [copy.deepcopy(data)],
+ }
+ )
+ ans.setdefault("data", {})["trace"] = trace_items
+ yield ans
+ continue
+
+ if event in ["message", "message_end"]:
+ yield ans
+
+
+@manager.route("/agents/templates", methods=["GET"]) # noqa: F821
+@login_required
+def list_agent_template():
+ return get_json_result(data=[item.to_dict() for item in CanvasTemplateService.get_all()])
+
+
+@manager.route("/agents/prompts", methods=["GET"]) # noqa: F821
+@login_required
+def prompts():
+ from rag.prompts.generator import (
+ ANALYZE_TASK_SYSTEM,
+ ANALYZE_TASK_USER,
+ CITATION_PROMPT_TEMPLATE,
+ NEXT_STEP,
+ REFLECT,
+ )
+
+ return get_json_result(
+ data={
+ "task_analysis": f"{ANALYZE_TASK_SYSTEM}\n\n{ANALYZE_TASK_USER}",
+ "plan_generation": NEXT_STEP,
+ "reflection": REFLECT,
+ "citation_guidelines": CITATION_PROMPT_TEMPLATE,
+ }
+ )
+
+
+@manager.route("/agents", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+def list_agents(tenant_id):
+ keywords = request.args.get("keywords", "")
+ canvas_category = request.args.get("canvas_category")
+ owner_ids = [item for item in request.args.get("owner_ids", "").strip().split(",") if item]
+
+ page_number = int(request.args.get("page", 0))
+ items_per_page = int(request.args.get("page_size", 0))
+ order_by = request.args.get("orderby", "create_time")
+ desc = str(request.args.get("desc", "true")).lower() != "false"
+ tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
+ authorized_owner_ids = {member["tenant_id"] for member in tenants}
+ authorized_owner_ids.add(tenant_id)
+
+ if owner_ids:
+ requested_owner_ids = set(owner_ids)
+ unauthorized_owner_ids = requested_owner_ids - authorized_owner_ids
+ if unauthorized_owner_ids:
+ return get_json_result(
+ data=False,
+ message="Only authorized owner_ids can be queried.",
+ code=RetCode.OPERATING_ERROR,
+ )
+ effective_owner_ids = list(requested_owner_ids)
+ else:
+ effective_owner_ids = list(authorized_owner_ids)
+
+ canvas, total = UserCanvasService.get_by_tenant_ids(
+ effective_owner_ids,
+ tenant_id,
+ page_number,
+ items_per_page,
+ order_by,
+ desc,
+ keywords,
+ canvas_category,
+ )
+
+ return get_json_result(data={"canvas": canvas, "total": total})
+
+
+@manager.route("/agents", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def create_agent(tenant_id):
+ req = {k: v for k, v in (await get_request_json()).items() if v is not None}
+ req["user_id"] = tenant_id
+ req["canvas_category"] = req.get("canvas_category") or CanvasCategory.Agent
+ req["release"] = bool(req.get("release", ""))
+
+ if req.get("dsl") is None:
+ return get_json_result(
+ data=False,
+ message="No DSL data in request.",
+ code=RetCode.ARGUMENT_ERROR,
+ )
+
+ try:
+ req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
+ except ValueError as exc:
+ return get_json_result(
+ data=False,
+ message=str(exc),
+ code=RetCode.ARGUMENT_ERROR,
+ )
+
+ if req.get("title") is None:
+ return get_json_result(
+ data=False,
+ message="No title in request.",
+ code=RetCode.ARGUMENT_ERROR,
+ )
+
+ req["title"] = req["title"].strip()
+ if UserCanvasService.query(
+ user_id=tenant_id,
+ title=req["title"],
+ canvas_category=req["canvas_category"],
+ ):
+ return get_data_error_result(message=f"{req['title']} already exists.")
+
+ req["id"] = get_uuid()
+ if not UserCanvasService.save(**req):
+ return get_data_error_result(message="Fail to create agent.")
+
+ owner_nickname = _get_user_nickname(tenant_id)
+ UserCanvasVersionService.save_or_replace_latest(
+ user_canvas_id=req["id"],
+ title=UserCanvasVersionService.build_version_title(owner_nickname, req.get("title")),
+ dsl=req["dsl"],
+ release=req.get("release"),
+ )
+ replica_ok = CanvasReplicaService.replace_for_set(
+ canvas_id=req["id"],
+ tenant_id=str(tenant_id),
+ runtime_user_id=str(tenant_id),
+ dsl=req["dsl"],
+ canvas_category=req["canvas_category"],
+ title=req.get("title", ""),
+ )
+ if not replica_ok:
+ return get_data_error_result(message="canvas saved, but replica sync failed.")
+
+ exists, created_agent = UserCanvasService.get_by_canvas_id(req["id"])
+ if not exists:
+ return get_data_error_result(message="Fail to create agent.")
+ return get_json_result(data=created_agent)
+
+
+@manager.route("/agents//upload", methods=["POST"]) # noqa: F821
+async def upload_agent_file(agent_id):
+ exists, canvas = UserCanvasService.get_by_canvas_id(agent_id)
+ if not exists:
+ return get_data_error_result(message="canvas not found.")
+
+ user_id = canvas["user_id"]
+ files = await request.files
+ file_objs = files.getlist("file") if files and files.get("file") else []
+ try:
+ if len(file_objs) == 1:
+ return get_json_result(
+ data=FileService.upload_info(user_id, file_objs[0], request.args.get("url"))
+ )
+ results = [FileService.upload_info(user_id, file_obj) for file_obj in file_objs]
+ return get_json_result(data=results)
+ except Exception as exc:
+ return server_error_response(exc)
+
+
+@manager.route("/agents//components//input-form", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_access_sync
+def get_agent_component_input_form(agent_id, component_id, tenant_id):
+ try:
+ exists, user_canvas = UserCanvasService.get_by_id(agent_id)
+ if not exists:
+ return get_data_error_result(message="canvas not found.")
+ canvas = Canvas(json.dumps(user_canvas.dsl), tenant_id, canvas_id=user_canvas.id)
+ return get_json_result(data=canvas.get_component_input_form(component_id))
+ except Exception as exc:
+ return server_error_response(exc)
+
+
+@manager.route("/agents//components//debug", methods=["POST"]) # noqa: F821
+@validate_request("params")
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_access_async
+async def debug_agent_component(agent_id, component_id, tenant_id):
+ req = await get_request_json()
+ try:
+ _, user_canvas = UserCanvasService.get_by_id(agent_id)
+ canvas = Canvas(json.dumps(user_canvas.dsl), tenant_id, canvas_id=user_canvas.id)
+ canvas.reset()
+ canvas.message_id = get_uuid()
+ component = canvas.get_component(component_id)["obj"]
+ component.reset()
+
+ if isinstance(component, LLM):
+ component.set_debug_inputs(req["params"])
+ component.invoke(**{k: o["value"] for k, o in req["params"].items()})
+ outputs = component.output()
+ for k in outputs.keys():
+ if isinstance(outputs[k], partial):
+ txt = ""
+ iter_obj = outputs[k]()
+ if inspect.isasyncgen(iter_obj):
+ async for c in iter_obj:
+ txt += c
+ else:
+ for c in iter_obj:
+ txt += c
+ outputs[k] = txt
+ return get_json_result(data=outputs)
+ except Exception as exc:
+ return server_error_response(exc)
+
+
+@manager.route("/agents/", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+def get_agent(agent_id, tenant_id):
+ if not UserCanvasService.accessible(agent_id, tenant_id):
+ return get_data_error_result(message="canvas not found.")
+
+ exists, canvas = UserCanvasService.get_by_canvas_id(agent_id)
+ if not exists:
+ return get_data_error_result(message="canvas not found.")
+
+ try:
+ CanvasReplicaService.bootstrap(
+ canvas_id=agent_id,
+ tenant_id=str(tenant_id),
+ runtime_user_id=str(tenant_id),
+ dsl=canvas.get("dsl"),
+ canvas_category=canvas.get("canvas_category", CanvasCategory.Agent),
+ title=canvas.get("title", ""),
+ )
+ except ValueError as exc:
+ return get_data_error_result(message=str(exc))
+
+ last_publish_time = None
+ versions = UserCanvasVersionService.list_by_canvas_id(agent_id)
+ if versions:
+ released_versions = [version for version in versions if version.release]
+ if released_versions:
+ released_versions.sort(key=lambda version: version.update_time, reverse=True)
+ last_publish_time = released_versions[0].update_time
+
+ canvas["dsl"] = normalize_chunker_dsl(canvas.get("dsl", {}))
+ canvas["last_publish_time"] = last_publish_time
+
+ if canvas.get("canvas_category") == CanvasCategory.DataFlow:
+ datasets = list(KnowledgebaseService.query(pipeline_id=agent_id))
+ canvas["datasets"] = [{"id": item.id, "name": item.name, "avatar": item.avatar} for item in datasets]
+
+ return get_json_result(data=canvas)
+
+
+@manager.route("/agents//versions", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_access_sync
+def list_agent_versions(agent_id, tenant_id):
+ try:
+ versions = sorted(
+ [item.to_dict() for item in UserCanvasVersionService.list_by_canvas_id(agent_id)],
+ key=lambda item: item["update_time"] * -1,
+ )
+ return get_json_result(data=versions)
+ except Exception as exc:
+ return get_data_error_result(message=f"Error getting history files: {exc}")
+
+
+@manager.route("/agents//versions/", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_access_sync
+def get_agent_version(agent_id, version_id, tenant_id):
+ try:
+ exists, version = UserCanvasVersionService.get_by_id(version_id)
+ if not exists or not version or str(version.user_canvas_id) != str(agent_id):
+ return get_data_error_result(message="Version not found.")
+ return get_json_result(data=version.to_dict())
+ except Exception as exc:
+ return get_data_error_result(message=f"Error getting history file: {exc}")
+
+
+@manager.route("/agents//logs/", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_access_sync
+def get_agent_logs(agent_id, message_id, tenant_id):
+ try:
+ binary = REDIS_CONN.get(f"{agent_id}-{message_id}-logs")
+ if not binary:
+ return get_json_result(data={})
+
+ return get_json_result(data=json.loads(binary.encode("utf-8")))
+ except Exception as exc:
+ logging.exception(exc)
+ return server_error_response(exc)
+
+
+@manager.route("/agents/", methods=["DELETE"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_owner_sync
+def delete_agent(agent_id, tenant_id):
+ UserCanvasService.delete_by_id(agent_id)
+ return get_json_result(data=True)
+
+
+@manager.route("/agents/", methods=["PUT"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_access_async
+async def update_agent(agent_id, tenant_id):
+ req = {k: v for k, v in (await get_request_json()).items() if v is not None}
+ req["release"] = bool(req.get("release", ""))
+
+ if req.get("dsl") is not None:
+ try:
+ req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
+ except ValueError as exc:
+ return get_json_result(
+ data=False,
+ message=str(exc),
+ code=RetCode.ARGUMENT_ERROR,
+ )
+
+ if req.get("title") is not None:
+ req["title"] = req["title"].strip()
+
+ _, current_agent = UserCanvasService.get_by_id(agent_id)
+ agent_title_for_version = req.get("title") or (current_agent.title if current_agent else "")
+ canvas_category = (
+ req.get("canvas_category")
+ or (current_agent.canvas_category if current_agent else CanvasCategory.Agent)
+ )
+ owner_nickname = _get_user_nickname(tenant_id)
+ UserCanvasService.update_by_id(agent_id, req)
+
+ if req.get("dsl") is not None:
+ UserCanvasVersionService.save_or_replace_latest(
+ user_canvas_id=agent_id,
+ title=UserCanvasVersionService.build_version_title(owner_nickname, agent_title_for_version),
+ dsl=req["dsl"],
+ release=req.get("release"),
+ )
+ replica_ok = CanvasReplicaService.replace_for_set(
+ canvas_id=agent_id,
+ tenant_id=str(tenant_id),
+ runtime_user_id=str(tenant_id),
+ dsl=req["dsl"],
+ canvas_category=canvas_category,
+ title=agent_title_for_version,
+ )
+ if not replica_ok:
+ return get_data_error_result(message="agent saved, but replica sync failed.")
+
+ return get_json_result(data=True)
+
+
+@manager.route("/agents//reset", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+@_require_canvas_access_async
+async def reset_agent(agent_id, tenant_id):
+ try:
+ exists, user_canvas = UserCanvasService.get_by_id(agent_id)
+ if not exists:
+ return get_data_error_result(message="canvas not found.")
+
+ canvas = Canvas(json.dumps(user_canvas.dsl), tenant_id, canvas_id=user_canvas.id)
+ canvas.reset()
+ dsl = json.loads(str(canvas))
+ UserCanvasService.update_by_id(agent_id, {"dsl": dsl})
+ replica_ok = CanvasReplicaService.replace_for_set(
+ canvas_id=agent_id,
+ tenant_id=str(tenant_id),
+ runtime_user_id=str(tenant_id),
+ dsl=dsl,
+ canvas_category=user_canvas.canvas_category,
+ title=user_canvas.title,
+ )
+ if not replica_ok:
+ return get_data_error_result(message="agent reset, but replica sync failed.")
+ return get_json_result(data=dsl)
+ except Exception as exc:
+ return server_error_response(exc)
+
+
+@manager.route("/agents/rerun", methods=["POST"]) # noqa: F821
+@validate_request("id", "dsl", "component_id")
+@login_required
+@add_tenant_id_to_kwargs
+async def rerun_agent(tenant_id):
+ req = await get_request_json()
+ doc = PipelineOperationLogService.get_documents_info(req["id"])
+ if not doc:
+ return get_data_error_result(message="Document not found.")
+ doc = doc[0]
+ if 0 < doc["progress"] < 1:
+ return get_data_error_result(message=f"`{doc['name']}` is processing...")
+
+ if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc["kb_id"]):
+ settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(tenant_id), doc["kb_id"])
+ doc["progress_msg"] = ""
+ doc["chunk_num"] = 0
+ doc["token_num"] = 0
+ DocumentService.clear_chunk_num_when_rerun(doc["id"])
+ DocumentService.update_by_id(doc["id"], doc)
+ TaskService.filter_delete([Task.doc_id == doc["id"]])
+
+ dsl = req["dsl"]
+ dsl["path"] = [req["component_id"]]
+ PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
+ queue_dataflow(
+ tenant_id=tenant_id,
+ flow_id=req["id"],
+ task_id=get_uuid(),
+ doc_id=doc["id"],
+ priority=0,
+ rerun=True,
+ )
+ return get_json_result(data=True)
+
+
+@manager.route("/agents/test_db_connection", methods=["POST"]) # noqa: F821
+@validate_request("db_type", "database", "username", "host", "port", "password")
+@login_required
+async def test_db_connection():
+ req = await get_request_json()
+ try:
+ if req["db_type"] in ["mysql", "mariadb"]:
+ db = MySQLDatabase(
+ req["database"],
+ user=req["username"],
+ host=req["host"],
+ port=req["port"],
+ password=req["password"],
+ )
+ elif req["db_type"] == "oceanbase":
+ db = MySQLDatabase(
+ req["database"],
+ user=req["username"],
+ host=req["host"],
+ port=req["port"],
+ password=req["password"],
+ charset="utf8mb4",
+ )
+ elif req["db_type"] == "postgres":
+ db = PostgresqlDatabase(
+ req["database"],
+ user=req["username"],
+ host=req["host"],
+ port=req["port"],
+ password=req["password"],
+ )
+ elif req["db_type"] == "mssql":
+ import pyodbc
+
+ connection_string = (
+ f"DRIVER={{ODBC Driver 17 for SQL Server}};"
+ f"SERVER={req['host']},{req['port']};"
+ f"DATABASE={req['database']};"
+ f"UID={req['username']};"
+ f"PWD={req['password']};"
+ )
+ db = pyodbc.connect(connection_string)
+ cursor = db.cursor()
+ cursor.execute("SELECT 1")
+ cursor.close()
+ elif req["db_type"] == "IBM DB2":
+ import ibm_db
+
+ conn_str = (
+ f"DATABASE={req['database']};"
+ f"HOSTNAME={req['host']};"
+ f"PORT={req['port']};"
+ f"PROTOCOL=TCPIP;"
+ f"UID={req['username']};"
+ f"PWD={req['password']};"
+ )
+ logging.info(
+ "DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=****;",
+ req["database"],
+ req["host"],
+ req["port"],
+ req["username"],
+ )
+ conn = ibm_db.connect(conn_str, "", "")
+ stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
+ ibm_db.fetch_assoc(stmt)
+ ibm_db.close(conn)
+ return get_json_result(data="Database Connection Successful!")
+ elif req["db_type"] == "trino":
+ import os
+ import trino
+
+ db_name = req["database"]
+ if "." in db_name:
+ catalog, schema = db_name.split(".", 1)
+ elif "/" in db_name:
+ catalog, schema = db_name.split("/", 1)
+ else:
+ catalog, schema = db_name, "default"
+
+ http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
+ auth = None
+ if http_scheme == "https" and req.get("password"):
+ auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
+
+ conn = trino.dbapi.connect(
+ host=req["host"],
+ port=int(req["port"] or 8080),
+ user=req["username"] or "ragflow",
+ catalog=catalog,
+ schema=schema or "default",
+ http_scheme=http_scheme,
+ auth=auth,
+ )
+ cur = conn.cursor()
+ cur.execute("SELECT 1")
+ cur.fetchall()
+ cur.close()
+ conn.close()
+ return get_json_result(data="Database Connection Successful!")
+ else:
+ return server_error_response("Unsupported database type.")
+
+ if req["db_type"] != "mssql":
+ db.connect()
+ db.close()
+ return get_json_result(data="Database Connection Successful!")
+ except Exception as exc:
+ return server_error_response(exc)
+
+
+@manager.route("/agents/chat/completion", methods=["POST"]) # noqa: F821
+@manager.route("/agents/chat/completions", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def agent_chat_completion(tenant_id, agent_id=None):
+ # This endpoint serves two execution modes:
+ # 1. Draft/runtime execution without session state. The request runs against the caller's
+ # runtime replica, which is populated from the editable canvas state.
+ # 2. Session continuation with an existing session_id. The request resumes from the stored
+ # API4Conversation state and must stay bound to the same agent and an accessible canvas.
+ #
+ # Security constraints:
+ # - agent_id is always supplied at the route layer and is not forwarded downstream as a free-form kwarg.
+ # - New runs without session_id must pass UserCanvasService.accessible(...) before the runtime replica is loaded.
+ # - Existing sessions are validated here at the route layer before handing control to the lower-level
+ # completion functions, so canvas_service only executes a pre-authorized session payload.
+ #
+ # Response modes:
+ # - Regular mode emits internal agent events.
+ # - openai-compatible mode reshapes the same execution into an OpenAI-like wire format.
+ req = await get_request_json()
+ agent_id = agent_id or req.get("agent_id")
+ openai_compatible = bool(req.get("openai-compatible", False))
+ if not agent_id:
+ return get_json_result(
+ data=False,
+ message="`agent_id` is required.",
+ code=RetCode.ARGUMENT_ERROR,
+ )
+ # Route-level selectors should not be forwarded into the lower-level completion functions.
+ req = dict(req)
+ req.pop("agent_id", None)
+ req.pop("openai-compatible", None)
+ session_id = req.get("session_id")
+ if session_id:
+ exists, conv = API4ConversationService.get_by_id(session_id)
+ if not exists:
+ return get_data_error_result(message="Session not found!")
+ if conv.dialog_id != agent_id:
+ return get_json_result(
+ data=False,
+ message="Session does not belong to the requested agent.",
+ code=RetCode.OPERATING_ERROR,
+ )
+ if not UserCanvasService.accessible(agent_id, tenant_id):
+ return get_json_result(
+ data=False,
+ message="Only authorized users can access this agent session.",
+ code=RetCode.OPERATING_ERROR,
+ )
+
+ if openai_compatible:
+ # OpenAI-compatible mode uses a different wire format, keep it separate from regular agent events.
+ messages = req.get("messages", [])
+ if not messages:
+ return get_data_error_result(message="You must provide at least one message.")
+ question = next((m.get("content", "") for m in reversed(messages) if m.get("role") == "user"), "")
+ stream = req.pop("stream", False)
+ session_id = req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", "")
+ if stream:
+ return _build_sse_response(
+ completion_openai(
+ tenant_id,
+ agent_id,
+ question,
+ session_id=session_id,
+ stream=True,
+ **req,
+ )
+ )
+
+ async for response in completion_openai(
+ tenant_id,
+ agent_id,
+ question,
+ session_id=session_id,
+ stream=False,
+ **req,
+ ):
+ return jsonify(response)
+ return None
+
+ if not session_id:
+ # Without session state, run against the runtime replica that tracks draft edits.
+ query = req.get("query", "") or req.get("question", "")
+ files = req.get("files", [])
+ inputs = req.get("inputs", {})
+ runtime_user_id = req.get("user_id") or tenant_id
+ user_id = str(runtime_user_id)
+ custom_header = req.get("custom_header", "")
+
+ if not UserCanvasService.accessible(agent_id, tenant_id):
+ return get_json_result(
+ data=False,
+ message="Make sure you have permission to access the agent.",
+ code=RetCode.OPERATING_ERROR,
+ )
+
+ _, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id)
+ if not cvs:
+ return get_data_error_result(message="canvas not found.")
+
+ replica_payload = CanvasReplicaService.load_for_run(
+ canvas_id=agent_id,
+ tenant_id=str(tenant_id),
+ runtime_user_id=user_id,
+ )
+ if not replica_payload:
+ try:
+ replica_payload = CanvasReplicaService.bootstrap(
+ canvas_id=agent_id,
+ tenant_id=str(tenant_id),
+ runtime_user_id=user_id,
+ dsl=cvs.dsl,
+ canvas_category=getattr(cvs, "canvas_category", CanvasCategory.Agent),
+ title=getattr(cvs, "title", ""),
+ )
+ except ValueError as exc:
+ return get_data_error_result(message=str(exc))
+ if not replica_payload:
+ return get_data_error_result(message="canvas replica not found, please fetch the agent first.")
+
+ replica_dsl = replica_payload.get("dsl", {})
+ canvas_title = replica_payload.get("title", "")
+ canvas_category = replica_payload.get("canvas_category", CanvasCategory.Agent)
+ dsl_str = json.dumps(replica_dsl, ensure_ascii=False)
+
+ if cvs.canvas_category == CanvasCategory.DataFlow:
+ task_id = get_uuid()
+ Pipeline(
+ dsl_str,
+ tenant_id=str(tenant_id),
+ doc_id=CANVAS_DEBUG_DOC_ID,
+ task_id=task_id,
+ flow_id=agent_id,
+ )
+ ok, error_message = await thread_pool_exec(
+ queue_dataflow,
+ user_id,
+ agent_id,
+ task_id,
+ CANVAS_DEBUG_DOC_ID,
+ files[0],
+ 0,
+ )
+ if not ok:
+ return get_data_error_result(message=error_message)
+ return get_json_result(data={"message_id": task_id})
+
+ try:
+ canvas = Canvas(dsl_str, str(tenant_id), canvas_id=agent_id, custom_header=custom_header)
+ except Exception as exc:
+ return server_error_response(exc)
+
+ async def commit_runtime_replica():
+ commit_ok = CanvasReplicaService.commit_after_run(
+ canvas_id=agent_id,
+ tenant_id=str(tenant_id),
+ runtime_user_id=user_id,
+ dsl=json.loads(str(canvas)),
+ canvas_category=canvas_category,
+ title=canvas_title,
+ )
+ if not commit_ok:
+ logging.error(
+ "Canvas runtime replica commit failed: canvas_id=%s tenant_id=%s runtime_user_id=%s",
+ agent_id,
+ tenant_id,
+ user_id,
+ )
+
+ if req.get("stream", True):
+ async def sse():
+ nonlocal canvas
+ try:
+ async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
+ yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
+
+ await commit_runtime_replica()
+ except Exception as exc:
+ logging.exception(exc)
+ canvas.cancel_task()
+ yield (
+ "data:"
+ + json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False)
+ + "\n\n"
+ )
+
+ return _build_sse_response(sse())
+
+ full_content = ""
+ reference = {}
+ final_ans = {}
+ trace_items = []
+ structured_output = {}
+ try:
+ async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
+ if ans.get("event") == "message":
+ full_content += ans.get("data", {}).get("content", "")
+ if ans.get("data", {}).get("reference", None):
+ reference.update(ans["data"]["reference"])
+ if ans.get("event") == "node_finished":
+ data = ans.get("data", {})
+ node_out = data.get("outputs", {})
+ component_id = data.get("component_id")
+ if component_id is not None and "structured" in node_out:
+ structured_output[component_id] = copy.deepcopy(node_out["structured"])
+ if req.get("return_trace", False):
+ trace_items.append(
+ {
+ "component_id": data.get("component_id"),
+ "trace": [copy.deepcopy(data)],
+ }
+ )
+ final_ans = ans
+ except Exception as exc:
+ logging.exception(exc)
+ canvas.cancel_task()
+ return get_result(data=f"**ERROR**: {str(exc)}")
+
+ if not final_ans:
+ await commit_runtime_replica()
+ return get_result(data={})
+
+ if "data" not in final_ans or not isinstance(final_ans["data"], dict):
+ final_ans["data"] = {}
+ final_ans["data"]["content"] = full_content
+ final_ans["data"]["reference"] = reference
+ if structured_output:
+ final_ans["data"]["structured"] = structured_output
+ if trace_items:
+ final_ans["data"]["trace"] = trace_items
+
+ await commit_runtime_replica()
+ return get_result(data=final_ans)
+
+ return_trace = bool(req.get("return_trace", False))
+ if req.get("stream", True):
+
+ async def generate():
+ async for ans in _iter_session_completion_events(tenant_id, agent_id, req, return_trace):
+ yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
+ yield "data:[DONE]\n\n"
+
+ return _build_sse_response(generate())
+
+ full_content = ""
+ reference = {}
+ final_ans = {}
+ trace_items = []
+ structured_output = {}
+ async for ans in _iter_session_completion_events(tenant_id, agent_id, req, return_trace):
+ try:
+ if ans["event"] == "message":
+ full_content += ans["data"]["content"]
+ if ans.get("data", {}).get("reference", None):
+ reference.update(ans["data"]["reference"])
+ if ans.get("event") == "node_finished":
+ data = ans.get("data", {})
+ node_out = data.get("outputs", {})
+ component_id = data.get("component_id")
+ if component_id is not None and "structured" in node_out:
+ structured_output[component_id] = copy.deepcopy(node_out["structured"])
+ if return_trace:
+ trace_items.append(
+ {
+ "component_id": data.get("component_id"),
+ "trace": [copy.deepcopy(data)],
+ }
+ )
+ final_ans = ans
+ except Exception as exc:
+ return get_result(data=f"**ERROR**: {str(exc)}")
+
+ if not final_ans:
+ return get_result(data={})
+
+ if "data" not in final_ans or not isinstance(final_ans["data"], dict):
+ final_ans["data"] = {}
+ final_ans["data"]["content"] = full_content
+ final_ans["data"]["reference"] = reference
+ if structured_output:
+ final_ans["data"]["structured"] = structured_output
+ if return_trace and final_ans:
+ final_ans["data"]["trace"] = trace_items
+ return get_result(data=final_ans)
+
+
+@manager.route("/agents//webhook", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
+@manager.route("/agents//webhook/test",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
+async def webhook(agent_id: str):
+ is_test = request.path.startswith(f"/api/v1/agents/{agent_id}/webhook/test")
+ start_ts = time.time()
+
+ # 1. Fetch canvas by agent_id
+ exists, cvs = UserCanvasService.get_by_id(agent_id)
+ if not exists:
+ return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
+
+ # 2. Check canvas category
+ if cvs.canvas_category == CanvasCategory.DataFlow:
+ return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
+
+ # 3. Load DSL from canvas
+ dsl = getattr(cvs, "dsl", None)
+ if not isinstance(dsl, dict):
+ return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
+
+ # 4. Check webhook configuration in DSL
+ webhook_cfg = {}
+ components = dsl.get("components", {})
+ for k, _ in components.items():
+ cpn_obj = components[k]["obj"]
+ if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
+ webhook_cfg = cpn_obj["params"]
+
+ if not webhook_cfg:
+ return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
+
+ # 5. Validate request method against webhook_cfg.methods
+ allowed_methods = webhook_cfg.get("methods", [])
+ request_method = request.method.upper()
+ if allowed_methods and request_method not in allowed_methods:
+ return get_data_error_result(
+ code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
+ ),RetCode.BAD_REQUEST
+
+ # 6. Validate webhook security
+ async def validate_webhook_security(security_cfg: dict):
+ """Validate webhook security rules based on security configuration."""
+
+ if not security_cfg:
+ return # No security config → allowed by default
+
+ # 1. Validate max body size
+ await _validate_max_body_size(security_cfg)
+
+ # 2. Validate IP whitelist
+ _validate_ip_whitelist(security_cfg)
+
+ # # 3. Validate rate limiting
+ _validate_rate_limit(security_cfg)
+
+ # 4. Validate authentication
+ auth_type = security_cfg.get("auth_type", "none")
+
+ if auth_type == "none":
+ return
+
+ if auth_type == "token":
+ _validate_token_auth(security_cfg)
+
+ elif auth_type == "basic":
+ _validate_basic_auth(security_cfg)
+
+ elif auth_type == "jwt":
+ _validate_jwt_auth(security_cfg)
+
+ else:
+ raise Exception(f"Unsupported auth_type: {auth_type}")
+
+ async def _validate_max_body_size(security_cfg):
+ """Check request size does not exceed max_body_size."""
+ max_size = security_cfg.get("max_body_size")
+ if not max_size:
+ return
+
+ # Convert "10MB" → bytes
+ units = {"kb": 1024, "mb": 1024**2}
+ size_str = max_size.lower()
+
+ for suffix, factor in units.items():
+ if size_str.endswith(suffix):
+ limit = int(size_str.replace(suffix, "")) * factor
+ break
+ else:
+ raise Exception("Invalid max_body_size format")
+ MAX_LIMIT = 10 * 1024 * 1024 # 10MB
+ if limit > MAX_LIMIT:
+ raise Exception("max_body_size exceeds maximum allowed size (10MB)")
+
+ content_length = request.content_length or 0
+ if content_length > limit:
+ raise Exception(f"Request body too large: {content_length} > {limit}")
+
+ def _validate_ip_whitelist(security_cfg):
+ """Allow only IPs listed in ip_whitelist."""
+ whitelist = security_cfg.get("ip_whitelist", [])
+ if not whitelist:
+ return
+
+ client_ip = request.remote_addr
+
+
+ for rule in whitelist:
+ if "/" in rule:
+ # CIDR notation
+ if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
+ return
+ else:
+ # Single IP
+ if client_ip == rule:
+ return
+
+ raise Exception(f"IP {client_ip} is not allowed by whitelist")
+
+ def _validate_rate_limit(security_cfg):
+ """Simple in-memory rate limiting."""
+ rl = security_cfg.get("rate_limit")
+ if not rl:
+ return
+
+ limit = int(rl.get("limit", 60))
+ if limit <= 0:
+ raise Exception("rate_limit.limit must be > 0")
+ per = rl.get("per", "minute")
+
+ window = {
+ "second": 1,
+ "minute": 60,
+ "hour": 3600,
+ "day": 86400,
+ }.get(per)
+
+ if not window:
+ raise Exception(f"Invalid rate_limit.per: {per}")
+
+ capacity = limit
+ rate = limit / window
+ cost = 1
+
+ key = f"rl:tb:{agent_id}"
+ now = time.time()
+
+ try:
+ res = REDIS_CONN.lua_token_bucket(
+ keys=[key],
+ args=[capacity, rate, now, cost],
+ client=REDIS_CONN.REDIS,
+ )
+
+ allowed = int(res[0])
+ if allowed != 1:
+ raise Exception("Too many requests (rate limit exceeded)")
+
+ except Exception as e:
+ raise Exception(f"Rate limit error: {e}")
+
+ def _validate_token_auth(security_cfg):
+ """Validate header-based token authentication."""
+ token_cfg = security_cfg.get("token",{})
+ header = token_cfg.get("token_header")
+ token_value = token_cfg.get("token_value")
+
+ provided = request.headers.get(header)
+ if provided != token_value:
+ raise Exception("Invalid token authentication")
+
+ def _validate_basic_auth(security_cfg):
+ """Validate HTTP Basic Auth credentials."""
+ auth_cfg = security_cfg.get("basic_auth", {})
+ username = auth_cfg.get("username")
+ password = auth_cfg.get("password")
+
+ auth = request.authorization
+ if not auth or auth.username != username or auth.password != password:
+ raise Exception("Invalid Basic Auth credentials")
+
+ def _validate_jwt_auth(security_cfg):
+ """Validate JWT token in Authorization header."""
+ jwt_cfg = security_cfg.get("jwt", {})
+ secret = jwt_cfg.get("secret")
+ if not secret:
+ raise Exception("JWT secret not configured")
+
+ auth_header = request.headers.get("Authorization", "")
+ if not auth_header.startswith("Bearer "):
+ raise Exception("Missing Bearer token")
+
+ token = auth_header[len("Bearer "):].strip()
+ if not token:
+ raise Exception("Empty Bearer token")
+
+ alg = (jwt_cfg.get("algorithm") or "HS256").upper()
+
+ decode_kwargs = {
+ "key": secret,
+ "algorithms": [alg],
+ }
+ options = {}
+ if jwt_cfg.get("audience"):
+ decode_kwargs["audience"] = jwt_cfg["audience"]
+ options["verify_aud"] = True
+ else:
+ options["verify_aud"] = False
+
+ if jwt_cfg.get("issuer"):
+ decode_kwargs["issuer"] = jwt_cfg["issuer"]
+ options["verify_iss"] = True
+ else:
+ options["verify_iss"] = False
+ try:
+ decoded = jwt.decode(
+ token,
+ options=options,
+ **decode_kwargs,
+ )
+ except Exception as e:
+ raise Exception(f"Invalid JWT: {str(e)}")
+
+ raw_required_claims = jwt_cfg.get("required_claims", [])
+ if isinstance(raw_required_claims, str):
+ required_claims = [raw_required_claims]
+ elif isinstance(raw_required_claims, (list, tuple, set)):
+ required_claims = list(raw_required_claims)
+ else:
+ required_claims = []
+
+ required_claims = [
+ c for c in required_claims
+ if isinstance(c, str) and c.strip()
+ ]
+
+ RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
+ for claim in required_claims:
+ if claim in RESERVED_CLAIMS:
+ raise Exception(f"Reserved JWT claim cannot be required: {claim}")
+
+ for claim in required_claims:
+ if claim not in decoded:
+ raise Exception(f"Missing JWT claim: {claim}")
+
+ return decoded
+
+ try:
+ security_config=webhook_cfg.get("security", {})
+ await validate_webhook_security(security_config)
+ except Exception as e:
+ return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
+ if not isinstance(cvs.dsl, str):
+ dsl = json.dumps(cvs.dsl, ensure_ascii=False)
+ try:
+ canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id)
+ except Exception as e:
+ resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
+ resp.status_code = RetCode.BAD_REQUEST
+ return resp
+
+ # 7. Parse request body
+ async def parse_webhook_request(content_type):
+ """Parse request based on content-type and return structured data."""
+
+ # 1. Query
+ query_data = {k: v for k, v in request.args.items()}
+
+ # 2. Headers
+ header_data = {k: v for k, v in request.headers.items()}
+
+ # 3. Body
+ ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
+ if ctype and ctype != content_type:
+ raise ValueError(
+ f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
+ )
+
+ body_data: dict = {}
+
+ try:
+ if ctype == "application/json":
+ body_data = await request.get_json() or {}
+
+ elif ctype == "multipart/form-data":
+ nonlocal canvas
+ form = await request.form
+ files = await request.files
+
+ body_data = {}
+
+ for key, value in form.items():
+ body_data[key] = value
+
+ if len(files) > 10:
+ raise Exception("Too many uploaded files")
+ for key, file in files.items():
+ desc = FileService.upload_info(
+ cvs.user_id, # user
+ file, # FileStorage
+ None # url (None for webhook)
+ )
+ file_parsed= await canvas.get_files_async([desc])
+ body_data[key] = file_parsed
+
+ elif ctype == "application/x-www-form-urlencoded":
+ form = await request.form
+ body_data = dict(form)
+
+ else:
+ # text/plain / octet-stream / empty / unknown
+ raw = await request.get_data()
+ if raw:
+ try:
+ body_data = json.loads(raw.decode("utf-8"))
+ except Exception:
+ body_data = {}
+ else:
+ body_data = {}
+
+ except Exception:
+ body_data = {}
+
+ return {
+ "query": query_data,
+ "headers": header_data,
+ "body": body_data,
+ "content_type": ctype,
+ }
+
+ def extract_by_schema(data, schema, name="section"):
+ """
+ Extract only fields defined in schema.
+ Required fields must exist.
+ Optional fields default to type-based default values.
+ Type validation included.
+ """
+ props = schema.get("properties", {})
+ required = schema.get("required", [])
+
+ extracted = {}
+
+ for field, field_schema in props.items():
+ field_type = field_schema.get("type")
+
+ # 1. Required field missing
+ if field in required and field not in data:
+ raise Exception(f"{name} missing required field: {field}")
+
+ # 2. Optional → default value
+ if field not in data:
+ extracted[field] = default_for_type(field_type)
+ continue
+
+ raw_value = data[field]
+
+ # 3. Auto convert value
+ try:
+ value = auto_cast_value(raw_value, field_type)
+ except Exception as e:
+ raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
+
+ # 4. Type validation
+ if not validate_type(value, field_type):
+ raise Exception(
+ f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
+ )
+
+ extracted[field] = value
+
+ return extracted
+
+
+ def default_for_type(t):
+ """Return default value for the given schema type."""
+ if t == "file":
+ return []
+ if t == "object":
+ return {}
+ if t == "boolean":
+ return False
+ if t == "number":
+ return 0
+ if t == "string":
+ return ""
+ if t and t.startswith("array"):
+ return []
+ if t == "null":
+ return None
+ return None
+
+ def auto_cast_value(value, expected_type):
+ """Convert string values into schema type when possible."""
+
+ # Non-string values already good
+ if not isinstance(value, str):
+ return value
+
+ v = value.strip()
+
+ # Boolean
+ if expected_type == "boolean":
+ if v.lower() in ["true", "1"]:
+ return True
+ if v.lower() in ["false", "0"]:
+ return False
+ raise Exception(f"Cannot convert '{value}' to boolean")
+
+ # Number
+ if expected_type == "number":
+ # integer
+ if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
+ return int(v)
+
+ # float
+ try:
+ return float(v)
+ except Exception:
+ raise Exception(f"Cannot convert '{value}' to number")
+
+ # Object
+ if expected_type == "object":
+ try:
+ parsed = json.loads(v)
+ if isinstance(parsed, dict):
+ return parsed
+ else:
+ raise Exception("JSON is not an object")
+ except Exception:
+ raise Exception(f"Cannot convert '{value}' to object")
+
+ # Array
+ if expected_type.startswith("array"):
+ try:
+ parsed = json.loads(v)
+ if isinstance(parsed, list):
+ return parsed
+ else:
+ raise Exception("JSON is not an array")
+ except Exception:
+ raise Exception(f"Cannot convert '{value}' to array")
+
+ # String (accept original)
+ if expected_type == "string":
+ return value
+
+ # File
+ if expected_type == "file":
+ return value
+ # Default: do nothing
+ return value
+
+
+ def validate_type(value, t):
+ """Validate value type against schema type t."""
+ if t == "file":
+ return isinstance(value, list)
+
+ if t == "string":
+ return isinstance(value, str)
+
+ if t == "number":
+ return isinstance(value, (int, float))
+
+ if t == "boolean":
+ return isinstance(value, bool)
+
+ if t == "object":
+ return isinstance(value, dict)
+
+ # array / array / array
+ if t.startswith("array"):
+ if not isinstance(value, list):
+ return False
+
+ if "<" in t and ">" in t:
+ inner = t[t.find("<") + 1 : t.find(">")]
+
+ # Check each element type
+ for item in value:
+ if not validate_type(item, inner):
+ return False
+
+ return True
+
+ return True
+ parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
+ SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
+
+ # Extract strictly by schema
+ try:
+ query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
+ header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
+ body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
+ except Exception as e:
+ return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
+
+ clean_request = {
+ "query": query_clean,
+ "headers": header_clean,
+ "body": body_clean,
+ "input": parsed
+ }
+
+ execution_mode = webhook_cfg.get("execution_mode", "Immediately")
+ response_cfg = webhook_cfg.get("response", {})
+
+ def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
+ key = f"webhook-trace-{agent_id}-logs"
+
+ raw = REDIS_CONN.get(key)
+ obj = json.loads(raw) if raw else {"webhooks": {}}
+
+ ws = obj["webhooks"].setdefault(
+ str(start_ts),
+ {"start_ts": start_ts, "events": []}
+ )
+
+ ws["events"].append({
+ "ts": time.time(),
+ **event
+ })
+
+ REDIS_CONN.set_obj(key, obj, ttl)
+
+ if execution_mode == "Immediately":
+ status = response_cfg.get("status", 200)
+ try:
+ status = int(status)
+ except (TypeError, ValueError):
+ return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
+
+ if not (200 <= status <= 399):
+ return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
+
+ body_tpl = response_cfg.get("body_template", "")
+
+ def parse_body(body: str):
+ if not body:
+ return None, "application/json"
+
+ try:
+ parsed = json.loads(body)
+ return parsed, "application/json"
+ except (json.JSONDecodeError, TypeError):
+ return body, "text/plain"
+
+
+ body, content_type = parse_body(body_tpl)
+ resp = Response(
+ json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
+ status=status,
+ content_type=content_type,
+ )
+
+ async def background_run():
+ try:
+ async for ans in canvas.run(
+ query="",
+ user_id=cvs.user_id,
+ webhook_payload=clean_request
+ ):
+ if is_test:
+ append_webhook_trace(agent_id, start_ts, ans)
+
+ if is_test:
+ append_webhook_trace(
+ agent_id,
+ start_ts,
+ {
+ "event": "finished",
+ "elapsed_time": time.time() - start_ts,
+ "success": True,
+ }
+ )
+
+ cvs.dsl = json.loads(str(canvas))
+ UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
+
+ except Exception as e:
+ logging.exception("Webhook background run failed")
+ if is_test:
+ try:
+ append_webhook_trace(
+ agent_id,
+ start_ts,
+ {
+ "event": "error",
+ "message": str(e),
+ "error_type": type(e).__name__,
+ }
+ )
+ append_webhook_trace(
+ agent_id,
+ start_ts,
+ {
+ "event": "finished",
+ "elapsed_time": time.time() - start_ts,
+ "success": False,
+ }
+ )
+ except Exception:
+ logging.exception("Failed to append webhook trace")
+
+ asyncio.create_task(background_run())
+ return resp
+ else:
+ async def sse():
+ nonlocal canvas
+ contents: list[str] = []
+ status = 200
+ try:
+ async for ans in canvas.run(
+ query="",
+ user_id=cvs.user_id,
+ webhook_payload=clean_request,
+ ):
+ if ans["event"] == "message":
+ content = ans["data"]["content"]
+ if ans["data"].get("start_to_think", False):
+ content = ""
+ elif ans["data"].get("end_to_think", False):
+ content = " "
+ if content:
+ contents.append(content)
+ if ans["event"] == "message_end":
+ status = int(ans["data"].get("status", status))
+ if is_test:
+ append_webhook_trace(
+ agent_id,
+ start_ts,
+ ans
+ )
+ if is_test:
+ append_webhook_trace(
+ agent_id,
+ start_ts,
+ {
+ "event": "finished",
+ "elapsed_time": time.time() - start_ts,
+ "success": True,
+ }
+ )
+ final_content = "".join(contents)
+ return {
+ "message": final_content,
+ "success": True,
+ "code": status,
+ }
+
+ except Exception as e:
+ if is_test:
+ append_webhook_trace(
+ agent_id,
+ start_ts,
+ {
+ "event": "error",
+ "message": str(e),
+ "error_type": type(e).__name__,
+ }
+ )
+ append_webhook_trace(
+ agent_id,
+ start_ts,
+ {
+ "event": "finished",
+ "elapsed_time": time.time() - start_ts,
+ "success": False,
+ }
+ )
+ return {"code": 400, "message": str(e),"success":False}
+
+ result = await sse()
+ return Response(
+ json.dumps(result),
+ status=result["code"],
+ mimetype="application/json",
+ )
+
+
+@manager.route("/agents//webhook/logs", methods=["GET"]) # noqa: F821
+@login_required
+async def webhook_trace(agent_id: str):
+ exists, cvs = UserCanvasService.get_by_id(agent_id)
+ if not exists or str(cvs.user_id) != str(current_user.id):
+ return get_data_error_result(
+ message="Canvas not found.",
+ )
+
+ def encode_webhook_id(start_ts: str) -> str:
+ WEBHOOK_ID_SECRET = "webhook_id_secret"
+ sig = hmac.new(
+ WEBHOOK_ID_SECRET.encode("utf-8"),
+ start_ts.encode("utf-8"),
+ hashlib.sha256,
+ ).digest()
+ return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=")
+
+ def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None:
+ for ts in webhooks.keys():
+ if encode_webhook_id(ts) == enc_id:
+ return ts
+ return None
+ since_ts = request.args.get("since_ts", type=float)
+ webhook_id = request.args.get("webhook_id")
+
+ key = f"webhook-trace-{agent_id}-logs"
+ raw = REDIS_CONN.get(key)
+
+ if since_ts is None:
+ now = time.time()
+ return get_json_result(
+ data={
+ "webhook_id": None,
+ "events": [],
+ "next_since_ts": now,
+ "finished": False,
+ }
+ )
+
+ if not raw:
+ return get_json_result(
+ data={
+ "webhook_id": None,
+ "events": [],
+ "next_since_ts": since_ts,
+ "finished": False,
+ }
+ )
+
+ obj = json.loads(raw)
+ webhooks = obj.get("webhooks", {})
+
+ if webhook_id is None:
+ candidates = [
+ float(k) for k in webhooks.keys() if float(k) > since_ts
+ ]
+
+ if not candidates:
+ return get_json_result(
+ data={
+ "webhook_id": None,
+ "events": [],
+ "next_since_ts": since_ts,
+ "finished": False,
+ }
+ )
+
+ start_ts = min(candidates)
+ real_id = str(start_ts)
+ webhook_id = encode_webhook_id(real_id)
+
+ return get_json_result(
+ data={
+ "webhook_id": webhook_id,
+ "events": [],
+ "next_since_ts": start_ts,
+ "finished": False,
+ }
+ )
+
+ real_id = decode_webhook_id(webhook_id, webhooks)
+
+ if not real_id:
+ return get_json_result(
+ data={
+ "webhook_id": webhook_id,
+ "events": [],
+ "next_since_ts": since_ts,
+ "finished": True,
+ }
+ )
+
+ ws = webhooks.get(str(real_id))
+ events = ws.get("events", [])
+ new_events = [e for e in events if e.get("ts", 0) > since_ts]
+
+ next_ts = since_ts
+ for e in new_events:
+ next_ts = max(next_ts, e["ts"])
+
+ finished = any(e.get("event") == "finished" for e in new_events)
+
+ return get_json_result(
+ data={
+ "webhook_id": webhook_id,
+ "events": new_events,
+ "next_since_ts": next_ts,
+ "finished": finished,
+ }
+ )
diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py
index 263294b53fa..fab74f5c62a 100644
--- a/api/apps/restful_apis/chat_api.py
+++ b/api/apps/restful_apis/chat_api.py
@@ -20,6 +20,7 @@
import re
import tempfile
from copy import deepcopy
+from types import SimpleNamespace
from quart import Response, request
@@ -30,7 +31,7 @@
)
from api.db.services.chunk_feedback_service import ChunkFeedbackService
from api.db.services.conversation_service import ConversationService, structure_answer
-from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
+from api.db.services.dialog_service import DialogService, async_chat, gen_mindmap
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
@@ -67,6 +68,15 @@
"tts": False,
"refine_multiturn": True,
}
+_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG = {
+ "system": "",
+ "prologue": "",
+ "parameters": [],
+ "empty_response": "",
+ "quote": False,
+ "tts": False,
+ "refine_multiturn": True,
+}
_DEFAULT_RERANK_MODELS = {"BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"}
_READONLY_FIELDS = {"id", "tenant_id", "created_by", "create_time", "create_date", "update_time", "update_date"}
_PERSISTED_FIELDS = set(DialogService.model._meta.fields)
@@ -124,6 +134,39 @@ def _ensure_owned_chat(chat_id):
)
+def _build_default_completion_dialog():
+ return SimpleNamespace(
+ tenant_id=current_user.id,
+ llm_id="",
+ tenant_llm_id=None,
+ llm_setting={},
+ prompt_config=deepcopy(_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG),
+ kb_ids=[],
+ top_n=6,
+ top_k=1024,
+ rerank_id="",
+ similarity_threshold=0.1,
+ vector_similarity_weight=0.3,
+ meta_data_filter=None,
+ )
+
+
+def _create_session_for_completion(chat_id, dialog, user_id):
+ conv = {
+ "id": get_uuid(),
+ "dialog_id": chat_id,
+ "name": "New session",
+ "message": [{"role": "assistant", "content": dialog.prompt_config.get("prologue", "")}],
+ "user_id": user_id,
+ "reference": [],
+ }
+ ConversationService.save(**conv)
+ ok, conv_obj = ConversationService.get_by_id(conv["id"])
+ if not ok:
+ raise LookupError("Fail to create a session!")
+ return conv_obj
+
+
def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
if not llm_id:
return None
@@ -565,6 +608,15 @@ async def bulk_delete_chats():
if not ids:
return get_json_result(data={})
else:
+ # keep backward compatibility, DELETE with chat_id in request body
+ chat_id = req.get("chat_id")
+ if chat_id:
+ try:
+ if not DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value}):
+ return get_data_error_result(message=f"Failed to delete chat {chat_id}")
+ return get_json_result(data=True)
+ except Exception as ex:
+ return server_error_response(ex)
return get_json_result(data={})
errors = []
@@ -671,7 +723,7 @@ async def get_session(chat_id, session_id):
return server_error_response(ex)
-@manager.route("/chats//sessions/", methods=["PUT"]) # noqa: F821
+@manager.route("/chats//sessions/", methods=["PATCH"]) # noqa: F821
@login_required
async def update_session(chat_id, session_id):
if not _ensure_owned_chat(chat_id):
@@ -829,7 +881,7 @@ async def update_message_feedback(chat_id, session_id, msg_id):
return server_error_response(ex)
-@manager.route("/chats/tts", methods=["POST"]) # noqa: F821
+@manager.route("/chat/audio/speech", methods=["POST"]) # noqa: F821
@login_required
async def tts():
req = await get_request_json()
@@ -857,9 +909,9 @@ def stream_audio():
return resp
-@manager.route("/chats/transcriptions", methods=["POST"]) # noqa: F821
+@manager.route("/chat/audio/transcription", methods=["POST"]) # noqa: F821
@login_required
-async def transcriptions():
+async def transcription():
req = await request.form
stream_mode = req.get("stream", "false").lower() == "true"
files = await request.files
@@ -915,7 +967,7 @@ async def event_stream():
return Response(event_stream(), content_type="text/event-stream")
-@manager.route("/chats/mindmap", methods=["POST"]) # noqa: F821
+@manager.route("/chat/mindmap", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
async def mindmap():
@@ -933,10 +985,10 @@ async def mindmap():
return get_json_result(data=mind_map)
-@manager.route("/chats/related_questions", methods=["POST"]) # noqa: F821
+@manager.route("/chat/recommendation", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question")
-async def related_questions():
+async def recommendation():
req = await get_request_json()
search_id = req.get("search_id", "")
@@ -971,10 +1023,10 @@ async def related_questions():
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
-@manager.route("/chats//sessions//completions", methods=["POST"]) # noqa: F821
+@manager.route("/chat/completions", methods=["POST"]) # noqa: F821
@login_required
@validate_request("messages")
-async def session_completion(chat_id, session_id):
+async def session_completion(chat_id_in_arg=""):
req = await get_request_json()
msg = []
for m in req["messages"]:
@@ -984,6 +1036,9 @@ async def session_completion(chat_id, session_id):
continue
msg.append(m)
message_id = msg[-1].get("id") if msg else None
+ chat_id = req.pop("chat_id", "") or ""
+ chat_id = chat_id or chat_id_in_arg
+ session_id = req.pop("session_id", "") or ""
chat_model_id = req.pop("llm_id", "")
chat_model_config = {}
@@ -993,21 +1048,41 @@ async def session_completion(chat_id, session_id):
chat_model_config[model_config] = config
try:
- e, conv = ConversationService.get_by_id(session_id)
- if not e:
- return get_data_error_result(message="Session not found!")
- if conv.dialog_id != chat_id:
- return get_data_error_result(message="Session does not belong to this chat!")
- conv.message = deepcopy(req["messages"])
- e, dia = DialogService.get_by_id(chat_id)
- if not e:
- return get_data_error_result(message="Chat not found!")
+ conv = None
+ if session_id and not chat_id:
+ return get_data_error_result(message="`chat_id` is required when `session_id` is provided.")
+
+ if chat_id:
+ if not _ensure_owned_chat(chat_id):
+ return get_json_result(
+ data=False,
+ message="No authorization.",
+ code=RetCode.AUTHENTICATION_ERROR,
+ )
+ e, dia = DialogService.get_by_id(chat_id)
+ if not e:
+ return get_data_error_result(message="Chat not found!")
+ if session_id:
+ e, conv = ConversationService.get_by_id(session_id)
+ if not e:
+ return get_data_error_result(message="Session not found!")
+ if conv.dialog_id != chat_id:
+ return get_data_error_result(message="Session does not belong to this chat!")
+ else:
+ conv = _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id))
+ session_id = conv.id
+ conv.message = deepcopy(req["messages"])
+ else:
+ dia = _build_default_completion_dialog()
+ dia.llm_setting = chat_model_config
+
del req["messages"]
- if not conv.reference:
- conv.reference = []
- conv.reference = [r for r in conv.reference if r]
- conv.reference.append({"chunks": [], "doc_aggs": []})
+ if conv is not None:
+ if not conv.reference:
+ conv.reference = []
+ conv.reference = [r for r in conv.reference if r]
+ conv.reference.append({"chunks": [], "doc_aggs": []})
if chat_model_id:
if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id):
@@ -1015,16 +1090,21 @@ async def session_completion(chat_id, session_id):
dia.llm_id = chat_model_id
dia.llm_setting = chat_model_config
- is_embedded = bool(chat_model_id)
stream_mode = req.pop("stream", True)
+ def _format_answer(ans):
+ formatted = structure_answer(conv, ans, message_id, session_id)
+ if chat_id:
+ formatted["chat_id"] = chat_id
+ return formatted
+
async def stream():
nonlocal dia, msg, req, conv
try:
async for ans in async_chat(dia, msg, True, **req):
- ans = structure_answer(conv, ans, message_id, conv.id)
+ ans = _format_answer(ans)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
- if not is_embedded:
+ if conv is not None:
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as ex:
logging.exception(ex)
@@ -1041,40 +1121,10 @@ async def stream():
answer = None
async for ans in async_chat(dia, msg, **req):
- answer = structure_answer(conv, ans, message_id, conv.id)
- if not is_embedded:
+ answer = _format_answer(ans)
+ if conv is not None:
ConversationService.update_by_id(conv.id, conv.to_dict())
break
return get_json_result(data=answer)
except Exception as ex:
return server_error_response(ex)
-
-
-@manager.route("/chats/ask", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("question", "kb_ids")
-async def ask():
- req = await get_request_json()
- uid = current_user.id
-
- search_id = req.get("search_id", "")
- search_config = {}
- if search_id:
- if search_app := SearchService.get_detail(search_id):
- search_config = search_app.get("search_config", {})
-
- async def stream():
- nonlocal req, uid
- try:
- async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
- yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
- except Exception as ex:
- yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n"
- yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
-
- resp = Response(stream(), mimetype="text/event-stream")
- resp.headers.add_header("Cache-control", "no-cache")
- resp.headers.add_header("Connection", "keep-alive")
- resp.headers.add_header("X-Accel-Buffering", "no")
- resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
- return resp
diff --git a/api/apps/restful_apis/chunk_api.py b/api/apps/restful_apis/chunk_api.py
new file mode 100644
index 00000000000..13b5cb5801e
--- /dev/null
+++ b/api/apps/restful_apis/chunk_api.py
@@ -0,0 +1,445 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import base64
+import datetime
+import re
+
+import xxhash
+from pydantic import BaseModel, Field, validator
+from quart import request
+
+from api.apps import login_required
+from api.db.joint_services.tenant_model_service import (
+ get_model_config_by_id,
+ get_model_config_by_type_and_name,
+)
+from api.db.services.document_service import DocumentService
+from api.db.services.knowledgebase_service import KnowledgebaseService
+from api.db.services.tenant_llm_service import TenantLLMService
+from api.utils.api_utils import (
+ add_tenant_id_to_kwargs,
+ check_duplicate_ids,
+ get_error_data_result,
+ get_request_json,
+ get_result,
+ server_error_response,
+)
+from api.utils.image_utils import store_chunk_image
+from common import settings
+from common.constants import LLMType, ParserType, RetCode
+from common.misc_utils import thread_pool_exec
+from common.string_utils import is_content_empty, remove_redundant_spaces
+from common.tag_feature_utils import validate_tag_features
+from rag.app.qa import beAdoc, rmPrefix
+from rag.nlp import rag_tokenizer, search
+
+
+class Chunk(BaseModel):
+ id: str = ""
+ content: str = ""
+ document_id: str = ""
+ docnm_kwd: str = ""
+ important_keywords: list = Field(default_factory=list)
+ tag_kwd: list = Field(default_factory=list)
+ questions: list = Field(default_factory=list)
+ question_tks: str = ""
+ image_id: str = ""
+ available: bool = True
+ positions: list[list[int]] = Field(default_factory=list)
+
+ @validator("positions")
+ def validate_positions(cls, value):
+ for sublist in value:
+ if len(sublist) != 5:
+ raise ValueError("Each sublist in positions must have a length of 5")
+ return value
+
+
+def _map_doc(doc):
+ key_mapping = {
+ "chunk_num": "chunk_count",
+ "kb_id": "dataset_id",
+ "token_num": "token_count",
+ "parser_id": "chunk_method",
+ }
+ run_mapping = {
+ "0": "UNSTART",
+ "1": "RUNNING",
+ "2": "CANCEL",
+ "3": "DONE",
+ "4": "FAIL",
+ }
+ renamed_doc = {}
+ for key, value in doc.to_dict().items():
+ renamed_doc[key_mapping.get(key, key)] = value
+ if key == "run":
+ renamed_doc["run"] = run_mapping.get(str(value))
+ return renamed_doc
+
+
+def _strip_chunk_runtime_fields(chunk):
+ for name in [name for name in chunk.keys() if re.search(r"(_vec$|_sm_|_tks|_ltks)", name)]:
+ del chunk[name]
+ return chunk
+
+
+@manager.route("/datasets//documents//chunks", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def list_chunks(tenant_id, dataset_id, document_id):
+ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
+ return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
+ doc = DocumentService.query(id=document_id, kb_id=dataset_id)
+ if not doc:
+ return get_error_data_result(message=f"You don't own the document {document_id}.")
+ doc = doc[0]
+ req = request.args
+ page = int(req.get("page", 1))
+ size = int(req.get("page_size", 30))
+ question = req.get("keywords", "")
+ query = {
+ "doc_ids": [document_id],
+ "page": page,
+ "size": size,
+ "question": question,
+ "sort": True,
+ }
+ if "available" in req:
+ query["available_int"] = 1 if req["available"] == "true" else 0
+
+ res = {"total": 0, "chunks": [], "doc": _map_doc(doc)}
+ if req.get("id"):
+ chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
+ if not chunk:
+ return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.DATA_ERROR)
+ if str(chunk.get("doc_id", chunk.get("document_id"))) != str(document_id):
+ return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.DATA_ERROR)
+ _strip_chunk_runtime_fields(chunk)
+ res["total"] = 1
+ final_chunk = {
+ "id": chunk.get("id", chunk.get("chunk_id")),
+ "content": chunk["content_with_weight"],
+ "document_id": chunk.get("doc_id", chunk.get("document_id")),
+ "docnm_kwd": chunk["docnm_kwd"],
+ "important_keywords": chunk.get("important_kwd", []),
+ "questions": chunk.get("question_kwd", []),
+ "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
+ "image_id": chunk.get("img_id", ""),
+ "available": bool(chunk.get("available_int", 1)),
+ "positions": chunk.get("position_int", []),
+ "tag_kwd": chunk.get("tag_kwd", []),
+ "tag_feas": chunk.get("tag_feas", {}),
+ }
+ res["chunks"].append(final_chunk)
+ _ = Chunk(**final_chunk)
+ elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
+ sres = await settings.retriever.search(
+ query,
+ search.index_name(tenant_id),
+ [dataset_id],
+ emb_mdl=None,
+ highlight=True,
+ )
+ res["total"] = sres.total
+ for chunk_id in sres.ids:
+ d = {
+ "id": chunk_id,
+ "content": (
+ remove_redundant_spaces(sres.highlight[chunk_id])
+ if question and chunk_id in sres.highlight
+ else sres.field[chunk_id].get("content_with_weight", "")
+ ),
+ "document_id": sres.field[chunk_id]["doc_id"],
+ "docnm_kwd": sres.field[chunk_id]["docnm_kwd"],
+ "important_keywords": sres.field[chunk_id].get("important_kwd", []),
+ "tag_kwd": sres.field[chunk_id].get("tag_kwd", []),
+ "questions": sres.field[chunk_id].get("question_kwd", []),
+ "dataset_id": sres.field[chunk_id].get("kb_id", sres.field[chunk_id].get("dataset_id")),
+ "image_id": sres.field[chunk_id].get("img_id", ""),
+ "available": bool(int(sres.field[chunk_id].get("available_int", "1"))),
+ "positions": sres.field[chunk_id].get("position_int", []),
+ }
+ res["chunks"].append(d)
+ _ = Chunk(**d)
+ return get_result(data=res)
+
+
+@manager.route("/datasets//documents//chunks/", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def get_chunk(tenant_id, dataset_id, document_id, chunk_id):
+ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
+ return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
+ doc = DocumentService.query(id=document_id, kb_id=dataset_id)
+ if not doc:
+ return get_error_data_result(message=f"You don't own the document {document_id}.")
+ try:
+ chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
+ if chunk is None or str(chunk.get("doc_id", chunk.get("document_id"))) != str(document_id):
+ return get_result(data=False, message="Chunk not found!", code=RetCode.DATA_ERROR)
+ return get_result(data=_strip_chunk_runtime_fields(chunk))
+ except Exception as e:
+ if str(e).find("NotFoundError") >= 0:
+ return get_result(data=False, message="Chunk not found!", code=RetCode.DATA_ERROR)
+ return server_error_response(e)
+
+
+@manager.route("/datasets//documents//chunks", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def add_chunk(tenant_id, dataset_id, document_id):
+ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
+ return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
+ doc = DocumentService.query(id=document_id, kb_id=dataset_id)
+ if not doc:
+ return get_error_data_result(message=f"You don't own the document {document_id}.")
+ doc = doc[0]
+ req = await get_request_json()
+ if is_content_empty(req.get("content")):
+ return get_error_data_result(message="`content` is required")
+ if "important_keywords" in req and not isinstance(req["important_keywords"], list):
+ return get_error_data_result("`important_keywords` is required to be a list")
+ if "questions" in req and not isinstance(req["questions"], list):
+ return get_error_data_result("`questions` is required to be a list")
+
+ chunk_id = xxhash.xxh64((req["content"] + document_id).encode("utf-8")).hexdigest()
+ d = {
+ "id": chunk_id,
+ "content_ltks": rag_tokenizer.tokenize(req["content"]),
+ "content_with_weight": req["content"],
+ }
+ d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
+ d["important_kwd"] = req.get("important_keywords", [])
+ d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_keywords", [])))
+ d["question_kwd"] = [str(q).strip() for q in req.get("questions", []) if str(q).strip()]
+ d["question_tks"] = rag_tokenizer.tokenize("\n".join(req.get("questions", [])))
+ d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
+ d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
+ d["kb_id"] = dataset_id
+ d["docnm_kwd"] = doc.name
+ d["doc_id"] = document_id
+
+ if "tag_kwd" in req:
+ if not isinstance(req["tag_kwd"], list):
+ return get_error_data_result("`tag_kwd` is required to be a list")
+ if not all(isinstance(t, str) for t in req["tag_kwd"]):
+ return get_error_data_result("`tag_kwd` must be a list of strings")
+ d["tag_kwd"] = req["tag_kwd"]
+ if "tag_feas" in req:
+ try:
+ d["tag_feas"] = validate_tag_features(req["tag_feas"])
+ except ValueError as exc:
+ return get_error_data_result(f"`tag_feas` {exc}")
+
+ image_base64 = req.get("image_base64")
+ if image_base64:
+ d["img_id"] = f"{dataset_id}-{chunk_id}"
+ d["doc_type_kwd"] = "image"
+
+ tenant_embd_id = DocumentService.get_tenant_embd_id(document_id)
+ if tenant_embd_id:
+ model_config = get_model_config_by_id(tenant_embd_id)
+ else:
+ embd_id = DocumentService.get_embd_id(document_id)
+ model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id)
+ embd_mdl = TenantLLMService.model_instance(model_config)
+ v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
+ v = 0.1 * v[0] + 0.9 * v[1]
+ d[f"q_{len(v)}_vec"] = v.tolist()
+ settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
+
+ if image_base64:
+ store_chunk_image(dataset_id, chunk_id, base64.b64decode(image_base64))
+
+ DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
+ key_mapping = {
+ "id": "id",
+ "content_with_weight": "content",
+ "doc_id": "document_id",
+ "important_kwd": "important_keywords",
+ "tag_kwd": "tag_kwd",
+ "question_kwd": "questions",
+ "kb_id": "dataset_id",
+ "create_timestamp_flt": "create_timestamp",
+ "create_time": "create_time",
+ "document_keyword": "document",
+ "img_id": "image_id",
+ }
+ renamed_chunk = {new_key: d[key] for key, new_key in key_mapping.items() if key in d}
+ _ = Chunk(**renamed_chunk)
+ return get_result(data={"chunk": renamed_chunk})
+
+
+@manager.route("/datasets//documents//chunks", methods=["DELETE"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def rm_chunk(tenant_id, dataset_id, document_id):
+ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
+ return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
+ docs = DocumentService.query(id=document_id, kb_id=dataset_id)
+ if not docs:
+ return get_error_data_result(message=f"You don't own the document {document_id}.")
+ req = await get_request_json()
+ if not req:
+ return get_result()
+
+ chunk_ids = req.get("chunk_ids")
+ if not chunk_ids:
+ if req.get("delete_all") is True:
+ doc = docs[0]
+ DocumentService.delete_chunk_images(doc, tenant_id)
+ chunk_number = settings.docStoreConn.delete({"doc_id": document_id}, search.index_name(tenant_id), dataset_id)
+ if chunk_number != 0:
+ DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
+ return get_result(message=f"deleted {chunk_number} chunks")
+ return get_result()
+
+ unique_chunk_ids, duplicate_messages = check_duplicate_ids(chunk_ids, "chunk")
+ chunk_number = settings.docStoreConn.delete(
+ {"doc_id": document_id, "id": unique_chunk_ids},
+ search.index_name(tenant_id),
+ dataset_id,
+ )
+ if chunk_number != 0:
+ DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
+ if chunk_number != len(unique_chunk_ids):
+ if len(unique_chunk_ids) == 0:
+ return get_result(message=f"deleted {chunk_number} chunks")
+ return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(unique_chunk_ids)}")
+ if duplicate_messages:
+ return get_result(
+ message=f"Partially deleted {chunk_number} chunks with {len(duplicate_messages)} errors",
+ data={"success_count": chunk_number, "errors": duplicate_messages},
+ )
+ return get_result(message=f"deleted {chunk_number} chunks")
+
+
+@manager.route("/datasets//documents//chunks/", methods=["PATCH"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
+ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
+ return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
+ doc = DocumentService.query(id=document_id, kb_id=dataset_id)
+ if not doc:
+ return get_error_data_result(message=f"You don't own the document {document_id}.")
+ doc = doc[0]
+ chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
+ if chunk is None or str(chunk.get("doc_id", chunk.get("document_id"))) != str(document_id):
+ return get_error_data_result(f"Can't find this chunk {chunk_id}")
+ req = await get_request_json()
+ content = req.get("content")
+ if content is not None:
+ if is_content_empty(content):
+ return get_error_data_result(message="`content` is required")
+ else:
+ content = chunk.get("content_with_weight", "")
+ d = {"id": chunk_id, "content_with_weight": content}
+ d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
+ d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
+ if "important_keywords" in req:
+ if not isinstance(req["important_keywords"], list):
+ return get_error_data_result("`important_keywords` should be a list")
+ d["important_kwd"] = req.get("important_keywords", [])
+ d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"]))
+ if "questions" in req:
+ if not isinstance(req["questions"], list):
+ return get_error_data_result("`questions` should be a list")
+ d["question_kwd"] = [str(q).strip() for q in req.get("questions", []) if str(q).strip()]
+ d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["questions"]))
+ if "available" in req:
+ d["available_int"] = int(req["available"])
+ if "positions" in req:
+ if not isinstance(req["positions"], list):
+ return get_error_data_result("`positions` should be a list")
+ d["position_int"] = req["positions"]
+ if "tag_kwd" in req:
+ if not isinstance(req["tag_kwd"], list):
+ return get_error_data_result("`tag_kwd` should be a list")
+ if not all(isinstance(t, str) for t in req["tag_kwd"]):
+ return get_error_data_result("`tag_kwd` must be a list of strings")
+ d["tag_kwd"] = req["tag_kwd"]
+ if "tag_feas" in req:
+ try:
+ d["tag_feas"] = validate_tag_features(req["tag_feas"])
+ except ValueError as exc:
+ return get_error_data_result(f"`tag_feas` {exc}")
+ image_base64 = req.get("image_base64")
+ if image_base64:
+ d["img_id"] = f"{dataset_id}-{chunk_id}"
+ d["doc_type_kwd"] = "image"
+
+ tenant_embd_id = DocumentService.get_tenant_embd_id(document_id)
+ if tenant_embd_id:
+ model_config = get_model_config_by_id(tenant_embd_id)
+ else:
+ embd_id = DocumentService.get_embd_id(document_id)
+ model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id)
+ embd_mdl = TenantLLMService.model_instance(model_config)
+ if doc.parser_id == ParserType.QA:
+ arr = [t for t in re.split(r"[\n\t]", d["content_with_weight"]) if len(t) > 1]
+ if len(arr) != 2:
+ return get_error_data_result(message="Q&A must be separated by TAB/ENTER key.")
+ q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
+ d = beAdoc(d, arr[0], arr[1], not any([rag_tokenizer.is_chinese(t) for t in q + a]))
+
+ v, _ = embd_mdl.encode(
+ [
+ doc.name,
+ d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"]),
+ ]
+ )
+ v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
+ d[f"q_{len(v)}_vec"] = v.tolist()
+ settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
+ if image_base64:
+ store_chunk_image(dataset_id, chunk_id, base64.b64decode(image_base64))
+ return get_result()
+
+
+@manager.route("/datasets//documents//chunks", methods=["PATCH"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def switch_chunks(tenant_id, dataset_id, document_id):
+ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
+ return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
+ req = await get_request_json()
+ if not req.get("chunk_ids"):
+ return get_error_data_result(message="`chunk_ids` is required.")
+ if "available_int" not in req and "available" not in req:
+ return get_error_data_result(message="`available_int` or `available` is required.")
+ available_int = int(req["available_int"]) if "available_int" in req else (1 if req.get("available") else 0)
+
+ try:
+ def _switch_sync():
+ e, doc = DocumentService.get_by_id(document_id)
+ if not e:
+ return get_error_data_result(message="Document not found!")
+ if not doc or str(doc.kb_id) != str(dataset_id):
+ return get_error_data_result(message="Document not found!")
+ for cid in req["chunk_ids"]:
+ if not settings.docStoreConn.update(
+ {"id": cid},
+ {"available_int": available_int},
+ search.index_name(tenant_id),
+ doc.kb_id,
+ ):
+ return get_error_data_result(message="Index updating failure")
+ return get_result(data=True)
+
+ return await thread_pool_exec(_switch_sync)
+ except Exception as e:
+ return server_error_response(e)
diff --git a/api/apps/connector_app.py b/api/apps/restful_apis/connector_api.py
similarity index 86%
rename from api/apps/connector_app.py
rename to api/apps/restful_apis/connector_api.py
index 0c123f70077..99a58930211 100644
--- a/api/apps/connector_app.py
+++ b/api/apps/restful_apis/connector_api.py
@@ -35,15 +35,30 @@
from api.apps import login_required, current_user
from box_sdk_gen import BoxOAuth, OAuthConfig, GetAuthorizeUrlOptions
-
-@manager.route("/set", methods=["POST"]) # noqa: F821
+@manager.route("/connectors/", methods=["PATCH"]) # noqa: F821
@login_required
-async def set_connector():
+async def update_connector(connector_id):
req = await get_request_json()
- if req.get("id"):
+ e, conn = ConnectorService.get_by_id(connector_id)
+ if not e:
+ return get_data_error_result(message="Can't find this Connector!")
+
+ if req:
conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
- ConnectorService.update_by_id(req["id"], conn)
- else:
+ conn["id"] = connector_id
+ ConnectorService.update_by_id(connector_id, conn)
+
+ await asyncio.sleep(1)
+ e, conn = ConnectorService.get_by_id(connector_id)
+
+ return get_json_result(data=conn.to_dict())
+
+
+@manager.route("/connectors", methods=["POST"]) # noqa: F821
+@login_required
+async def create_connector():
+ req = await get_request_json()
+ if req:
req["id"] = get_uuid()
conn = {
"id": req["id"],
@@ -65,13 +80,13 @@ async def set_connector():
return get_json_result(data=conn.to_dict())
-@manager.route("/list", methods=["GET"]) # noqa: F821
+@manager.route("/connectors", methods=["GET"]) # noqa: F821
@login_required
def list_connector():
return get_json_result(data=ConnectorService.list(current_user.id))
-@manager.route("/", methods=["GET"]) # noqa: F821
+@manager.route("/connectors/", methods=["GET"]) # noqa: F821
@login_required
def get_connector(connector_id):
e, conn = ConnectorService.get_by_id(connector_id)
@@ -80,7 +95,7 @@ def get_connector(connector_id):
return get_json_result(data=conn.to_dict())
-@manager.route("//logs", methods=["GET"]) # noqa: F821
+@manager.route("/connectors//logs", methods=["GET"]) # noqa: F821
@login_required
def list_logs(connector_id):
req = request.args.to_dict(flat=True)
@@ -88,7 +103,7 @@ def list_logs(connector_id):
return get_json_result(data={"total": total, "logs": arr})
-@manager.route("//resume", methods=["PUT"]) # noqa: F821
+@manager.route("/connectors//resume", methods=["POST"]) # noqa: F821
@login_required
async def resume(connector_id):
req = await get_request_json()
@@ -99,7 +114,7 @@ async def resume(connector_id):
return get_json_result(data=True)
-@manager.route("//rebuild", methods=["PUT"]) # noqa: F821
+@manager.route("/connectors//rebuild", methods=["POST"]) # noqa: F821
@login_required
@validate_request("kb_id")
async def rebuild(connector_id):
@@ -110,7 +125,7 @@ async def rebuild(connector_id):
return get_json_result(data=True)
-@manager.route("//rm", methods=["POST"]) # noqa: F821
+@manager.route("/connectors/", methods=["DELETE"]) # noqa: F821
@login_required
def rm_connector(connector_id):
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
@@ -157,6 +172,22 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
return {"web": web_section}
+def _exchange_google_web_oauth_code(
+ client_config: dict[str, Any],
+ scopes: list[str],
+ redirect_uri: str,
+ code: str,
+ code_verifier: str | None,
+) -> Flow:
+ flow = Flow.from_client_config(client_config, scopes=scopes)
+ flow.redirect_uri = redirect_uri
+ fetch_token_kwargs: dict[str, Any] = {"code": code}
+ if code_verifier:
+ fetch_token_kwargs["code_verifier"] = code_verifier
+ flow.fetch_token(**fetch_token_kwargs)
+ return flow
+
+
async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"):
status = "success" if success else "error"
auto_close = "window.close();" if success else ""
@@ -185,7 +216,7 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, sou
return response
-@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821
+@manager.route("/connectors/google/oauth/web/start", methods=["POST"]) # noqa: F821
@login_required
@validate_request("credentials")
async def start_google_web_oauth():
@@ -252,6 +283,7 @@ async def start_google_web_oauth():
"user_id": current_user.id,
"client_config": client_config,
"redirect_uri": redirect_uri,
+ "code_verifier": flow.code_verifier,
"created_at": int(time.time()),
}
REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS)
@@ -265,7 +297,7 @@ async def start_google_web_oauth():
)
-@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821
+@manager.route("/connectors/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821
async def google_gmail_web_oauth_callback():
state_id = request.args.get("state")
error = request.args.get("error")
@@ -283,6 +315,7 @@ async def google_gmail_web_oauth_callback():
state_obj = json.loads(state_cache)
client_config = state_obj.get("client_config")
redirect_uri = state_obj.get("redirect_uri", GMAIL_WEB_OAUTH_REDIRECT_URI)
+ code_verifier = state_obj.get("code_verifier")
if not client_config:
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
@@ -296,10 +329,13 @@ async def google_gmail_web_oauth_callback():
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
try:
- # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
- flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL])
- flow.redirect_uri = redirect_uri
- flow.fetch_token(code=code)
+ flow = _exchange_google_web_oauth_code(
+ client_config=client_config,
+ scopes=GOOGLE_SCOPES[DocumentSource.GMAIL],
+ redirect_uri=redirect_uri,
+ code=code,
+ code_verifier=code_verifier,
+ )
except Exception as exc: # pragma: no cover - defensive
logging.exception("Failed to exchange Google OAuth code: %s", exc)
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
@@ -316,7 +352,7 @@ async def google_gmail_web_oauth_callback():
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
-@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
+@manager.route("/connectors/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
async def google_drive_web_oauth_callback():
state_id = request.args.get("state")
error = request.args.get("error")
@@ -334,6 +370,7 @@ async def google_drive_web_oauth_callback():
state_obj = json.loads(state_cache)
client_config = state_obj.get("client_config")
redirect_uri = state_obj.get("redirect_uri", GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI)
+ code_verifier = state_obj.get("code_verifier")
if not client_config:
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source)
@@ -347,10 +384,13 @@ async def google_drive_web_oauth_callback():
return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source)
try:
- # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail)
- flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
- flow.redirect_uri = redirect_uri
- flow.fetch_token(code=code)
+ flow = _exchange_google_web_oauth_code(
+ client_config=client_config,
+ scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE],
+ redirect_uri=redirect_uri,
+ code=code,
+ code_verifier=code_verifier,
+ )
except Exception as exc: # pragma: no cover - defensive
logging.exception("Failed to exchange Google OAuth code: %s", exc)
REDIS_CONN.delete(_web_state_cache_key(state_id, source))
@@ -366,7 +406,7 @@ async def google_drive_web_oauth_callback():
return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source)
-@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821
+@manager.route("/connectors/google/oauth/web/result", methods=["POST"]) # noqa: F821
@login_required
@validate_request("flow_id")
async def poll_google_web_result():
@@ -386,7 +426,7 @@ async def poll_google_web_result():
REDIS_CONN.delete(_web_result_cache_key(flow_id, source))
return get_json_result(data={"credentials": result.get("credentials")})
-@manager.route("/box/oauth/web/start", methods=["POST"]) # noqa: F821
+@manager.route("/connectors/box/oauth/web/start", methods=["POST"]) # noqa: F821
@login_required
async def start_box_web_oauth():
req = await get_request_json()
@@ -429,7 +469,7 @@ async def start_box_web_oauth():
"expires_in": WEB_FLOW_TTL_SECS,}
)
-@manager.route("/box/oauth/web/callback", methods=["GET"]) # noqa: F821
+@manager.route("/connectors/box/oauth/web/callback", methods=["GET"]) # noqa: F821
async def box_web_oauth_callback():
flow_id = request.args.get("state")
if not flow_id:
@@ -471,7 +511,7 @@ async def box_web_oauth_callback():
return await _render_web_oauth_popup(flow_id, True, "Authorization completed successfully.", "box")
-@manager.route("/box/oauth/web/result", methods=["POST"]) # noqa: F821
+@manager.route("/connectors/box/oauth/web/result", methods=["POST"]) # noqa: F821
@login_required
@validate_request("flow_id")
async def poll_box_web_result():
diff --git a/api/apps/restful_apis/dataset_api.py b/api/apps/restful_apis/dataset_api.py
index 4f3ff2d59a4..55ded90e028 100644
--- a/api/apps/restful_apis/dataset_api.py
+++ b/api/apps/restful_apis/dataset_api.py
@@ -19,11 +19,13 @@
from quart import request
from common.constants import RetCode
from api.apps import login_required, current_user
-from api.utils.api_utils import get_error_argument_result, get_error_data_result, get_result, add_tenant_id_to_kwargs
+from api.utils.api_utils import get_error_argument_result, get_error_data_result, get_json_result, get_result, add_tenant_id_to_kwargs
from api.utils.validation_utils import (
CreateDatasetReq,
DeleteDatasetReq,
ListDatasetReq,
+ SearchDatasetReq,
+ SearchDatasetsReq,
UpdateDatasetReq,
validate_and_parse_json_request,
validate_and_parse_request_args,
@@ -31,10 +33,54 @@
from api.apps.services import dataset_api_service
+@manager.route("/datasets/tags/aggregation", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+def aggregate_tags(tenant_id):
+ dataset_ids = request.args.get("dataset_ids", "").split(",")
+ dataset_ids = [d for d in dataset_ids if d]
+ if not dataset_ids:
+ return get_error_data_result(message="Lack of dataset_ids in query parameters")
+
+ try:
+ success, result = dataset_api_service.aggregate_tags(dataset_ids, tenant_id)
+ if success:
+ return get_result(data=result)
+ else:
+ return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets/metadata/flattened", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+def get_flattened_metadata(tenant_id):
+ dataset_ids = request.args.get("dataset_ids", "").split(",")
+ dataset_ids = [d for d in dataset_ids if d]
+ if not dataset_ids:
+ return get_error_data_result(message="Lack of dataset_ids in query parameters")
+
+ try:
+ success, result = dataset_api_service.get_flattened_metadata(dataset_ids, tenant_id)
+ if success:
+ return get_result(data=result)
+ else:
+ return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
@manager.route("/datasets", methods=["POST"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
-async def create(tenant_id: str=None):
+async def create(tenant_id: str = None):
"""
Create a new dataset.
---
@@ -102,6 +148,8 @@ async def create(tenant_id: str=None):
return get_result(data=result)
else:
return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
@@ -330,26 +378,188 @@ def list_datasets(tenant_id):
return get_error_data_result(message="Internal server error")
-@manager.route('/datasets//knowledge_graph', methods=['GET']) # noqa: F821
+@manager.route("/datasets/", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+def get_dataset(tenant_id, dataset_id):
+ try:
+ success, result = dataset_api_service.get_dataset(dataset_id, tenant_id)
+ if success:
+ return get_result(data=result)
+ else:
+ return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets//ingestions/summary", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+def get_ingestion_summary(tenant_id, dataset_id):
+ try:
+ success, result = dataset_api_service.get_ingestion_summary(dataset_id, tenant_id)
+ if success:
+ return get_result(data=result)
+ else:
+ return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets//tags", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+def list_tags(tenant_id, dataset_id):
+ try:
+ success, result = dataset_api_service.list_tags(dataset_id, tenant_id)
+ if success:
+ return get_result(data=result)
+ else:
+ return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets//tags", methods=["DELETE"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def delete_tags(tenant_id, dataset_id):
+ req = await request.get_json()
+ if not req or "tags" not in req:
+ return get_error_data_result(message="Lack of tags in request body")
+ if not isinstance(req["tags"], list) or not all(isinstance(t, str) for t in req["tags"]):
+ return get_error_argument_result("tags must be a list of strings")
+
+ try:
+ success, result = dataset_api_service.delete_tags(dataset_id, tenant_id, req["tags"])
+ if success:
+ return get_result(data=result)
+ else:
+ return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets//tags", methods=["PUT"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
-async def knowledge_graph(tenant_id, dataset_id):
+async def rename_tag(tenant_id, dataset_id):
+ req = await request.get_json()
+ if not req or "from_tag" not in req or "to_tag" not in req:
+ return get_error_data_result(message="Lack of from_tag or to_tag in request body")
+ if not isinstance(req["from_tag"], str) or not isinstance(req["to_tag"], str):
+ return get_error_argument_result("from_tag and to_tag must be strings")
+
+ if not req["from_tag"].strip() or not req["to_tag"].strip():
+ return get_error_argument_result("from_tag and to_tag must not be empty")
+
+ try:
+ success, result = dataset_api_service.rename_tag(dataset_id, tenant_id, req["from_tag"], req["to_tag"])
+ if success:
+ return get_result(data=result)
+ else:
+ return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets/search", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def search_datasets(tenant_id):
+ """Search (retrieval test) across multiple datasets.
+
+ POST /api/v1/datasets/search
+ JSON body: {"dataset_ids": list[str] (required), "question": str (required), "doc_ids": list[str], "top_k": int, "page": int, "size": int,
+ "similarity_threshold": float, "vector_similarity_weight": float, "use_kg": bool,
+ "cross_languages": list[str], "keyword": bool, "meta_data_filter": dict}
+ Success: {"code": 0, "data": {"chunks": [...], "total": int, "labels": [...]}}
+ Errors: ARGUMENT_ERROR (101) for invalid payload; DATA_ERROR (102) for access denied or internal errors.
+ """
+ req, err = await validate_and_parse_json_request(request, SearchDatasetsReq)
+ if err is not None:
+ return get_error_argument_result(err)
+ try:
+ success, result = await dataset_api_service.search_datasets(tenant_id, req)
+ if success:
+ return get_result(data=result)
+ else:
+ return get_error_data_result(message=result)
+ except Exception as e:
+ logging.exception(e)
+ if "not_found" in str(e):
+ return get_error_data_result(message="No chunk found! Check the chunk status please!")
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets//search", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def search(tenant_id, dataset_id):
+ """Search (retrieval test) within a dataset.
+
+ POST /api/v1/datasets//search
+ JSON body: {"question": str (required), "doc_ids": list[str], "top_k": int, "page": int, "size": int,
+ "similarity_threshold": float, "vector_similarity_weight": float, "use_kg": bool,
+ "cross_languages": list[str], "keyword": bool, "meta_data_filter": dict}
+ Success: {"code": 0, "data": {"chunks": [...], "total": int, "labels": [...]}}
+ Errors: ARGUMENT_ERROR (101) for invalid payload; DATA_ERROR (102) for access denied or internal errors.
+ """
+ req, err = await validate_and_parse_json_request(request, SearchDatasetReq)
+ if err is not None:
+ return get_error_argument_result(err)
+ req['dataset_ids'] = [dataset_id]
+ try:
+ success, result = await dataset_api_service.search_datasets(tenant_id, req)
+ if success:
+ return get_result(data=result)
+ else:
+ return get_error_data_result(message=result)
+ except Exception as e:
+ logging.exception(e)
+ if "not_found" in str(e):
+ return get_error_data_result(message="No chunk found! Check the chunk status please!")
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets//graph", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def get_knowledge_graph(tenant_id, dataset_id):
+ """Get the knowledge graph of a dataset.
+
+ GET /api/v1/datasets//graph
+ Query params: optional filter params.
+ Success: {"code": 0, "data": {...}}
+ Errors: AUTHENTICATION_ERROR for access denied; DATA_ERROR for internal errors.
+ """
try:
success, result = await dataset_api_service.get_knowledge_graph(dataset_id, tenant_id)
if success:
return get_result(data=result)
else:
- return get_result(
- data=False,
- message=result,
- code=RetCode.AUTHENTICATION_ERROR
- )
+ return get_result(data=False, message=result, code=RetCode.AUTHENTICATION_ERROR)
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
-@manager.route('/datasets//knowledge_graph', methods=['DELETE']) # noqa: F821
+@manager.route("/datasets//graph", methods=["DELETE"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
def delete_knowledge_graph(tenant_id, dataset_id):
@@ -358,67 +568,82 @@ def delete_knowledge_graph(tenant_id, dataset_id):
if success:
return get_result(data=result)
else:
- return get_result(
- data=False,
- message=result,
- code=RetCode.AUTHENTICATION_ERROR
- )
+ return get_result(data=False, message=result, code=RetCode.AUTHENTICATION_ERROR)
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
-@manager.route("/datasets//run_graphrag", methods=["POST"]) # noqa: F821
+@manager.route("/datasets//index", methods=["POST"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
-async def run_graphrag(tenant_id, dataset_id):
+async def run_index(tenant_id, dataset_id):
+ index_type = request.args.get("type", "")
+ index_type = index_type.lower()
try:
- success, result = dataset_api_service.run_graphrag(dataset_id, tenant_id)
+ success, result = dataset_api_service.run_index(dataset_id, tenant_id, index_type)
if success:
return get_result(data=result)
else:
return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
-@manager.route("/datasets//trace_graphrag", methods=["GET"]) # noqa: F821
+@manager.route("/datasets//index", methods=["GET"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
-def trace_graphrag(tenant_id, dataset_id):
+def trace_index(tenant_id, dataset_id):
+ index_type = request.args.get("type", "")
+ index_type = index_type.lower()
try:
- success, result = dataset_api_service.trace_graphrag(dataset_id, tenant_id)
+ success, result = dataset_api_service.trace_index(dataset_id, tenant_id, index_type)
if success:
return get_result(data=result)
else:
return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
-@manager.route("/datasets//run_raptor", methods=["POST"]) # noqa: F821
+@manager.route("/datasets//", methods=["DELETE"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
-async def run_raptor(tenant_id, dataset_id):
+def delete_index(tenant_id, dataset_id, index_type):
+ index_type = index_type.lower()
+ if index_type not in dataset_api_service._VALID_INDEX_TYPES:
+ return get_error_argument_result(f"Invalid index type '{index_type}'")
+ # `wipe` controls whether the persisted index artefacts (graph rows /
+ # raptor summaries) are removed. Default true preserves historical
+ # behaviour; pass wipe=false to cancel the running task while keeping
+ # prior progress so it can be resumed later.
+ wipe_arg = (request.args.get("wipe", "true") or "true").strip().lower()
+ wipe = wipe_arg not in ("false", "0", "no", "off")
try:
- success, result = dataset_api_service.run_raptor(dataset_id, tenant_id)
+ success, result = dataset_api_service.delete_index(dataset_id, tenant_id, index_type, wipe=wipe)
if success:
return get_result(data=result)
else:
return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
-@manager.route("/datasets//trace_raptor", methods=["GET"]) # noqa: F821
+@manager.route("/datasets//embedding", methods=["POST"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
-def trace_raptor(tenant_id, dataset_id):
+async def run_embedding(tenant_id, dataset_id):
try:
- success, result = dataset_api_service.trace_raptor(dataset_id, tenant_id)
+ success, result = dataset_api_service.run_embedding(dataset_id, tenant_id)
if success:
return get_result(data=result)
else:
@@ -428,7 +653,70 @@ def trace_raptor(tenant_id, dataset_id):
return get_error_data_result(message="Internal server error")
-@manager.route("/datasets//auto_metadata", methods=["GET"]) # noqa: F821
+@manager.route("/datasets//embedding/check", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def check_embedding(tenant_id, dataset_id):
+ try:
+ req = await request.get_json()
+ if not req or not req.get("embd_id"):
+ return get_error_data_result(message="`embd_id` is required.")
+ status, result = dataset_api_service.check_embedding(dataset_id, tenant_id, req)
+ if status is True:
+ return get_result(data=result)
+ elif status == "not_effective":
+ return get_json_result(code=result["code"], message=result["message"], data=result["data"])
+ else:
+ return get_error_data_result(message=result)
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets//ingestions", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+def list_ingestion_logs(tenant_id, dataset_id):
+ try:
+ page = int(request.args.get("page", 0))
+ page_size = int(request.args.get("page_size", 0))
+ orderby = request.args.get("orderby", "create_time")
+ desc = request.args.get("desc", "true").lower() != "false"
+ operation_status = request.args.getlist("operation_status")
+ create_date_from = request.args.get("create_date_from", None)
+ create_date_to = request.args.get("create_date_to", None)
+ log_type = request.args.get("log_type", "dataset")
+ keywords = request.args.get("keywords", None)
+ success, result = dataset_api_service.list_ingestion_logs(dataset_id, tenant_id, page, page_size, orderby, desc, operation_status, create_date_from, create_date_to, log_type, keywords)
+ if success:
+ return get_result(data=result)
+ else:
+ return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets//ingestions/", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+def get_ingestion_log(tenant_id, dataset_id, log_id):
+ try:
+ success, result = dataset_api_service.get_ingestion_log(dataset_id, tenant_id, log_id)
+ if success:
+ return get_result(data=result)
+ else:
+ return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets//metadata/config", methods=["GET"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
def get_auto_metadata(tenant_id, dataset_id):
@@ -462,12 +750,14 @@ def get_auto_metadata(tenant_id, dataset_id):
return get_result(data=result)
else:
return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
-@manager.route("/datasets//auto_metadata", methods=["PUT"]) # noqa: F821
+@manager.route("/datasets//metadata/config", methods=["PUT"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
async def update_auto_metadata(tenant_id, dataset_id):
@@ -502,6 +792,7 @@ async def update_auto_metadata(tenant_id, dataset_id):
type: object
"""
from api.utils.validation_utils import AutoMetadataConfig
+
cfg, err = await validate_and_parse_json_request(request, AutoMetadataConfig)
if err is not None:
return get_error_argument_result(err)
@@ -512,6 +803,8 @@ async def update_auto_metadata(tenant_id, dataset_id):
return get_result(data=result)
else:
return get_error_data_result(message=result)
+ except ValueError as e:
+ return get_error_argument_result(str(e))
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
diff --git a/api/apps/restful_apis/document_api.py b/api/apps/restful_apis/document_api.py
index b2e749f3e51..7300a55a9f7 100644
--- a/api/apps/restful_apis/document_api.py
+++ b/api/apps/restful_apis/document_api.py
@@ -15,26 +15,107 @@
#
import logging
import json
+import os
+import re
+from pathlib import Path
-from quart import request
+from quart import request, make_response
from peewee import OperationalError
from pydantic import ValidationError
-from api.apps import login_required
+from api.apps import current_user, login_required
+from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
from api.apps.services.document_api_service import validate_document_update_fields, map_doc_keys, \
- map_doc_keys_with_run_status, update_document_name_only, update_chunk_method_only, update_document_status_only
-from api.constants import IMG_BASE64_PREFIX
-from api.db import VALID_FILE_TYPES
+ map_doc_keys_with_run_status, update_document_name_only, update_chunk_method, update_document_status_only, \
+ reset_document_for_reparse
+from api.db import VALID_FILE_TYPES, FileType
+from api.db.services import duplicate_name
from api.db.services.doc_metadata_service import DocMetadataService
+from api.db.db_models import Task
from api.db.services.document_service import DocumentService
+from api.db.services.file2document_service import File2DocumentService
+from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
+from api.common.check_team_permission import check_kb_team_permission
+from api.db.services.task_service import TaskService, cancel_all_task_of
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_result, get_json_result, \
- server_error_response, add_tenant_id_to_kwargs, get_request_json
+ server_error_response, add_tenant_id_to_kwargs, get_request_json, get_error_argument_result, check_duplicate_ids
from api.utils.validation_utils import (
- UpdateDocumentReq, format_validation_error_message,
+ UpdateDocumentReq, format_validation_error_message, validate_and_parse_json_request, DeleteDocumentReq,
)
-from common.constants import RetCode
+
+from common import settings
+from common.constants import ParserType, RetCode, TaskStatus, SANDBOX_ARTIFACT_BUCKET
from common.metadata_utils import convert_conditions, meta_filter, turn2jsonschema
+from common.misc_utils import get_uuid, thread_pool_exec
+from api.utils.file_utils import filename_type, thumbnail
+from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url, apply_safe_file_response_headers
+from common.ssrf_guard import assert_url_is_safe
+from rag.nlp import search
+
+
+@manager.route("/documents/upload", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def upload_info(tenant_id: str):
+ """
+ Upload a document and get its parsed info.
+ ---
+ tags:
+ - Documents
+ security:
+ - ApiKeyAuth: []
+ parameters:
+ - in: header
+ name: Authorization
+ type: string
+ required: true
+ description: Bearer token for authentication.
+ - in: formData
+ name: file
+ type: file
+ required: false
+ description: File to upload.
+ - in: query
+ name: url
+ type: string
+ required: false
+ description: URL to fetch file from.
+ responses:
+ 200:
+ description: Successful operation.
+ """
+ files = await request.files
+ file_objs = files.getlist("file") if files and files.get("file") else []
+ url = request.args.get("url")
+
+ if file_objs and url:
+ return get_error_argument_result("Provide either multipart file(s) or ?url=..., not both.")
+
+ if not file_objs and not url:
+ return get_error_argument_result("Missing input: provide multipart file(s) or url")
+
+ try:
+ if url and not file_objs:
+ try:
+ assert_url_is_safe(url)
+ except ValueError as ve:
+ logging.warning("upload_info: rejected unsafe url: %s", ve)
+ return get_error_argument_result(str(ve))
+
+ data = await thread_pool_exec(FileService.upload_info, tenant_id, None, url)
+ return get_result(data=data)
+
+ if len(file_objs) == 1:
+ data = await thread_pool_exec(FileService.upload_info, tenant_id, file_objs[0], None)
+ return get_result(data=data)
+
+ results = [await thread_pool_exec(FileService.upload_info, tenant_id, f, None) for f in file_objs]
+ return get_result(data=results)
+ except Exception as e:
+ logging.exception("upload_info failed")
+ return server_error_response(e)
+
@manager.route("/datasets//documents/", methods=["PATCH"]) # noqa: F821
@login_required
@@ -125,16 +206,26 @@ async def update_document(tenant_id, dataset_id, document_id):
if error := update_document_name_only(document_id, req["name"]):
return error
+ # "parser_id" provided but does not match with existing doc's file type
+ if "parser_id" in req and ((doc.type == FileType.VISUAL and req["parser_id"] != "picture")
+ or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation")):
+ return get_data_error_result(message="Not supported yet!")
+
# parser config provided (already validated in UpdateDocumentReq), update it
if update_doc_req.parser_config:
+ req["parser_config"].update(update_doc_req.parser_config.ext)
DocumentService.update_parser_config(doc.id, req["parser_config"])
+ # pipeline_id provided - reset document for reparse
+ if update_doc_req.pipeline_id:
+ if error := reset_document_for_reparse(doc, tenant_id, pipeline_id=update_doc_req.pipeline_id):
+ return error
# chunk method provided - the update method will check if it's different with existing one
- if update_doc_req.chunk_method:
- if error := update_chunk_method_only(req, doc, dataset_id, tenant_id):
+ elif update_doc_req.chunk_method:
+ if error := update_chunk_method(req, doc, tenant_id):
return error
- if "enabled" in req: # already checked in UpdateDocumentReq - it's int if it's present
+ if "enabled" in req: # already checked in UpdateDocumentReq - it's int if present
# "enabled" flag provided, the update method will check if it's changed and then update if so
if error := update_document_status_only(int(req["enabled"]), doc, kb):
return error
@@ -189,6 +280,88 @@ async def metadata_summary(dataset_id, tenant_id):
return server_error_response(e)
+@manager.route("/datasets//metadata/update", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def metadata_batch_update(dataset_id, tenant_id):
+ """
+ Batch update metadata for documents in a dataset.
+ ---
+ tags:
+ - Documents
+ security:
+ - ApiKeyAuth: []
+ parameters:
+ - in: path
+ name: dataset_id
+ type: string
+ required: true
+ description: ID of the dataset.
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ selector:
+ type: object
+ updates:
+ type: array
+ deletes:
+ type: array
+ responses:
+ 200:
+ description: Metadata updated successfully.
+ """
+ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
+ return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
+
+ req = await get_request_json()
+ selector = req.get("selector", {}) or {}
+ updates = req.get("updates", []) or []
+ deletes = req.get("deletes", []) or []
+
+ if not isinstance(selector, dict):
+ return get_error_data_result(message="selector must be an object.")
+ if not isinstance(updates, list) or not isinstance(deletes, list):
+ return get_error_data_result(message="updates and deletes must be lists.")
+
+ metadata_condition = selector.get("metadata_condition", {}) or {}
+ if metadata_condition and not isinstance(metadata_condition, dict):
+ return get_error_data_result(message="metadata_condition must be an object.")
+
+ document_ids = selector.get("document_ids", []) or []
+ if document_ids and not isinstance(document_ids, list):
+ return get_error_data_result(message="document_ids must be a list.")
+
+ for upd in updates:
+ if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
+ return get_error_data_result(message="Each update requires key and value.")
+ for d in deletes:
+ if not isinstance(d, dict) or not d.get("key"):
+ return get_error_data_result(message="Each delete requires key.")
+
+ target_doc_ids = set()
+ if document_ids:
+ kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
+ invalid_ids = set(document_ids) - set(kb_doc_ids)
+ if invalid_ids:
+ return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}")
+ target_doc_ids = set(document_ids)
+
+ if metadata_condition:
+ metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id])
+ filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
+ target_doc_ids = target_doc_ids & filtered_ids
+ if metadata_condition.get("conditions") and not target_doc_ids:
+ return get_result(data={"updated": 0, "matched_docs": 0})
+
+ target_doc_ids = list(target_doc_ids)
+ updated = DocMetadataService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes)
+ return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
+
+
@manager.route("/datasets//documents", methods=["POST"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
@@ -259,19 +432,148 @@ async def upload_document(dataset_id, tenant_id):
type: string
description: Processing status.
"""
- from api.constants import FILE_NAME_LEN_LIMIT
- from api.common.check_team_permission import check_kb_team_permission
- from api.db.services.file_service import FileService
- from common.misc_utils import thread_pool_exec
-
+ upload_type = (request.args.get("type") or "local").lower()
+ e, kb = KnowledgebaseService.get_by_id(dataset_id)
+ if not e:
+ logging.error(f"Can't find the dataset with ID {dataset_id}!")
+ return get_error_data_result(message=f"Can't find the dataset with ID {dataset_id}!", code=RetCode.DATA_ERROR)
+
+ if not check_kb_team_permission(kb, tenant_id):
+ logging.error("No authorization.")
+ return get_error_data_result(message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
+
+ if upload_type == "web":
+ return await _upload_web_document(dataset_id, kb, tenant_id)
+
+ if upload_type == "empty":
+ return await _upload_empty_document(dataset_id, kb, tenant_id)
+
+ if upload_type != "local":
+ return get_error_data_result(
+ message='`type` must be one of "local", "web", or "empty".',
+ code=RetCode.ARGUMENT_ERROR,
+ )
+
+ return await _upload_local_documents(kb, tenant_id)
+
+
+async def _upload_web_document(dataset_id, kb, tenant_id):
+ form = await request.form
+ name = (form.get("name") or "").strip()
+ url = form.get("url")
+
+ if not name:
+ return get_error_data_result(message='Lack of "name"', code=RetCode.ARGUMENT_ERROR)
+ if not url:
+ return get_error_data_result(message='Lack of "url"', code=RetCode.ARGUMENT_ERROR)
+ if len(name.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
+ return get_error_data_result(
+ message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.",
+ code=RetCode.ARGUMENT_ERROR,
+ )
+ if not is_valid_url(url):
+ return get_error_data_result(message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR)
+
+ blob = html2pdf(url)
+ if not blob:
+ return server_error_response(ValueError("Download failure."))
+
+ root_folder = FileService.get_root_folder(tenant_id)
+ FileService.init_knowledgebase_docs(root_folder["id"], tenant_id)
+ kb_root_folder = FileService.get_kb_folder(tenant_id)
+ kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
+
+ try:
+ filename = duplicate_name(DocumentService.query, name=f"{name}.pdf", kb_id=kb.id)
+ filetype = filename_type(filename)
+ if filetype == FileType.OTHER.value:
+ raise RuntimeError("This type of file has not been supported yet!")
+
+ location = filename
+ while settings.STORAGE_IMPL.obj_exist(dataset_id, location):
+ location += "_"
+ settings.STORAGE_IMPL.put(dataset_id, location, blob)
+
+ doc = {
+ "id": get_uuid(),
+ "kb_id": kb.id,
+ "parser_id": kb.parser_id,
+ "pipeline_id": kb.pipeline_id,
+ "parser_config": kb.parser_config,
+ "created_by": tenant_id,
+ "type": filetype,
+ "name": filename,
+ "location": location,
+ "size": len(blob),
+ "thumbnail": thumbnail(filename, blob),
+ "suffix": Path(filename).suffix.lstrip("."),
+ }
+ if doc["type"] == FileType.VISUAL:
+ doc["parser_id"] = ParserType.PICTURE.value
+ if doc["type"] == FileType.AURAL:
+ doc["parser_id"] = ParserType.AUDIO.value
+ if re.search(r"\.(ppt|pptx|pages)$", filename):
+ doc["parser_id"] = ParserType.PRESENTATION.value
+ if re.search(r"\.(eml)$", filename):
+ doc["parser_id"] = ParserType.EMAIL.value
+
+ DocumentService.insert(doc)
+ FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
+ return get_result(data=map_doc_keys_with_run_status(doc, run_status="0"))
+ except Exception as e:
+ return server_error_response(e)
+
+
+async def _upload_empty_document(dataset_id, kb, tenant_id):
+ req = await get_request_json()
+ name = (req.get("name") or "").strip()
+
+ if not name:
+ return get_error_data_result(message="File name can't be empty.", code=RetCode.ARGUMENT_ERROR)
+ if len(name.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
+ return get_error_data_result(
+ message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.",
+ code=RetCode.ARGUMENT_ERROR,
+ )
+ if DocumentService.query(name=name, kb_id=dataset_id):
+ return get_error_data_result(message="Duplicated document name in the same dataset.")
+
+ try:
+ kb_root_folder = FileService.get_kb_folder(kb.tenant_id)
+ if not kb_root_folder:
+ return get_error_data_result(message="Cannot find the root folder.")
+ kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
+ if not kb_folder:
+ return get_error_data_result(message="Cannot find the kb folder for this file.")
+
+ doc = DocumentService.insert(
+ {
+ "id": get_uuid(),
+ "kb_id": kb.id,
+ "parser_id": kb.parser_id,
+ "pipeline_id": kb.pipeline_id,
+ "parser_config": kb.parser_config,
+ "created_by": tenant_id,
+ "type": FileType.VIRTUAL,
+ "name": name,
+ "suffix": Path(name).suffix.lstrip("."),
+ "location": "",
+ "size": 0,
+ }
+ )
+ FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], kb.tenant_id)
+ return get_result(data=map_doc_keys(doc))
+ except Exception as e:
+ return server_error_response(e)
+
+
+async def _upload_local_documents(kb, tenant_id):
form = await request.form
files = await request.files
-
- # Validation
if "file" not in files:
logging.error("No file part!")
return get_error_data_result(message="No file part!", code=RetCode.ARGUMENT_ERROR)
-
+
file_objs = files.getlist("file")
for file_obj in file_objs:
if file_obj is None or file_obj.filename is None or file_obj.filename == "":
@@ -282,18 +584,6 @@ async def upload_document(dataset_id, tenant_id):
logging.error(msg)
return get_error_data_result(message=msg, code=RetCode.ARGUMENT_ERROR)
- # KB Lookup
- e, kb = KnowledgebaseService.get_by_id(dataset_id)
- if not e:
- logging.error(f"Can't find the dataset with ID {dataset_id}!")
- return get_error_data_result(message=f"Can't find the dataset with ID {dataset_id}!", code=RetCode.DATA_ERROR)
-
- # Permission Check
- if not check_kb_team_permission(kb, tenant_id):
- logging.error("No authorization.")
- return get_error_data_result(message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
-
- # File Upload (async)
err, files = await thread_pool_exec(
FileService.upload_document, kb, file_objs, tenant_id,
parent_path=form.get("parent_path")
@@ -307,10 +597,8 @@ async def upload_document(dataset_id, tenant_id):
msg = "There seems to be an issue with your file format. please verify it is correct and not corrupted."
logging.error(msg)
return get_error_data_result(message=msg, code=RetCode.DATA_ERROR)
-
- files = [f[0] for f in files] # remove the blob
- # Check if we should return raw files without document key mapping
+ files = [f[0] for f in files] # remove the blob
return_raw_files = request.args.get("return_raw_files", "false").lower() == "true"
if return_raw_files:
@@ -432,19 +720,24 @@ def list_docs(dataset_id, tenant_id):
logging.error(f"You don't own the dataset {dataset_id}. ")
return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
- err_code, err_msg, docs, total = _get_docs_with_request(request, dataset_id)
+ if request.args.get("type") == "filter":
+ err_code, err_msg, payload, total = _get_doc_filters_with_request(request, dataset_id)
+ if err_code != RetCode.SUCCESS:
+ return get_data_error_result(code=err_code, message=err_msg)
+ return get_json_result(data={"total": total, "filter": payload})
+
+ err_code, err_msg, payload, total = _get_docs_with_request(request, dataset_id)
if err_code != RetCode.SUCCESS:
return get_data_error_result(code=err_code, message=err_msg)
- renamed_doc_list = [map_doc_keys(doc) for doc in docs]
+ renamed_doc_list = [map_doc_keys(doc) for doc in payload]
for doc_item in renamed_doc_list:
if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
- doc_item["thumbnail"] = f"/v1/document/image/{dataset_id}-{doc_item['thumbnail']}"
+ doc_item["thumbnail"] = f"/api/v1/documents/images/{dataset_id}-{doc_item['thumbnail']}"
if doc_item.get("source_type"):
doc_item["source_type"] = doc_item["source_type"].split("/")[0]
if doc_item["parser_config"].get("metadata"):
doc_item["parser_config"]["metadata"] = turn2jsonschema(doc_item["parser_config"]["metadata"])
-
return get_json_result(data={"total": total, "docs": renamed_doc_list})
@@ -517,13 +810,21 @@ def _get_docs_with_request(req, dataset_id:str):
doc_name = q.get("name")
doc_id = q.get("id")
- if doc_id and not DocumentService.query(id=doc_id, kb_id=dataset_id):
- return RetCode.DATA_ERROR, f"You don't own the document {doc_id}.", [], 0
+ if doc_id:
+ if not DocumentService.query(id=doc_id, kb_id=dataset_id):
+ return RetCode.DATA_ERROR, f"You don't own the document {doc_id}.", [], 0
+ doc_ids_filter = [doc_id] # id provided, ignore other filters
if doc_name and not DocumentService.query(name=doc_name, kb_id=dataset_id):
return RetCode.DATA_ERROR, f"You don't own the document {doc_name}.", [], 0
+ doc_ids = q.getlist("ids")
+ if doc_id and len(doc_ids) > 0:
+ return RetCode.DATA_ERROR, f"Should not provide both 'id':{doc_id} and 'ids'{doc_ids}"
+ if len(doc_ids) > 0:
+ doc_ids_filter = doc_ids
+
docs, total = DocumentService.get_by_kb_id(dataset_id, page, page_size, orderby, desc, keywords, run_status_converted, types, suffix,
- doc_id=doc_id, name=doc_name, doc_ids_filter=doc_ids_filter, return_empty_metadata=return_empty_metadata)
+ name=doc_name, doc_ids=doc_ids_filter, return_empty_metadata=return_empty_metadata)
# time range filter (0 means no bound)
create_time_from = int(q.get("create_time_from", 0))
@@ -533,6 +834,40 @@ def _get_docs_with_request(req, dataset_id:str):
return RetCode.SUCCESS, "", docs, total
+
+def _get_doc_filters_with_request(req, dataset_id: str):
+ """Get aggregated document filters with request parameters from a dataset."""
+ q = req.args
+
+ keywords = q.get("keywords", "")
+
+ suffix = q.getlist("suffix")
+
+ types = q.getlist("types")
+ if types:
+ invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
+ if invalid_types:
+ msg = f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
+ return RetCode.DATA_ERROR, msg, {}, 0
+
+ run_status = q.getlist("run")
+ run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
+ run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
+ if run_status_converted:
+ invalid_status = {s for s in run_status_converted if s not in run_status_text_to_numeric.values()}
+ if invalid_status:
+ msg = f"Invalid filter run status conditions: {', '.join(invalid_status)}"
+ return RetCode.DATA_ERROR, msg, {}, 0
+
+ docs_filter, total = DocumentService.get_filter_by_kb_id(
+ dataset_id,
+ keywords,
+ run_status_converted,
+ types,
+ suffix,
+ )
+ return RetCode.SUCCESS, "", docs_filter, total
+
def _parse_doc_id_filter_with_metadata(req, kb_id):
"""Parse document ID filter based on metadata conditions from the request.
@@ -568,7 +903,7 @@ def _parse_doc_id_filter_with_metadata(req, kb_id):
- The metadata_condition uses operators like: =, !=, >, <, >=, <=, contains, not contains,
in, not in, start with, end with, empty, not empty.
- The metadata parameter performs exact matching where values are OR'd within the same key
- and AND'd across different keys.
+ & AND'd across different keys.
Examples:
Simple metadata filter (exact match):
@@ -622,11 +957,11 @@ def _parse_doc_id_filter_with_metadata(req, kb_id):
if metadata and not isinstance(metadata, dict):
return RetCode.DATA_ERROR, "metadata must be an object.", [], return_empty_metadata
- doc_ids_filter = None
- metas = None
+ metas = dict()
if metadata_condition or metadata:
metas = DocMetadataService.get_flatted_meta_by_kbs([kb_id])
+ doc_ids_filter = None
if metadata_condition:
doc_ids_filter = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
if metadata_condition.get("conditions") and not doc_ids_filter:
@@ -651,6 +986,7 @@ def _parse_doc_id_filter_with_metadata(req, kb_id):
metadata_doc_ids &= key_doc_ids
if not metadata_doc_ids:
return RetCode.SUCCESS, "", [], return_empty_metadata
+
if metadata_doc_ids is not None:
if doc_ids_filter is None:
doc_ids_filter = metadata_doc_ids
@@ -660,3 +996,900 @@ def _parse_doc_id_filter_with_metadata(req, kb_id):
return RetCode.SUCCESS, "", [], return_empty_metadata
return RetCode.SUCCESS, "", list(doc_ids_filter) if doc_ids_filter is not None else [], return_empty_metadata
+
+
+@manager.route("/datasets//documents", methods=["DELETE"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def delete_documents(tenant_id, dataset_id):
+ """
+ Delete documents from a dataset.
+ ---
+ tags:
+ - Documents
+ security:
+ - ApiKeyAuth: []
+ parameters:
+ - in: path
+ name: dataset_id
+ type: string
+ required: true
+ description: ID of the dataset containing the documents.
+ - in: header
+ name: Authorization
+ type: string
+ required: true
+ description: Bearer token for authentication.
+ - in: body
+ name: body
+ description: Document deletion parameters.
+ required: true
+ schema:
+ type: object
+ properties:
+ ids:
+ type: array or null
+ items:
+ type: string
+ description: |
+ Specifies the documents to delete:
+ - An array of IDs, only the specified documents will be deleted.
+ delete_all:
+ type: boolean
+ default: false
+ description: Whether to delete all documents in the dataset.
+ responses:
+ 200:
+ description: Successful operation.
+ schema:
+ type: object
+ """
+ req, err = await validate_and_parse_json_request(request, DeleteDocumentReq)
+ if err is not None or req is None:
+ return get_error_argument_result(err)
+
+ try:
+ # Validate dataset exists and user has permission
+ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
+ return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
+
+ # Get documents to delete
+ doc_ids = req.get("ids") or []
+ delete_all = req.get("delete_all", False)
+ if not delete_all and len(doc_ids) == 0:
+ return get_error_data_result(message=f"should either provide doc ids or set delete_all(true), dataset: {dataset_id}. ")
+
+ if len(doc_ids) > 0 and delete_all:
+ return get_error_data_result(message=f"should not provide both doc ids and delete_all(true), dataset: {dataset_id}. ")
+ if delete_all:
+ doc_ids = [doc.id for doc in DocumentService.query(kb_id=dataset_id)]
+
+ dataset_doc_ids = {doc.id for doc in DocumentService.query(kb_id=dataset_id)}
+ invalid_ids = [doc_id for doc_id in doc_ids if doc_id not in dataset_doc_ids]
+ if invalid_ids:
+ return get_error_data_result(
+ message=f"These documents do not belong to dataset {dataset_id} or Document not found: {', '.join(invalid_ids)}"
+ )
+
+ # make sure each id is unique
+ unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_ids, "document")
+ if duplicate_messages:
+ logging.warning(f"duplicate_messages:{duplicate_messages}")
+ else:
+ doc_ids = unique_doc_ids
+
+ # Delete documents using existing FileService.delete_docs
+ errors = await thread_pool_exec(FileService.delete_docs, doc_ids, tenant_id)
+
+ if errors:
+ return get_error_data_result(message=str(errors))
+
+ return get_result(data={"deleted": len(doc_ids)})
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+@manager.route("/datasets//documents//metadata/config", methods=["PUT"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def update_metadata_config(tenant_id, dataset_id, document_id):
+ """
+ Update document metadata configuration.
+ ---
+ tags:
+ - Documents
+ security:
+ - ApiKeyAuth: []
+ parameters:
+ - in: path
+ name: dataset_id
+ type: string
+ required: true
+ description: ID of the dataset.
+ - in: path
+ name: document_id
+ type: string
+ required: true
+ description: ID of the document.
+ - in: header
+ name: Authorization
+ type: string
+ required: true
+ description: Bearer token for authentication.
+ - in: body
+ name: body
+ description: Metadata configuration.
+ required: true
+ schema:
+ type: object
+ properties:
+ metadata:
+ type: object
+ description: Metadata configuration JSON.
+ responses:
+ 200:
+ description: Document updated successfully.
+ """
+ # Verify ownership and existence of dataset
+ if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
+ return get_error_data_result(message="You don't own the dataset.")
+
+ # Verify document exists in the dataset
+ doc = DocumentService.query(id=document_id, kb_id=dataset_id)
+ if not doc:
+ msg = f"Document {document_id} not found in dataset {dataset_id}"
+ return get_error_data_result(message=msg)
+ doc = doc[0]
+
+ # Get request body
+ req = await get_request_json()
+ if "metadata" not in req:
+ return get_error_argument_result(message="metadata is required")
+
+ # Update parser config with metadata
+ try:
+ DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]})
+ except Exception as e:
+ logging.error("error when update_parser_config", exc_info=e)
+ return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
+
+ # Get updated document
+ try:
+ e, doc = DocumentService.get_by_id(doc.id)
+ if not e:
+ return get_data_error_result(message="Document not found!")
+ except Exception as e:
+ return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
+
+ return get_result(data=doc.to_dict())
+
+
+@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
+def list_thumbnails():
+ """
+ Get thumbnails for documents.
+ ---
+ tags:
+ - Documents
+ parameters:
+ - in: query
+ name: doc_ids
+ type: array
+ required: true
+ description: List of document IDs to get thumbnails for.
+ responses:
+ 200:
+ description: Successfully retrieved thumbnails
+ 400:
+ description: Missing document IDs
+ """
+ from api.constants import IMG_BASE64_PREFIX
+ from api.db.services.document_service import DocumentService
+
+ doc_ids = request.args.getlist("doc_ids")
+ if not doc_ids:
+ return get_json_result(data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR)
+
+ try:
+ docs = DocumentService.get_thumbnails(doc_ids)
+
+ for doc_item in docs:
+ if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX):
+ doc_item["thumbnail"] = f"/api/v1/documents/images/{doc_item['kb_id']}-{doc_item['thumbnail']}"
+
+ return get_json_result(data={d["id"]: d["thumbnail"] for d in docs})
+ except Exception as e:
+ return server_error_response(e)
+
+
+@manager.route("/datasets//documents/metadatas", methods=["PATCH"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def update_metadata(tenant_id, dataset_id):
+ """
+ Update document metadata in batch.
+ ---
+ tags:
+ - Documents
+ security:
+ - ApiKeyAuth: []
+ parameters:
+ - in: path
+ name: dataset_id
+ type: string
+ required: true
+ description: ID of the dataset.
+ - in: header
+ name: Authorization
+ type: string
+ required: true
+ description: Bearer token for authentication.
+ - in: body
+ name: body
+ description: Metadata update request.
+ required: true
+ schema:
+ type: object
+ properties:
+ selector:
+ type: object
+ description: Document selector.
+ properties:
+ document_ids:
+ type: array
+ items:
+ type: string
+ description: List of document IDs to update.
+ metadata_condition:
+ type: object
+ description: Filter documents by existing metadata.
+ updates:
+ type: array
+ items:
+ type: object
+ properties:
+ key:
+ type: string
+ value:
+ type: any
+ description: List of metadata key-value pairs to update.
+ deletes:
+ type: array
+ items:
+ type: object
+ properties:
+ key:
+ type: string
+ description: List of metadata keys to delete.
+ responses:
+ 200:
+ description: Metadata updated successfully.
+ """
+ # Verify ownership of dataset
+ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
+ return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
+
+ # Get request body
+ req = await get_request_json()
+ selector = req.get("selector", {}) or {}
+ updates = req.get("updates", []) or []
+ deletes = req.get("deletes", []) or []
+
+ # Validate selector
+ if not isinstance(selector, dict):
+ return get_error_data_result(message="selector must be an object.")
+ if not isinstance(updates, list) or not isinstance(deletes, list):
+ return get_error_data_result(message="updates and deletes must be lists.")
+
+ # Validate metadata_condition
+ metadata_condition = selector.get("metadata_condition", {}) or {}
+ if metadata_condition and not isinstance(metadata_condition, dict):
+ return get_error_data_result(message="metadata_condition must be an object.")
+
+ # Validate document_ids
+ document_ids = selector.get("document_ids", []) or []
+ if document_ids and not isinstance(document_ids, list):
+ return get_error_data_result(message="document_ids must be a list.")
+
+ # Validate updates
+ for upd in updates:
+ if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
+ return get_error_data_result(message="Each update requires key and value.")
+
+ # Validate deletes
+ for d in deletes:
+ if not isinstance(d, dict) or not d.get("key"):
+ return get_error_data_result(message="Each delete requires key.")
+
+ # Initialize target document IDs
+ target_doc_ids = set()
+
+ # If document_ids provided, validate they belong to the dataset
+ if document_ids:
+ kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
+ invalid_ids = set(document_ids) - set(kb_doc_ids)
+ if invalid_ids:
+ return get_error_data_result(
+ message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}"
+ )
+ target_doc_ids = set(document_ids)
+
+ # Apply metadata_condition filtering if provided
+ if metadata_condition:
+ metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id])
+ filtered_ids = set(
+ meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
+ )
+ target_doc_ids = target_doc_ids & filtered_ids
+ if metadata_condition.get("conditions") and not target_doc_ids:
+ return get_result(data={"updated": 0, "matched_docs": 0})
+
+ # Convert to list and perform update
+ target_doc_ids = list(target_doc_ids)
+ updated = DocMetadataService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes)
+ return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
+
+
+@manager.route("/documents/ingest", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def ingest(tenant_id):
+ req = await get_request_json()
+ try:
+ user_id = tenant_id
+
+ error_code, error_message = await thread_pool_exec(_run_sync, user_id, req)
+
+ if error_code:
+ logging.error(f"error when ingest documents:{req}, error message:{error_message}")
+ return get_json_result(error_code, error_message)
+
+ return get_json_result(data=True)
+ except Exception as e:
+ logging.exception("document ingest/run failed")
+ return server_error_response(e)
+
+def _run_sync(user_id:str, req):
+ for doc_id in req["doc_ids"]:
+ if not DocumentService.accessible(doc_id, user_id):
+ return RetCode.AUTHENTICATION_ERROR, "No authorization."
+
+ kb_table_num_map = {}
+ for doc_id in req["doc_ids"]:
+ info = {"run": str(req["run"]), "progress": 0}
+ rerun_with_delete = str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False)
+ if rerun_with_delete:
+ info["progress_msg"] = ""
+ info["chunk_num"] = 0
+ info["token_num"] = 0
+
+ doc_tenant_id = DocumentService.get_tenant_id(doc_id)
+ if not doc_tenant_id:
+ return RetCode.DATA_ERROR, "Tenant not found!"
+ e, doc = DocumentService.get_by_id(doc_id)
+ if not e:
+ return RetCode.DATA_ERROR, "Document not found!"
+
+ if str(req["run"]) == TaskStatus.CANCEL.value:
+ tasks = list(TaskService.query(doc_id=doc_id))
+ has_unfinished_task = any((task.progress or 0) < 1 for task in tasks)
+ if str(doc.run) in [TaskStatus.RUNNING.value, TaskStatus.CANCEL.value] or has_unfinished_task:
+ cancel_all_task_of(doc_id)
+ else:
+ return RetCode.DATA_ERROR, "Cannot cancel a task that is not in RUNNING status"
+ if all([rerun_with_delete, str(doc.run) == TaskStatus.DONE.value]):
+ DocumentService.clear_chunk_num_when_rerun(doc_id)
+
+ DocumentService.update_by_id(doc_id, info)
+ if req.get("delete", False):
+ TaskService.filter_delete([Task.doc_id == doc_id])
+ if settings.docStoreConn.index_exist(search.index_name(doc_tenant_id), doc.kb_id):
+ settings.docStoreConn.delete({"doc_id": doc_id}, search.index_name(doc_tenant_id), doc.kb_id)
+
+ if str(req["run"]) == TaskStatus.RUNNING.value:
+ if req.get("apply_kb"):
+ e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
+ if not e:
+ raise LookupError("Can't find this dataset!")
+ doc.parser_config["llm_id"] = kb.parser_config.get("llm_id")
+ doc.parser_config["enable_metadata"] = kb.parser_config.get("enable_metadata", False)
+ doc.parser_config["metadata"] = kb.parser_config.get("metadata", {})
+ DocumentService.update_parser_config(doc.id, doc.parser_config)
+ doc_dict = doc.to_dict()
+ DocumentService.run(doc_tenant_id, doc_dict, kb_table_num_map)
+
+ return None, None
+
+
+@manager.route("/datasets//documents/parse", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def parse_documents(tenant_id, dataset_id):
+ """
+ Start parsing documents in a dataset.
+ ---
+ tags:
+ - Documents
+ security:
+ - ApiKeyAuth: []
+ parameters:
+ - in: path
+ name: dataset_id
+ type: string
+ required: true
+ description: ID of the dataset.
+ - in: header
+ name: Authorization
+ type: string
+ required: true
+ description: Bearer token for authentication.
+ - in: body
+ name: body
+ description: Document parse parameters.
+ required: true
+ schema:
+ type: object
+ properties:
+ document_ids:
+ type: array
+ items:
+ type: string
+ description: List of document IDs to parse.
+ responses:
+ 200:
+ description: Successful operation.
+ """
+ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
+ return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
+
+ req = await get_request_json()
+ if req is None:
+ return get_error_data_result(message="Request body is required")
+
+ document_ids = req.get("document_ids")
+ if document_ids is None or not isinstance(document_ids, list):
+ return get_error_data_result(message="`document_ids` is required")
+ if len(document_ids) == 0:
+ return get_error_data_result(message="`document_ids` is required")
+
+ # Check for duplicate document IDs
+ unique_doc_ids, duplicate_messages = check_duplicate_ids(document_ids, "document")
+ errors = duplicate_messages if duplicate_messages else []
+
+ # Validate all document IDs belong to the dataset
+ not_found_ids = []
+ valid_doc_ids = []
+ for doc_id in unique_doc_ids:
+ docs = DocumentService.query(kb_id=dataset_id, id=doc_id)
+ if not docs:
+ not_found_ids.append(doc_id)
+ else:
+ valid_doc_ids.append(doc_id)
+
+ if not_found_ids:
+ errors.append(f"Documents not found: {not_found_ids}")
+ # Still parse valid documents, but return error code
+ if not valid_doc_ids:
+ return get_error_data_result(message=f"Documents not found: {not_found_ids}")
+
+ try:
+ def _run_sync():
+ kb_table_num_map = {}
+ success_count = 0
+ for doc_id in valid_doc_ids:
+ e, doc = DocumentService.get_by_id(doc_id)
+ if not e:
+ errors.append(f"Document not found: {doc_id}")
+ continue
+
+ info = {"run": str(TaskStatus.RUNNING.value), "progress": 0}
+ # If re-running a completed document, clear previous chunks
+ if str(doc.run) == TaskStatus.DONE.value:
+ DocumentService.clear_chunk_num_when_rerun(doc.id)
+ info["progress_msg"] = ""
+ info["chunk_num"] = 0
+ info["token_num"] = 0
+
+ DocumentService.update_by_id(doc_id, info)
+ TaskService.filter_delete([Task.doc_id == doc_id])
+ if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
+ settings.docStoreConn.delete({"doc_id": doc_id}, search.index_name(tenant_id), doc.kb_id)
+
+ doc_dict = doc.to_dict()
+ DocumentService.run(tenant_id, doc_dict, kb_table_num_map)
+ success_count += 1
+
+ result = {"success_count": success_count}
+ if errors:
+ result["errors"] = errors
+ return result
+
+ result = await thread_pool_exec(_run_sync)
+ if not_found_ids:
+ return get_error_data_result(message=f"Documents not found: {not_found_ids}")
+ return get_result(data=result)
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/datasets//documents/stop", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def stop_parse_documents(tenant_id, dataset_id):
+ """
+ Stop parsing documents in a dataset.
+ ---
+ tags:
+ - Documents
+ security:
+ - ApiKeyAuth: []
+ parameters:
+ - in: path
+ name: dataset_id
+ type: string
+ required: true
+ description: ID of the dataset.
+ - in: header
+ name: Authorization
+ type: string
+ required: true
+ description: Bearer token for authentication.
+ - in: body
+ name: body
+ description: Document stop parse parameters.
+ required: true
+ schema:
+ type: object
+ properties:
+ document_ids:
+ type: array
+ items:
+ type: string
+ description: List of document IDs to stop parsing.
+ responses:
+ 200:
+ description: Successful operation.
+ """
+ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
+ return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
+
+ req = await get_request_json()
+ if req is None:
+ return get_error_data_result(message="Request body is required")
+
+ document_ids = req.get("document_ids")
+ if document_ids is None or not isinstance(document_ids, list):
+ return get_error_data_result(message="`document_ids` is required")
+ if len(document_ids) == 0:
+ return get_error_data_result(message="`document_ids` is required")
+
+ # Check for duplicate document IDs
+ unique_doc_ids, duplicate_messages = check_duplicate_ids(document_ids, "document")
+ errors = duplicate_messages if duplicate_messages else []
+
+ # Validate all document IDs belong to the dataset
+ not_found_ids = []
+ valid_doc_ids = []
+ for doc_id in unique_doc_ids:
+ docs = DocumentService.query(kb_id=dataset_id, id=doc_id)
+ if not docs:
+ not_found_ids.append(doc_id)
+ else:
+ valid_doc_ids.append(doc_id)
+
+ if not_found_ids:
+ return get_error_data_result(message=f"Documents not found: {not_found_ids}")
+
+ try:
+ def _run_sync():
+ success_count = 0
+ for doc_id in valid_doc_ids:
+ e, doc = DocumentService.get_by_id(doc_id)
+ if not e:
+ errors.append(f"Document not found: {doc_id}")
+ continue
+
+ # Check if the document is currently running
+ tasks = list(TaskService.query(doc_id=doc_id))
+ has_unfinished_task = any((task.progress or 0) < 1 for task in tasks)
+ if str(doc.run) not in [TaskStatus.RUNNING.value, TaskStatus.CANCEL.value] and not has_unfinished_task:
+ errors.append("Can't stop parsing document that has not started or already completed")
+ continue
+
+ cancel_all_task_of(doc_id)
+ DocumentService.update_by_id(doc_id, {"run": str(TaskStatus.CANCEL.value)})
+ success_count += 1
+
+ result = {"success_count": success_count}
+ if errors:
+ result["errors"] = errors
+ return result
+
+ result = await thread_pool_exec(_run_sync)
+ if not_found_ids:
+ return get_error_data_result(message=f"Documents not found: {not_found_ids}")
+ return get_result(data=result)
+ except Exception as e:
+ logging.exception(e)
+ return get_error_data_result(message="Internal server error")
+
+
+@manager.route("/documents/images/", methods=["GET"]) # noqa: F821
+async def get_document_image(image_id):
+ """
+ Get a document image by ID.
+ ---
+ tags:
+ - Documents
+ parameters:
+ - name: image_id
+ in: path
+ required: true
+ schema:
+ type: string
+ description: The image ID (format: bucket-name-image-name)
+ responses:
+ 200:
+ description: Image file
+ content:
+ image/jpeg:
+ schema:
+ type: string
+ format: binary
+ """
+ try:
+ arr = image_id.split("-")
+ if len(arr) != 2:
+ return get_data_error_result(message="Image not found.")
+ bkt, nm = image_id.split("-")
+ data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm)
+ response = await make_response(data)
+ response.headers.set("Content-Type", "image/JPEG")
+ return response
+ except Exception as e:
+ return server_error_response(e)
+
+
+ARTIFACT_CONTENT_TYPES = {
+ ".png": "image/png",
+ ".jpg": "image/jpeg",
+ ".jpeg": "image/jpeg",
+ ".svg": "image/svg+xml",
+ ".pdf": "application/pdf",
+ ".csv": "text/csv",
+ ".json": "application/json",
+ ".html": "text/html",
+}
+
+
+@manager.route("/documents/artifact/", methods=["GET"]) # noqa: F821
+@login_required
+async def get_artifact(filename):
+ """
+ Get an artifact file.
+ ---
+ tags:
+ - Documents
+ security:
+ - ApiKeyAuth: []
+ parameters:
+ - in: path
+ name: filename
+ type: string
+ required: true
+ description: Name of the artifact file.
+ - in: header
+ name: Authorization
+ type: string
+ required: true
+ description: Bearer token for authentication.
+ responses:
+ 200:
+ description: Artifact file returned successfully.
+ """
+ from common import settings
+
+ try:
+ bucket = SANDBOX_ARTIFACT_BUCKET
+ # Validate filename: must be uuid hex + allowed extension, nothing else
+ basename = os.path.basename(filename)
+ if basename != filename or "/" in filename or "\\" in filename:
+ return get_data_error_result(message="Invalid filename.")
+ ext = os.path.splitext(basename)[1].lower()
+ if ext not in ARTIFACT_CONTENT_TYPES:
+ return get_data_error_result(message="Invalid file type.")
+ data = await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, basename)
+ if not data:
+ return get_data_error_result(message="Artifact not found.")
+ content_type = ARTIFACT_CONTENT_TYPES.get(ext, "application/octet-stream")
+ response = await make_response(data)
+ safe_filename = re.sub(r"[^\w.\-]", "_", basename)
+ apply_safe_file_response_headers(response, content_type, ext)
+ if not response.headers.get("Content-Disposition"):
+ response.headers.set("Content-Disposition", f'inline; filename="{safe_filename}"')
+ return response
+ except Exception as e:
+ return server_error_response(e)
+
+
+@manager.route("/datasets//documents/batch-update-status", methods=["POST"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def batch_update_document_status(tenant_id, dataset_id):
+ """
+ Batch update status of documents within a dataset.
+ ---
+ tags:
+ - Documents
+ security:
+ - ApiKeyAuth: []
+ parameters:
+ - in: path
+ name: dataset_id
+ type: string
+ required: true
+ description: ID of the dataset.
+ - in: header
+ name: Authorization
+ type: string
+ required: true
+ description: Bearer token for authentication.
+ - in: body
+ name: body
+ description: Document status update parameters.
+ required: true
+ schema:
+ type: object
+ required:
+ - doc_ids
+ - status
+ properties:
+ doc_ids:
+ type: array
+ items:
+ type: string
+ description: List of document IDs to update.
+ status:
+ type: string
+ enum: ["0", "1"]
+ description: New status (0 = disabled, 1 = enabled).
+ responses:
+ 200:
+ description: Document statuses updated successfully.
+ """
+
+ req = await get_request_json()
+ doc_ids = req.get("doc_ids", [])
+ if not isinstance(doc_ids, list) or not doc_ids:
+ return get_error_argument_result(message='"doc_ids" must be a non-empty list.')
+ if any(not isinstance(doc_id, str) or not doc_id for doc_id in doc_ids):
+ return get_error_argument_result(message='"doc_ids" must contain non-empty document IDs.')
+
+ status = str(req.get("status", -1))
+ if status not in ["0", "1"]:
+ return get_error_argument_result(message=f'"Status" must be either 0 or 1:{status}!')
+
+ # Verify dataset ownership
+ if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
+ return get_error_data_result(message="You don't own the dataset.")
+
+ e, kb = KnowledgebaseService.get_by_id(dataset_id)
+ if not e:
+ return get_error_data_result(message="Can't find this dataset!")
+
+ result = {}
+ has_error = False
+ for doc_id in doc_ids:
+ try:
+ e, doc = DocumentService.get_by_id(doc_id)
+ if not e:
+ result[doc_id] = {"error": "Document not found"}
+ has_error = True
+ continue
+
+ if doc.kb_id != dataset_id:
+ logging.warning(f"Document {doc.kb_id} not in dataset {dataset_id}")
+ result[doc_id] = {"error": "Document not found in this dataset."}
+ has_error = True
+ continue
+
+ current_status = str(doc.status)
+ if current_status == status:
+ result[doc_id] = {"status": status}
+ continue
+ if not DocumentService.update_by_id(doc_id, {"status": str(status)}):
+ result[doc_id] = {"error": "Database error (Document update)!"}
+ has_error = True
+ continue
+
+ status_int = int(status)
+ if getattr(doc, "chunk_num", 0) > 0:
+ try:
+ ok = settings.docStoreConn.update(
+ {"doc_id": doc_id},
+ {"available_int": status_int},
+ search.index_name(kb.tenant_id),
+ doc.kb_id,
+ )
+ except Exception as exc:
+ msg = str(exc)
+ if "3022" in msg:
+ result[doc_id] = {"error": "Document store table missing."}
+ else:
+ result[doc_id] = {"error": f"Document store update failed: {msg}"}
+ has_error = True
+ continue
+ if not ok:
+ result[doc_id] = {"error": "Database error (docStore update)!"}
+ has_error = True
+ continue
+ result[doc_id] = {"status": status}
+ except Exception as e:
+ result[doc_id] = {"error": f"Internal server error: {str(e)}"}
+ has_error = True
+
+ if has_error:
+ return get_json_result(data=result, message="Partial failure", code=RetCode.SERVER_ERROR)
+ return get_json_result(data=result)
+
+@manager.route("/documents//preview", methods=["GET"]) # noqa: F821
+@login_required
+async def get(doc_id):
+ """Return the raw file bytes for a document the requesting user is authorized to read.
+
+ The user must belong to the tenant that owns the document's knowledge base; otherwise
+ the response is indistinguishable from a missing document to avoid cross-tenant ID
+ enumeration.
+ """
+ try:
+ if not DocumentService.accessible(doc_id, current_user.id):
+ return get_data_error_result(message="Document not found!")
+
+ e, doc = DocumentService.get_by_id(doc_id)
+ if not e:
+ return get_data_error_result(message="Document not found!")
+
+ b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
+ data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n)
+ response = await make_response(data)
+
+ ext = re.search(r"\.([^.]+)$", doc.name.lower())
+ ext = ext.group(1) if ext else None
+ content_type = None
+ if ext:
+ fallback_prefix = "image" if doc.type == FileType.VISUAL.value else "application"
+ content_type = CONTENT_TYPE_MAP.get(ext, f"{fallback_prefix}/{ext}")
+ apply_safe_file_response_headers(response, content_type, ext)
+ return response
+ except Exception as e:
+ return server_error_response(e)
+
+
+@manager.route("/documents//download", methods=["GET"]) # noqa: F821
+@login_required
+@add_tenant_id_to_kwargs
+async def download_attachment(tenant_id=None, doc_id=None, attachment_id=None):
+ """Stream a document's underlying file to the requesting user.
+
+ Mirrors the authorization model of the preview endpoint: the user must belong
+ to the tenant that owns the document's knowledge base. A denial returns the
+ same "Document not found!" response so the endpoint cannot be used to
+ enumerate doc ids across tenants.
+ """
+ try:
+ # Keep backward compatibility with older callers and unit tests that still
+ # pass `attachment_id` instead of the route parameter name.
+ doc_id = doc_id or attachment_id
+ if not DocumentService.accessible(doc_id, current_user.id):
+ return get_data_error_result(message="Document not found!")
+ ext = request.args.get("ext", "markdown")
+ data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, doc_id)
+ response = await make_response(data)
+ content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
+ apply_safe_file_response_headers(response, content_type, ext)
+
+ return response
+
+ except Exception as e:
+ return server_error_response(e)
diff --git a/api/apps/file2document_app.py b/api/apps/restful_apis/file2document_api.py
similarity index 63%
rename from api/apps/file2document_app.py
rename to api/apps/restful_apis/file2document_api.py
index c82207ab73a..9c466a441d3 100644
--- a/api/apps/file2document_app.py
+++ b/api/apps/restful_apis/file2document_api.py
@@ -18,6 +18,7 @@
import logging
from pathlib import Path
+from api.common.check_team_permission import check_file_team_permission, check_kb_team_permission
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
@@ -25,10 +26,11 @@
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
from common.misc_utils import get_uuid
-from common.constants import RetCode
from api.db import FileType
from api.db.services.document_service import DocumentService
+logger = logging.getLogger(__name__)
+
def _convert_files(file_ids, kb_ids, user_id):
"""Synchronous worker: delete old docs and insert new ones for the given file/kb pairs."""
@@ -74,7 +76,7 @@ def _convert_files(file_ids, kb_ids, user_id):
})
-@manager.route('/convert', methods=['POST']) # noqa: F821
+@manager.route('/files/link-to-datasets', methods=['POST']) # noqa: F821
@login_required
@validate_request("file_ids", "kb_ids")
async def convert():
@@ -89,13 +91,29 @@ async def convert():
# Validate all files exist before starting any work
for file_id in file_ids:
if not files_set.get(file_id):
+ logger.warning(
+ "user_id=%s resource_type=file resource_id=%s action=validate_file_lookup result=not_found file_ids=%s kb_ids=%s",
+ current_user.id,
+ file_id,
+ file_ids,
+ kb_ids,
+ )
return get_data_error_result(message="File not found!")
# Validate all kb_ids exist before scheduling background work
+ kb_map = {}
for kb_id in kb_ids:
- e, _ = KnowledgebaseService.get_by_id(kb_id)
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
+ logger.warning(
+ "user_id=%s resource_type=dataset resource_id=%s action=validate_dataset_lookup result=not_found file_ids=%s kb_ids=%s",
+ current_user.id,
+ kb_id,
+ file_ids,
+ kb_ids,
+ )
return get_data_error_result(message="Can't find this dataset!")
+ kb_map[kb_id] = kb
# Expand folders to their innermost file IDs
all_file_ids = []
@@ -107,6 +125,38 @@ async def convert():
all_file_ids.append(file_id)
user_id = current_user.id
+ for file_id in all_file_ids:
+ e, file = FileService.get_by_id(file_id)
+ if not e or not file:
+ logger.warning(
+ "user_id=%s resource_type=file resource_id=%s action=validate_expanded_file_lookup result=not_found file_ids=%s kb_ids=%s",
+ user_id,
+ file_id,
+ file_ids,
+ kb_ids,
+ )
+ return get_data_error_result(message="File not found!")
+ if not check_file_team_permission(file, user_id):
+ logger.warning(
+ "user_id=%s resource_type=file resource_id=%s action=authorize_file result=denied file_ids=%s kb_ids=%s",
+ user_id,
+ file_id,
+ file_ids,
+ kb_ids,
+ )
+ return get_data_error_result(message="No authorization.")
+
+ for kb_id, kb in kb_map.items():
+ if not check_kb_team_permission(kb, user_id):
+ logger.warning(
+ "user_id=%s resource_type=dataset resource_id=%s action=authorize_dataset result=denied file_ids=%s kb_ids=%s",
+ user_id,
+ kb_id,
+ file_ids,
+ kb_ids,
+ )
+ return get_data_error_result(message="No authorization.")
+
# Run the blocking DB work in a thread so the event loop is not blocked.
# For large folders this prevents 504 Gateway Timeout by returning as
# soon as the background task is scheduled.
@@ -115,39 +165,12 @@ async def convert():
future.add_done_callback(
lambda f: logging.error("_convert_files failed: %s", f.exception()) if f.exception() else None
)
- return get_json_result(data=True)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route('/rm', methods=['POST']) # noqa: F821
-@login_required
-@validate_request("file_ids")
-async def rm():
- req = await get_request_json()
- file_ids = req["file_ids"]
- if not file_ids:
- return get_json_result(
- data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR)
- try:
- for file_id in file_ids:
- informs = File2DocumentService.get_by_file_id(file_id)
- if not informs:
- return get_data_error_result(message="Inform not found!")
- for inform in informs:
- if not inform:
- return get_data_error_result(message="Inform not found!")
- File2DocumentService.delete_by_file_id(file_id)
- doc_id = inform.document_id
- e, doc = DocumentService.get_by_id(doc_id)
- if not e:
- return get_data_error_result(message="Document not found!")
- tenant_id = DocumentService.get_tenant_id(doc_id)
- if not tenant_id:
- return get_data_error_result(message="Tenant not found!")
- if not DocumentService.remove_document(doc, tenant_id):
- return get_data_error_result(
- message="Database error (Document removal)!")
+ logger.info(
+ "user_id=%s resource_type=file_to_dataset_link resource_id=batch action=schedule_convert result=scheduled file_ids=%s kb_ids=%s",
+ user_id,
+ all_file_ids,
+ kb_ids,
+ )
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
diff --git a/api/apps/restful_apis/file_api.py b/api/apps/restful_apis/file_api.py
index fbe1e39d50a..b67aa30ffce 100644
--- a/api/apps/restful_apis/file_api.py
+++ b/api/apps/restful_apis/file_api.py
@@ -24,8 +24,10 @@
add_tenant_id_to_kwargs,
get_error_argument_result,
get_error_data_result,
+ get_json_result,
get_result,
)
+from common.constants import RetCode
from api.utils.validation_utils import (
CreateFolderReq,
DeleteFileReq,
@@ -99,7 +101,7 @@ async def create_or_upload(tenant_id: str = None):
@manager.route("/files", methods=["GET"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
-def list_files(tenant_id: str = None):
+async def list_files(tenant_id: str = None):
"""
List files under a folder.
---
@@ -185,10 +187,22 @@ async def delete(tenant_id: str = None):
return get_error_argument_result(err)
try:
- success, result = await file_api_service.delete_files(tenant_id, req["ids"])
+ # Get Authorization header to pass to Go backend
+ auth_header = request.headers.get("Authorization", "")
+ success, result = await file_api_service.delete_files(tenant_id, req["ids"], auth_header)
if success:
return get_result(data=result)
else:
+ if isinstance(result, dict):
+ success_count = result.get("success_count", 0)
+ errors = result.get("errors", [])
+ return get_json_result(
+ code=RetCode.DATA_ERROR,
+ message=f"Partially deleted {success_count} files with {len(errors)} errors"
+ if success_count > 0
+ else f"Deleted files failed with {len(errors)} errors",
+ data=result,
+ )
return get_error_data_result(message=result)
except Exception as e:
logging.exception(e)
@@ -303,7 +317,7 @@ async def download(tenant_id: str = None, file_id: str = None):
@manager.route("/files//parent", methods=["GET"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
-def parent_folder(tenant_id: str = None, file_id: str = None):
+async def parent_folder(tenant_id: str = None, file_id: str = None):
"""
Get parent folder of a file.
---
@@ -321,7 +335,7 @@ def parent_folder(tenant_id: str = None, file_id: str = None):
description: Parent folder information.
"""
try:
- success, result = file_api_service.get_parent_folder(file_id)
+ success, result = file_api_service.get_parent_folder(file_id, user_id=tenant_id)
if success:
return get_result(data=result)
else:
@@ -334,7 +348,7 @@ def parent_folder(tenant_id: str = None, file_id: str = None):
@manager.route("/files//ancestors", methods=["GET"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
-def ancestors(tenant_id: str = None, file_id: str = None):
+async def ancestors(tenant_id: str = None, file_id: str = None):
"""
Get all ancestor folders of a file.
---
@@ -352,7 +366,7 @@ def ancestors(tenant_id: str = None, file_id: str = None):
description: List of ancestor folders.
"""
try:
- success, result = file_api_service.get_all_parent_folders(file_id)
+ success, result = file_api_service.get_all_parent_folders(file_id, user_id=tenant_id)
if success:
return get_result(data=result)
else:
@@ -360,5 +374,3 @@ def ancestors(tenant_id: str = None, file_id: str = None):
except Exception as e:
logging.exception(e)
return get_error_data_result(message="Internal server error")
-
-
diff --git a/api/apps/langfuse_app.py b/api/apps/restful_apis/langfuse_api.py
similarity index 94%
rename from api/apps/langfuse_app.py
rename to api/apps/restful_apis/langfuse_api.py
index 1d7993d365c..70b81b42c63 100644
--- a/api/apps/langfuse_app.py
+++ b/api/apps/restful_apis/langfuse_api.py
@@ -23,7 +23,7 @@
from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request
-@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821
+@manager.route("/langfuse/api-key", methods=["POST", "PUT"]) # noqa: F821
@login_required
@validate_request("secret_key", "public_key", "host")
async def set_api_key():
@@ -58,7 +58,7 @@ async def set_api_key():
return server_error_response(e)
-@manager.route("/api_key", methods=["GET"]) # noqa: F821
+@manager.route("/langfuse/api-key", methods=["GET"]) # noqa: F821
@login_required
@validate_request()
def get_api_key():
@@ -82,7 +82,7 @@ def get_api_key():
return get_json_result(data=langfuse_entry)
-@manager.route("/api_key", methods=["DELETE"]) # noqa: F821
+@manager.route("/langfuse/api-key", methods=["DELETE"]) # noqa: F821
@login_required
@validate_request()
def delete_api_key():
diff --git a/api/apps/mcp_server_app.py b/api/apps/restful_apis/mcp_api.py
similarity index 62%
rename from api/apps/mcp_server_app.py
rename to api/apps/restful_apis/mcp_api.py
index 187560d626b..ec384f6074d 100644
--- a/api/apps/mcp_server_app.py
+++ b/api/apps/restful_apis/mcp_api.py
@@ -1,5 +1,5 @@
#
-# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,20 +13,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+
from quart import Response, request
-from api.apps import current_user, login_required
+from api.apps import current_user, login_required
from api.db.db_models import MCPServer
from api.db.services.mcp_server_service import MCPServerService
from api.db.services.user_service import TenantService
-from common.constants import RetCode, VALID_MCP_SERVER_TYPES
-
-from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request
from api.utils.web_utils import get_float, safe_json_parse
+from common.constants import VALID_MCP_SERVER_TYPES
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
+from common.misc_utils import get_uuid, thread_pool_exec
-@manager.route("/list", methods=["POST"]) # noqa: F821
+
+def _get_mcp_ids_from_args() -> list[str]:
+ mcp_ids = request.args.getlist("mcp_ids")
+ if mcp_ids:
+ return [mcp_id for item in mcp_ids for mcp_id in item.split(",") if mcp_id]
+ mcp_ids = request.args.get("mcp_id", "")
+ return [mcp_id for mcp_id in mcp_ids.split(",") if mcp_id]
+
+
+def _export_mcp_servers(mcp_ids: list[str]) -> dict | None:
+ exported_servers = {}
+ for mcp_id in mcp_ids:
+ e, mcp_server = MCPServerService.get_by_id(mcp_id)
+ if e and mcp_server.tenant_id == current_user.id:
+ server_key = mcp_server.name
+ exported_servers[server_key] = {
+ "type": mcp_server.server_type,
+ "url": mcp_server.url,
+ "name": mcp_server.name,
+ "authorization_token": mcp_server.variables.get("authorization_token", ""),
+ "tools": mcp_server.variables.get("tools", {}),
+ }
+
+ if not exported_servers:
+ return None
+
+ return {"mcpServers": exported_servers}
+
+
+@manager.route("/mcp/servers", methods=["GET"]) # noqa: F821
@login_required
async def list_mcp() -> Response:
keywords = request.args.get("keywords", "")
@@ -38,8 +67,7 @@ async def list_mcp() -> Response:
else:
desc = True
- req = await get_request_json()
- mcp_ids = req.get("mcp_ids", [])
+ mcp_ids = _get_mcp_ids_from_args()
try:
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or []
total = len(servers)
@@ -52,22 +80,27 @@ async def list_mcp() -> Response:
return server_error_response(e)
-@manager.route("/detail", methods=["GET"]) # noqa: F821
+@manager.route("/mcp/servers/", methods=["GET"]) # noqa: F821
@login_required
-def detail() -> Response:
- mcp_id = request.args["mcp_id"]
+def detail(mcp_id: str) -> Response:
try:
+ if request.args.get("mode") == "download":
+ exported_servers = _export_mcp_servers([mcp_id])
+ if exported_servers is None:
+ return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
+ return get_json_result(data=exported_servers)
+
mcp_server = MCPServerService.get_or_none(id=mcp_id, tenant_id=current_user.id)
if mcp_server is None:
- return get_json_result(code=RetCode.NOT_FOUND, data=None)
+ return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
return get_json_result(data=mcp_server.to_dict())
except Exception as e:
return server_error_response(e)
-@manager.route("/create", methods=["POST"]) # noqa: F821
+@manager.route("/mcp/servers", methods=["POST"]) # noqa: F821
@login_required
@validate_request("name", "url", "server_type")
async def create() -> Response:
@@ -107,7 +140,7 @@ async def create() -> Response:
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message:
- return get_data_error_result(err_message)
+ return get_data_error_result(message=err_message)
tools = server_tools[server_name]
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
@@ -115,20 +148,18 @@ async def create() -> Response:
req["variables"] = variables
if not MCPServerService.insert(**req):
- return get_data_error_result("Failed to create MCP server.")
+ return get_data_error_result(message="Failed to create MCP server.")
return get_json_result(data=req)
except Exception as e:
return server_error_response(e)
-@manager.route("/update", methods=["POST"]) # noqa: F821
+@manager.route("/mcp/servers/", methods=["PUT"]) # noqa: F821
@login_required
-@validate_request("mcp_id")
-async def update() -> Response:
+async def update(mcp_id: str) -> Response:
req = await get_request_json()
- mcp_id = req.get("mcp_id", "")
e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
@@ -153,13 +184,12 @@ async def update() -> Response:
try:
req["tenant_id"] = current_user.id
- req.pop("mcp_id", None)
req["id"] = mcp_id
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout)
if err_message:
- return get_data_error_result(err_message)
+ return get_data_error_result(message=err_message)
tools = server_tools[server_name]
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
@@ -178,25 +208,22 @@ async def update() -> Response:
return server_error_response(e)
-@manager.route("/rm", methods=["POST"]) # noqa: F821
+@manager.route("/mcp/servers/", methods=["DELETE"]) # noqa: F821
@login_required
-@validate_request("mcp_ids")
-async def rm() -> Response:
- req = await get_request_json()
- mcp_ids = req.get("mcp_ids", [])
-
+async def rm(mcp_id: str) -> Response:
try:
- req["tenant_id"] = current_user.id
-
- if not MCPServerService.delete_by_ids(mcp_ids):
- return get_data_error_result(message=f"Failed to delete MCP servers {mcp_ids}")
+ e, mcp_server = MCPServerService.get_by_id(mcp_id)
+ if not e or mcp_server.tenant_id != current_user.id:
+ return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
+ if not MCPServerService.delete_by_ids([mcp_id]):
+ return get_data_error_result(message=f"Failed to delete MCP servers {[mcp_id]}")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
-@manager.route("/import", methods=["POST"]) # noqa: F821
+@manager.route("/mcp/servers/import", methods=["POST"]) # noqa: F821
@login_required
@validate_request("mcpServers")
async def import_multiple() -> Response:
@@ -263,144 +290,10 @@ async def import_multiple() -> Response:
return server_error_response(e)
-@manager.route("/export", methods=["POST"]) # noqa: F821
+@manager.route("/mcp/servers//test", methods=["POST"]) # noqa: F821
@login_required
-@validate_request("mcp_ids")
-async def export_multiple() -> Response:
- req = await get_request_json()
- mcp_ids = req.get("mcp_ids", [])
-
- if not mcp_ids:
- return get_data_error_result(message="No MCP server IDs provided.")
-
- try:
- exported_servers = {}
-
- for mcp_id in mcp_ids:
- e, mcp_server = MCPServerService.get_by_id(mcp_id)
-
- if e and mcp_server.tenant_id == current_user.id:
- server_key = mcp_server.name
-
- exported_servers[server_key] = {
- "type": mcp_server.server_type,
- "url": mcp_server.url,
- "name": mcp_server.name,
- "authorization_token": mcp_server.variables.get("authorization_token", ""),
- "tools": mcp_server.variables.get("tools", {}),
- }
-
- return get_json_result(data={"mcpServers": exported_servers})
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/list_tools", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("mcp_ids")
-async def list_tools() -> Response:
- req = await get_request_json()
- mcp_ids = req.get("mcp_ids", [])
- if not mcp_ids:
- return get_data_error_result(message="No MCP server IDs provided.")
-
- timeout = get_float(req, "timeout", 10)
-
- results = {}
- tool_call_sessions = []
- try:
- for mcp_id in mcp_ids:
- e, mcp_server = MCPServerService.get_by_id(mcp_id)
-
- if e and mcp_server.tenant_id == current_user.id:
- server_key = mcp_server.id
-
- cached_tools = mcp_server.variables.get("tools", {})
-
- tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
- tool_call_sessions.append(tool_call_session)
-
- try:
- tools = await thread_pool_exec(tool_call_session.get_tools, timeout)
- except Exception as e:
- return get_data_error_result(message=f"MCP list tools error: {e}")
-
- results[server_key] = []
- for tool in tools:
- tool_dict = tool.model_dump()
- cached_tool = cached_tools.get(tool_dict["name"], {})
-
- tool_dict["enabled"] = cached_tool.get("enabled", True)
- results[server_key].append(tool_dict)
-
- return get_json_result(data=results)
- except Exception as e:
- return server_error_response(e)
- finally:
- # PERF: blocking call to close sessions — consider moving to background thread or task queue
- await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
-
-
-@manager.route("/test_tool", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("mcp_id", "tool_name", "arguments")
-async def test_tool() -> Response:
- req = await get_request_json()
- mcp_id = req.get("mcp_id", "")
- if not mcp_id:
- return get_data_error_result(message="No MCP server ID provided.")
-
- timeout = get_float(req, "timeout", 10)
-
- tool_name = req.get("tool_name", "")
- arguments = req.get("arguments", {})
- if not all([tool_name, arguments]):
- return get_data_error_result(message="Require provide tool name and arguments.")
-
- tool_call_sessions = []
- try:
- e, mcp_server = MCPServerService.get_by_id(mcp_id)
- if not e or mcp_server.tenant_id != current_user.id:
- return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
-
- tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
- tool_call_sessions.append(tool_call_session)
- result = await thread_pool_exec(tool_call_session.tool_call, tool_name, arguments, timeout)
-
- # PERF: blocking call to close sessions — consider moving to background thread or task queue
- await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
- return get_json_result(data=result)
- except Exception as e:
- return server_error_response(e)
-
-
-@manager.route("/cache_tools", methods=["POST"]) # noqa: F821
-@login_required
-@validate_request("mcp_id", "tools")
-async def cache_tool() -> Response:
- req = await get_request_json()
- mcp_id = req.get("mcp_id", "")
- if not mcp_id:
- return get_data_error_result(message="No MCP server ID provided.")
- tools = req.get("tools", [])
-
- e, mcp_server = MCPServerService.get_by_id(mcp_id)
- if not e or mcp_server.tenant_id != current_user.id:
- return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
-
- variables = mcp_server.variables
- tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
- variables["tools"] = tools
-
- if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], {"variables": variables}):
- return get_data_error_result(message="Failed to updated MCP server.")
-
- return get_json_result(data=tools)
-
-
-@manager.route("/test_mcp", methods=["POST"]) # noqa: F821
@validate_request("url", "server_type")
-async def test_mcp() -> Response:
+async def test_mcp(mcp_id: str) -> Response:
req = await get_request_json()
url = req.get("url", "")
@@ -415,7 +308,7 @@ async def test_mcp() -> Response:
headers = safe_json_parse(req.get("headers", {}))
variables = safe_json_parse(req.get("variables", {}))
- mcp_server = MCPServer(id=f"{server_type}: {url}", server_type=server_type, url=url, headers=headers, variables=variables)
+ mcp_server = MCPServer(id=mcp_id, server_type=server_type, url=url, headers=headers, variables=variables)
result = []
try:
@@ -426,7 +319,6 @@ async def test_mcp() -> Response:
except Exception as e:
return get_data_error_result(message=f"Test MCP error: {e}")
finally:
- # PERF: blocking call to close sessions — consider moving to background thread or task queue
await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session])
for tool in tools:
diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py
index 8f92661e700..c361d816b60 100644
--- a/api/apps/restful_apis/memory_api.py
+++ b/api/apps/restful_apis/memory_api.py
@@ -130,7 +130,7 @@ async def delete_memory(memory_id):
@login_required
async def list_memory():
filter_params = {
- k: request.args.get(k) for k in ["memory_type", "tenant_id", "storage_type"] if k in request.args
+ k: request.args.get(k) for k in ["memory_type", "tenant_id", "owner_ids", "storage_type"] if k in request.args
}
keywords = request.args.get("keywords")
page = int(request.args.get("page", 1))
diff --git a/api/apps/restful_apis/openai_api.py b/api/apps/restful_apis/openai_api.py
new file mode 100644
index 00000000000..baa011f32a8
--- /dev/null
+++ b/api/apps/restful_apis/openai_api.py
@@ -0,0 +1,300 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import time
+
+from quart import Response, jsonify
+
+from api.apps import current_user, login_required
+from api.db.services.dialog_service import DialogService, async_chat
+from api.db.services.doc_metadata_service import DocMetadataService
+from api.db.services.tenant_llm_service import TenantLLMService
+from api.utils.api_utils import get_error_data_result, get_request_json, validate_request
+from common.constants import RetCode, StatusEnum
+from common.metadata_utils import convert_conditions, meta_filter
+from common.token_utils import num_tokens_from_string
+from rag.prompts.generator import chunks_format
+
+def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
+ if not llm_id:
+ return None
+
+ llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(llm_id)
+ model_type = (llm_setting or {}).get("model_type")
+ if model_type not in {"chat", "image2text"}:
+ model_type = "chat"
+
+ if not TenantLLMService.query(
+ tenant_id=tenant_id,
+ llm_name=llm_name,
+ llm_factory=llm_factory,
+ model_type=model_type,
+ ):
+ return f"`llm_id` {llm_id} doesn't exist"
+ return None
+
+
+import logging
+from api.utils.reference_metadata_utils import enrich_chunks_with_document_metadata
+
+def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None):
+ chunks = chunks_format(reference)
+ if not include_metadata:
+ logging.debug("Skipping document metadata enrichment (include_metadata=False)")
+ return chunks
+
+ normalized_fields = None
+ if metadata_fields is not None:
+ if not isinstance(metadata_fields, list):
+ return chunks
+ normalized_fields = {f for f in metadata_fields if isinstance(f, str)}
+ if not normalized_fields:
+ return chunks
+
+ logging.debug(
+ "Enriching %d chunks with document metadata (fields: %s)",
+ len(chunks),
+ "ALL" if normalized_fields is None else list(normalized_fields),
+ )
+
+ enrich_chunks_with_document_metadata(
+ chunks,
+ normalized_fields,
+ kb_field="dataset_id",
+ doc_field="document_id",
+ )
+
+ return chunks
+
+
+def _build_sse_response(body):
+ resp = Response(body, mimetype="text/event-stream")
+ resp.headers.add_header("Cache-control", "no-cache")
+ resp.headers.add_header("Connection", "keep-alive")
+ resp.headers.add_header("X-Accel-Buffering", "no")
+ resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
+ return resp
+
+
+@manager.route("/openai//chat/completions", methods=["POST"]) # noqa: F821
+@login_required
+@validate_request("model", "messages")
+async def openai_chat_completions(chat_id):
+ req = await get_request_json()
+
+ extra_body = req.get("extra_body") or {}
+ if extra_body and not isinstance(extra_body, dict):
+ return get_error_data_result("extra_body must be an object.")
+
+ need_reference = bool(extra_body.get("reference", False))
+ reference_metadata = extra_body.get("reference_metadata") or {}
+ if reference_metadata and not isinstance(reference_metadata, dict):
+ return get_error_data_result("reference_metadata must be an object.")
+ include_reference_metadata = bool(reference_metadata.get("include", False))
+ metadata_fields = reference_metadata.get("fields")
+ if metadata_fields is not None and not isinstance(metadata_fields, list):
+ return get_error_data_result("reference_metadata.fields must be an array.")
+
+ messages = req.get("messages", [])
+ if len(messages) < 1:
+ return get_error_data_result("You have to provide messages.")
+ if messages[-1]["role"] != "user":
+ return get_error_data_result("The last content of this conversation is not from user.")
+
+ prompt = messages[-1]["content"]
+ context_token_used = sum(num_tokens_from_string(message["content"]) for message in messages)
+ requested_model = req.get("model", "") or ""
+ completion_id = f"chatcmpl-{chat_id}"
+
+ dia = DialogService.query(tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value)
+ if not dia:
+ return get_error_data_result(f"You don't own the chat {chat_id}")
+ dia = dia[0]
+
+ using_placeholder_model = requested_model == "model"
+ if using_placeholder_model:
+ requested_model = dia.llm_id or requested_model
+ else:
+ llm_id_error = _validate_llm_id(requested_model, current_user.id, {"model_type": "chat"})
+ if llm_id_error:
+ return get_error_data_result(message=llm_id_error, code=RetCode.ARGUMENT_ERROR)
+ dia.llm_id = requested_model
+ if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=requested_model):
+ return get_error_data_result(message=f"Cannot use specified model {requested_model}.")
+
+ metadata_condition = extra_body.get("metadata_condition") or {}
+ if metadata_condition and not isinstance(metadata_condition, dict):
+ return get_error_data_result(message="metadata_condition must be an object.")
+
+ doc_ids_str = None
+ if metadata_condition:
+ metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or [])
+ filtered_doc_ids = meta_filter(
+ metas,
+ convert_conditions(metadata_condition),
+ metadata_condition.get("logic", "and"),
+ )
+ if metadata_condition.get("conditions") and not filtered_doc_ids:
+ filtered_doc_ids = ["-999"]
+ doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None
+
+ msg = []
+ for message in messages:
+ if message["role"] == "system":
+ continue
+ if message["role"] == "assistant" and not msg:
+ continue
+ msg.append(message)
+
+ tools = None
+ toolcall_session = None
+ stream_mode = req.get("stream", True)
+
+ if stream_mode:
+ async def streamed_response_generator():
+ token_used = 0
+ last_ans = {}
+ full_content = ""
+ final_answer = None
+ final_reference = None
+ in_think = False
+ response = {
+ "id": completion_id,
+ "choices": [
+ {
+ "delta": {
+ "content": "",
+ "role": "assistant",
+ "function_call": None,
+ "tool_calls": None,
+ "reasoning_content": "",
+ },
+ "finish_reason": None,
+ "index": 0,
+ "logprobs": None,
+ }
+ ],
+ "created": int(time.time()),
+ "model": requested_model,
+ "object": "chat.completion.chunk",
+ "system_fingerprint": "",
+ "usage": None,
+ }
+
+ try:
+ chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
+ if doc_ids_str:
+ chat_kwargs["doc_ids"] = doc_ids_str
+ async for ans in async_chat(dia, msg, True, **chat_kwargs):
+ last_ans = ans
+ if ans.get("final"):
+ if ans.get("answer"):
+ full_content = ans["answer"]
+ response["choices"][0]["delta"]["content"] = full_content
+ response["choices"][0]["delta"]["reasoning_content"] = None
+ yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
+ final_answer = full_content
+ final_reference = ans.get("reference", {})
+ continue
+ if ans.get("start_to_think"):
+ in_think = True
+ continue
+ if ans.get("end_to_think"):
+ in_think = False
+ continue
+ delta = ans.get("answer") or ""
+ if not delta:
+ continue
+ token_used += num_tokens_from_string(delta)
+ if in_think:
+ response["choices"][0]["delta"]["reasoning_content"] = delta
+ response["choices"][0]["delta"]["content"] = None
+ else:
+ full_content += delta
+ response["choices"][0]["delta"]["content"] = delta
+ response["choices"][0]["delta"]["reasoning_content"] = None
+ yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
+ except Exception as e:
+ response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
+ yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
+
+ response["choices"][0]["delta"]["content"] = None
+ response["choices"][0]["delta"]["reasoning_content"] = None
+ response["choices"][0]["finish_reason"] = "stop"
+ prompt_tokens = num_tokens_from_string(prompt)
+ response["usage"] = {
+ "prompt_tokens": prompt_tokens,
+ "completion_tokens": token_used,
+ "total_tokens": prompt_tokens + token_used,
+ }
+ if need_reference:
+ reference_payload = final_reference if final_reference is not None else last_ans.get("reference", [])
+ response["choices"][0]["delta"]["reference"] = _build_reference_chunks(
+ reference_payload,
+ include_metadata=include_reference_metadata,
+ metadata_fields=metadata_fields,
+ )
+ response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content
+ yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
+ yield "data:[DONE]\n\n"
+
+ return _build_sse_response(streamed_response_generator())
+
+ answer = None
+ chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
+ if doc_ids_str:
+ chat_kwargs["doc_ids"] = doc_ids_str
+ async for ans in async_chat(dia, msg, False, **chat_kwargs):
+ answer = ans
+ break
+
+ content = answer["answer"]
+ response = {
+ "id": completion_id,
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": requested_model,
+ "usage": {
+ "prompt_tokens": num_tokens_from_string(prompt),
+ "completion_tokens": num_tokens_from_string(content),
+ "total_tokens": num_tokens_from_string(prompt) + num_tokens_from_string(content),
+ "completion_tokens_details": {
+ "reasoning_tokens": context_token_used,
+ "accepted_prediction_tokens": num_tokens_from_string(content),
+ "rejected_prediction_tokens": 0,
+ },
+ },
+ "choices": [
+ {
+ "message": {
+ "role": "assistant",
+ "content": content,
+ },
+ "logprobs": None,
+ "finish_reason": "stop",
+ "index": 0,
+ }
+ ],
+ }
+ if need_reference:
+ response["choices"][0]["message"]["reference"] = _build_reference_chunks(
+ answer.get("reference", {}),
+ include_metadata=include_reference_metadata,
+ metadata_fields=metadata_fields,
+ )
+
+ return jsonify(response)
diff --git a/api/apps/plugin_app.py b/api/apps/restful_apis/plugin_api.py
similarity index 93%
rename from api/apps/plugin_app.py
rename to api/apps/restful_apis/plugin_api.py
index fb0a7bb6106..6d53fbc6267 100644
--- a/api/apps/plugin_app.py
+++ b/api/apps/restful_apis/plugin_api.py
@@ -21,7 +21,7 @@
from agent.plugin import GlobalPluginManager
-@manager.route('/llm_tools', methods=['GET']) # noqa: F821
+@manager.route('/plugin/tools', methods=['GET']) # noqa: F821
@login_required
def llm_tools() -> Response:
tools = GlobalPluginManager.get_llm_tools()
diff --git a/api/apps/restful_apis/search_api.py b/api/apps/restful_apis/search_api.py
index 82a357f306b..c56d0ff8344 100644
--- a/api/apps/restful_apis/search_api.py
+++ b/api/apps/restful_apis/search_api.py
@@ -14,7 +14,10 @@
# limitations under the License.
#
-from quart import request
+import json
+
+from quart import Response, request
+from api.db.services.dialog_service import async_ask
from api.apps import current_user, login_required
from api.constants import DATASET_NAME_LIMIT
@@ -168,3 +171,46 @@ def delete_search(search_id):
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
+
+
+@manager.route("/searches//completion", methods=["POST"]) # noqa: F821
+@manager.route("/searches//completions", methods=["POST"]) # noqa: F821
+@login_required
+@validate_request("question")
+async def completion(search_id):
+ if not SearchService.accessible4deletion(search_id, current_user.id):
+ return get_json_result(
+ data=False,
+ message="No authorization.",
+ code=RetCode.AUTHENTICATION_ERROR,
+ )
+
+ req = await get_request_json()
+ uid = current_user.id
+ search_app = SearchService.get_detail(search_id)
+ if not search_app:
+ return get_data_error_result(message=f"Cannot find search {search_id}")
+
+ search_config = search_app.get("search_config", {})
+ kb_ids = search_config.get("kb_ids") or req.get("kb_ids") or []
+ if not kb_ids:
+ return get_data_error_result(message="`kb_ids` is required.")
+
+ async def stream():
+ nonlocal req, uid, kb_ids, search_config
+ try:
+ async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config):
+ yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
+ except Exception as ex:
+ yield "data:" + json.dumps(
+ {"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}},
+ ensure_ascii=False,
+ ) + "\n\n"
+ yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
+
+ resp = Response(stream(), mimetype="text/event-stream")
+ resp.headers.add_header("Cache-control", "no-cache")
+ resp.headers.add_header("Connection", "keep-alive")
+ resp.headers.add_header("X-Accel-Buffering", "no")
+ resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
+ return resp
diff --git a/api/apps/api_app.py b/api/apps/restful_apis/stats_api.py
similarity index 97%
rename from api/apps/api_app.py
rename to api/apps/restful_apis/stats_api.py
index 0d5d62334ed..7185194327d 100644
--- a/api/apps/api_app.py
+++ b/api/apps/restful_apis/stats_api.py
@@ -20,7 +20,7 @@
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response
from api.apps import login_required, current_user
-@manager.route('/stats', methods=['GET']) # noqa: F821
+@manager.route('/system/stats', methods=['GET']) # noqa: F821
@login_required
def stats():
try:
diff --git a/api/apps/restful_apis/system_api.py b/api/apps/restful_apis/system_api.py
index 467d9111d90..55c34c25a34 100644
--- a/api/apps/restful_apis/system_api.py
+++ b/api/apps/restful_apis/system_api.py
@@ -14,25 +14,31 @@
# limitations under the License.
#
+import json
+import logging
+from datetime import datetime
+from timeit import default_timer as timer
+
from quart import jsonify
from api.apps import login_required, current_user
from api.utils.api_utils import get_json_result, get_data_error_result, server_error_response, generate_confirmation_token
-from api.utils.health_utils import run_health_checks
+from api.utils.health_utils import run_health_checks, get_oceanbase_status
from common.versions import get_ragflow_version
-from datetime import datetime
from common.time_utils import current_timestamp, datetime_format
from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService
+from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import UserTenantService
from common.log_utils import get_log_levels, set_log_level
+from common import settings
+from rag.utils.redis_conn import REDIS_CONN
@manager.route("/system/ping", methods=["GET"]) # noqa: F821
async def ping():
return "pong", 200
@manager.route("/system/version", methods=["GET"]) # noqa: F821
-@login_required
def version():
"""
Get the current version of the application.
@@ -53,6 +59,174 @@ def version():
"""
return get_json_result(data=get_ragflow_version())
+
+@manager.route("/system/status", methods=["GET"]) # noqa: F821
+@login_required
+def status():
+ """
+ Get the system status.
+ ---
+ tags:
+ - System
+ security:
+ - ApiKeyAuth: []
+ responses:
+ 200:
+ description: System is operational.
+ schema:
+ type: object
+ properties:
+ es:
+ type: object
+ description: Elasticsearch status.
+ storage:
+ type: object
+ description: Storage status.
+ database:
+ type: object
+ description: Database status.
+ 503:
+ description: Service unavailable.
+ schema:
+ type: object
+ properties:
+ error:
+ type: string
+ description: Error message.
+ """
+ res = {}
+ st = timer()
+ try:
+ res["doc_engine"] = settings.docStoreConn.health()
+ res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
+ except Exception as e:
+ res["doc_engine"] = {
+ "type": "unknown",
+ "status": "red",
+ "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
+ "error": str(e),
+ }
+
+ st = timer()
+ try:
+ settings.STORAGE_IMPL.health()
+ res["storage"] = {
+ "storage": settings.STORAGE_IMPL_TYPE.lower(),
+ "status": "green",
+ "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
+ }
+ except Exception as e:
+ res["storage"] = {
+ "storage": settings.STORAGE_IMPL_TYPE.lower(),
+ "status": "red",
+ "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
+ "error": str(e),
+ }
+
+ st = timer()
+ try:
+ KnowledgebaseService.get_by_id("x")
+ res["database"] = {
+ "database": settings.DATABASE_TYPE.lower(),
+ "status": "green",
+ "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
+ }
+ except Exception as e:
+ res["database"] = {
+ "database": settings.DATABASE_TYPE.lower(),
+ "status": "red",
+ "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
+ "error": str(e),
+ }
+
+ st = timer()
+ try:
+ if not REDIS_CONN.health():
+ raise Exception("Lost connection!")
+ res["redis"] = {
+ "status": "green",
+ "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
+ }
+ except Exception as e:
+ res["redis"] = {
+ "status": "red",
+ "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
+ "error": str(e),
+ }
+
+ task_executor_heartbeats = {}
+ try:
+ task_executors = REDIS_CONN.smembers("TASKEXE")
+ now = datetime.now().timestamp()
+ for task_executor_id in task_executors:
+ heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60 * 30, now)
+ heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
+ task_executor_heartbeats[task_executor_id] = heartbeats
+ except Exception:
+ logging.exception("get task executor heartbeats failed!")
+ res["task_executor_heartbeats"] = task_executor_heartbeats
+
+ return get_json_result(data=res)
+
+
+@manager.route("/system/oceanbase/status", methods=["GET"]) # noqa: F821
+@login_required
+def oceanbase_status():
+ """
+ Get OceanBase health status and performance metrics.
+ ---
+ tags:
+ - System
+ security:
+ - ApiKeyAuth: []
+ responses:
+ 200:
+ description: OceanBase status retrieved successfully.
+ schema:
+ type: object
+ properties:
+ status:
+ type: string
+ description: Status (alive/timeout).
+ message:
+ type: object
+ description: Detailed status information including health and performance metrics.
+ """
+ try:
+ status_info = get_oceanbase_status()
+ return get_json_result(data=status_info)
+ except Exception as e:
+ return get_json_result(
+ data={
+ "status": "error",
+ "message": f"Failed to get OceanBase status: {str(e)}"
+ },
+ code=500
+ )
+
+
+@manager.route("/system/config", methods=["GET"]) # noqa: F821
+def get_config():
+ """
+ Get system configuration.
+ ---
+ tags:
+ - System
+ responses:
+ 200:
+ description: Return system configuration
+ schema:
+ type: object
+ properties:
+ registerEnable:
+ type: integer 0 means disabled, 1 means enabled
+ description: Whether user registration is enabled
+ """
+ return get_json_result(data={
+ "registerEnabled": settings.REGISTER_ENABLED,
+ "disablePasswordLogin": settings.DISABLE_PASSWORD_LOGIN,
+ })
+
@manager.route("/system/healthz", methods=["GET"]) # noqa: F821
def healthz():
result, all_ok = run_health_checks()
diff --git a/api/apps/restful_apis/task_api.py b/api/apps/restful_apis/task_api.py
new file mode 100644
index 00000000000..2bd7a41802f
--- /dev/null
+++ b/api/apps/restful_apis/task_api.py
@@ -0,0 +1,101 @@
+#
+# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import logging
+from datetime import datetime
+
+from api.apps import login_required
+from api.db.services.task_service import TaskService, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID
+from api.utils.api_utils import (
+ get_json_result,
+ get_request_json,
+ validate_request,
+)
+from common.constants import RetCode, TaskStatus
+from rag.utils.redis_conn import REDIS_CONN
+
+
+@manager.route("/tasks//cancel", methods=["POST"]) # noqa: F821
+@login_required
+async def cancel_task(task_id):
+ """Cancel a running task.
+ """
+ return await _cancel_task(task_id)
+
+
+@manager.route("/tasks/", methods=["PATCH"]) # noqa: F821
+@login_required
+@validate_request("action")
+async def patch_task(task_id):
+ req = await get_request_json()
+ action = req.get("action")
+
+ if action != "stop":
+ return get_json_result(
+ code=RetCode.ARGUMENT_ERROR,
+ message=f"Invalid action '{action}'. Only 'stop' is supported.",
+ )
+
+ return await _cancel_task(task_id)
+
+
+async def _cancel_task(task_id):
+ """
+ Sets a Redis cancel flag, updates the task progress to -1 (cancelled),
+ and marks the associated document's run status as CANCEL if applicable.
+ """
+ try:
+ REDIS_CONN.set(f"{task_id}-cancel", "x")
+ except Exception as e:
+ logging.exception("Failed to set cancel flag for task %s: %s", task_id, str(e))
+ return get_json_result(
+ code=RetCode.CONNECTION_ERROR,
+ message="Failed to stop task",
+ )
+
+ exists, task = TaskService.get_by_id(task_id)
+ if not exists:
+ return get_json_result(data=True)
+
+ # Append a cancellation message so the user can see it in progress_msg.
+ try:
+ cancel_msg = f"\n{datetime.now().strftime('%H:%M:%S')} Task stopped by user."
+ # Only transition to -1 if the task is still in a non-terminal state,
+ # mirroring TaskService.update_progress semantics.
+ TaskService.model.update(
+ progress_msg=TaskService.model.progress_msg + cancel_msg,
+ progress=-1,
+ ).where(
+ (TaskService.model.id == task_id)
+ & (TaskService.model.progress >= 0)
+ & (TaskService.model.progress < 1)
+ ).execute()
+ except Exception as e:
+ logging.warning("Failed to update task %s progress after cancellation: %s", task_id, str(e))
+
+ # If the task belongs to a document, also mark the document's run status as
+ # cancelled so that the UI reflects the state correctly.
+ try:
+ from api.db.services.document_service import DocumentService
+ doc_id = task.doc_id
+ if doc_id and doc_id not in (CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID):
+ _, doc = DocumentService.get_by_id(doc_id)
+ if doc and str(doc.run) in (TaskStatus.RUNNING.value, TaskStatus.SCHEDULE.value):
+ DocumentService.update_by_id(doc_id, {"run": TaskStatus.CANCEL.value, "progress": 0})
+ except Exception as e:
+ logging.warning("Failed to update document run status for task %s: %s", task_id, str(e))
+
+ logging.info(f"Cancel task succeeded: task_id={task_id} doc_id={task.doc_id}")
+ return get_json_result(data=True)
diff --git a/api/apps/tenant_app.py b/api/apps/restful_apis/tenant_api.py
similarity index 59%
rename from api/apps/tenant_app.py
rename to api/apps/restful_apis/tenant_api.py
index be6305e8911..4d45337cb0b 100644
--- a/api/apps/tenant_app.py
+++ b/api/apps/restful_apis/tenant_api.py
@@ -1,5 +1,5 @@
#
-# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,48 +13,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import logging
import asyncio
+import logging
+
+from api.apps import current_user, login_required
from api.db import UserTenantRole
from api.db.db_models import UserTenant
-from api.db.services.user_service import UserTenantService, UserService
-
+from api.db.services.user_service import UserService, UserTenantService
+from api.utils.api_utils import (
+ get_data_error_result,
+ get_json_result,
+ get_request_json,
+ server_error_response,
+ validate_request,
+)
+from api.utils.web_utils import send_invite_email
+from common import settings
from common.constants import RetCode, StatusEnum
from common.misc_utils import get_uuid
from common.time_utils import delta_seconds
-from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request
-from api.utils.web_utils import send_invite_email
-from common import settings
-from api.apps import login_required, current_user
-@manager.route("//user/list", methods=["GET"]) # noqa: F821
+@manager.route("/tenants//users", methods=["GET"]) # noqa: F821
@login_required
def user_list(tenant_id):
if current_user.id != tenant_id:
return get_json_result(
data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR)
+ message="No authorization.",
+ code=RetCode.AUTHENTICATION_ERROR,
+ )
try:
users = UserTenantService.get_by_tenant_id(tenant_id)
- for u in users:
- u["delta_seconds"] = delta_seconds(str(u["update_date"]))
+ for user in users:
+ user["delta_seconds"] = delta_seconds(str(user["update_date"]))
return get_json_result(data=users)
- except Exception as e:
- return server_error_response(e)
+ except Exception as exc:
+ return server_error_response(exc)
-@manager.route('//user', methods=['POST']) # noqa: F821
+@manager.route("/tenants//users", methods=["POST"]) # noqa: F821
@login_required
@validate_request("email")
async def create(tenant_id):
if current_user.id != tenant_id:
return get_json_result(
data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR)
+ message="No authorization.",
+ code=RetCode.AUTHENTICATION_ERROR,
+ )
req = await get_request_json()
invite_user_email = req["email"]
@@ -71,7 +79,8 @@ async def create(tenant_id):
if user_tenant_role == UserTenantRole.OWNER:
return get_data_error_result(message=f"{invite_user_email} is the owner of the team.")
return get_data_error_result(
- message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.")
+ message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid."
+ )
UserTenantService.save(
id=get_uuid(),
@@ -79,10 +88,10 @@ async def create(tenant_id):
tenant_id=tenant_id,
invited_by=current_user.id,
role=UserTenantRole.INVITE,
- status=StatusEnum.VALID.value)
+ status=StatusEnum.VALID.value,
+ )
try:
-
user_name = ""
_, user = UserService.get_by_id(current_user.id)
if user:
@@ -93,52 +102,62 @@ async def create(tenant_id):
to_email=invite_user_email,
invite_url=settings.MAIL_FRONTEND_URL,
tenant_id=tenant_id,
- inviter=user_name or current_user.email
+ inviter=user_name or current_user.email,
)
)
- except Exception as e:
- logging.exception(f"Failed to send invite email to {invite_user_email}: {e}")
- return get_json_result(data=False, message="Failed to send invite email.", code=RetCode.SERVER_ERROR)
- usr = invite_users[0].to_dict()
- usr = {k: v for k, v in usr.items() if k in ["id", "avatar", "email", "nickname"]}
+ except Exception as exc:
+ logging.exception(f"Failed to send invite email to {invite_user_email}: {exc}")
+ return get_json_result(
+ data=False,
+ message="Failed to send invite email.",
+ code=RetCode.SERVER_ERROR,
+ )
- return get_json_result(data=usr)
+ user = invite_users[0].to_dict()
+ user = {k: v for k, v in user.items() if k in ["id", "avatar", "email", "nickname"]}
+ return get_json_result(data=user)
-@manager.route('//user/', methods=['DELETE']) # noqa: F821
+@manager.route("/tenants//users", methods=["DELETE"]) # noqa: F821
@login_required
-def rm(tenant_id, user_id):
+@validate_request("user_id")
+async def rm(tenant_id):
+ req = await get_request_json()
+ user_id = req["user_id"]
if current_user.id != tenant_id and current_user.id != user_id:
return get_json_result(
data=False,
- message='No authorization.',
- code=RetCode.AUTHENTICATION_ERROR)
+ message="No authorization.",
+ code=RetCode.AUTHENTICATION_ERROR,
+ )
try:
UserTenantService.filter_delete([UserTenant.tenant_id == tenant_id, UserTenant.user_id == user_id])
return get_json_result(data=True)
- except Exception as e:
- return server_error_response(e)
+ except Exception as exc:
+ return server_error_response(exc)
-@manager.route("/list", methods=["GET"]) # noqa: F821
+@manager.route("/tenants", methods=["GET"]) # noqa: F821
@login_required
def tenant_list():
try:
users = UserTenantService.get_tenants_by_user_id(current_user.id)
- for u in users:
- u["delta_seconds"] = delta_seconds(str(u["update_date"]))
+ for user in users:
+ user["delta_seconds"] = delta_seconds(str(user["update_date"]))
return get_json_result(data=users)
- except Exception as e:
- return server_error_response(e)
+ except Exception as exc:
+ return server_error_response(exc)
-@manager.route("/agree/", methods=["PUT"]) # noqa: F821
+@manager.route("/tenants/", methods=["PATCH"]) # noqa: F821
@login_required
def agree(tenant_id):
try:
- UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id],
- {"role": UserTenantRole.NORMAL})
+ UserTenantService.filter_update(
+ [UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id],
+ {"role": UserTenantRole.NORMAL},
+ )
return get_json_result(data=True)
- except Exception as e:
- return server_error_response(e)
+ except Exception as exc:
+ return server_error_response(exc)
diff --git a/api/apps/user_app.py b/api/apps/restful_apis/user_api.py
similarity index 75%
rename from api/apps/user_app.py
rename to api/apps/restful_apis/user_api.py
index 74248992696..714453ac6fa 100644
--- a/api/apps/user_app.py
+++ b/api/apps/restful_apis/user_api.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import json
import logging
import string
import os
@@ -60,10 +59,9 @@
captcha_key,
)
from common import settings
-from common.http_client import async_request
-@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
+@manager.route("/auth/login", methods=["POST"]) # noqa: F821
async def login():
"""
User login endpoint.
@@ -140,7 +138,7 @@ async def login():
)
-@manager.route("/login/channels", methods=["GET"]) # noqa: F821
+@manager.route("/auth/login/channels", methods=["GET"]) # noqa: F821
async def get_login_channels():
"""
Get all supported authentication channels.
@@ -161,7 +159,7 @@ async def get_login_channels():
return get_json_result(data=[], message=f"Load channels failure, error: {str(e)}", code=RetCode.EXCEPTION_ERROR)
-@manager.route("/login/", methods=["GET"]) # noqa: F821
+@manager.route("/auth/login/", methods=["GET"]) # noqa: F821
async def oauth_login(channel):
channel_config = settings.OAUTH_CONFIG.get(channel)
if not channel_config:
@@ -174,7 +172,7 @@ async def oauth_login(channel):
return redirect(auth_url)
-@manager.route("/oauth/callback/", methods=["GET"]) # noqa: F821
+@manager.route("/auth/oauth//callback", methods=["GET"]) # noqa: F821
async def oauth_callback(channel):
"""
Handle the OAuth/OIDC callback for various channels dynamically.
@@ -269,224 +267,7 @@ async def oauth_callback(channel):
return redirect(f"/?error={str(e)}")
-@manager.route("/github_callback", methods=["GET"]) # noqa: F821
-async def github_callback():
- """
- **Deprecated**, Use `/oauth/callback/` instead.
-
- GitHub OAuth callback endpoint.
- ---
- tags:
- - OAuth
- parameters:
- - in: query
- name: code
- type: string
- required: true
- description: Authorization code from GitHub.
- responses:
- 200:
- description: Authentication successful.
- schema:
- type: object
- """
- res = await async_request(
- "POST",
- settings.GITHUB_OAUTH.get("url"),
- data={
- "client_id": settings.GITHUB_OAUTH.get("client_id"),
- "client_secret": settings.GITHUB_OAUTH.get("secret_key"),
- "code": request.args.get("code"),
- },
- headers={"Accept": "application/json"},
- )
- res = res.json()
- if "error" in res:
- return redirect("/?error=%s" % res["error_description"])
-
- if "user:email" not in res["scope"].split(","):
- return redirect("/?error=user:email not in scope")
-
- session["access_token"] = res["access_token"]
- session["access_token_from"] = "github"
- user_info = await user_info_from_github(session["access_token"])
- email_address = user_info["email"]
- users = UserService.query(email=email_address)
- user_id = get_uuid()
- if not users:
- # User isn't try to register
- try:
- try:
- avatar = await download_img(user_info["avatar_url"])
- except Exception as e:
- logging.exception(e)
- avatar = ""
- users = user_register(
- user_id,
- {
- "access_token": session["access_token"],
- "email": email_address,
- "avatar": avatar,
- "nickname": user_info["login"],
- "login_channel": "github",
- "last_login_time": get_format_time(),
- "is_superuser": False,
- },
- )
- if not users:
- raise Exception(f"Fail to register {email_address}.")
- if len(users) > 1:
- raise Exception(f"Same email: {email_address} exists!")
-
- # Try to log in
- user = users[0]
- login_user(user)
- return redirect("/?auth=%s" % user.get_id())
- except Exception as e:
- rollback_user_registration(user_id)
- logging.exception(e)
- return redirect("/?error=%s" % str(e))
-
- # User has already registered, try to log in
- user = users[0]
- user.access_token = get_uuid()
- if user and hasattr(user, 'is_active') and user.is_active == "0":
- return redirect("/?error=user_inactive")
- login_user(user)
- user.save()
- return redirect("/?auth=%s" % user.get_id())
-
-
-@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821
-async def feishu_callback():
- """
- Feishu OAuth callback endpoint.
- ---
- tags:
- - OAuth
- parameters:
- - in: query
- name: code
- type: string
- required: true
- description: Authorization code from Feishu.
- responses:
- 200:
- description: Authentication successful.
- schema:
- type: object
- """
- app_access_token_res = await async_request(
- "POST",
- settings.FEISHU_OAUTH.get("app_access_token_url"),
- data=json.dumps(
- {
- "app_id": settings.FEISHU_OAUTH.get("app_id"),
- "app_secret": settings.FEISHU_OAUTH.get("app_secret"),
- }
- ),
- headers={"Content-Type": "application/json; charset=utf-8"},
- )
- app_access_token_res = app_access_token_res.json()
- if app_access_token_res["code"] != 0:
- return redirect("/?error=%s" % app_access_token_res)
-
- res = await async_request(
- "POST",
- settings.FEISHU_OAUTH.get("user_access_token_url"),
- data=json.dumps(
- {
- "grant_type": settings.FEISHU_OAUTH.get("grant_type"),
- "code": request.args.get("code"),
- }
- ),
- headers={
- "Content-Type": "application/json; charset=utf-8",
- "Authorization": f"Bearer {app_access_token_res['app_access_token']}",
- },
- )
- res = res.json()
- if res["code"] != 0:
- return redirect("/?error=%s" % res["message"])
-
- if "contact:user.email:readonly" not in res["data"]["scope"].split():
- return redirect("/?error=contact:user.email:readonly not in scope")
- session["access_token"] = res["data"]["access_token"]
- session["access_token_from"] = "feishu"
- user_info = await user_info_from_feishu(session["access_token"])
- email_address = user_info["email"]
- users = UserService.query(email=email_address)
- user_id = get_uuid()
- if not users:
- # User isn't try to register
- try:
- try:
- avatar = await download_img(user_info["avatar_url"])
- except Exception as e:
- logging.exception(e)
- avatar = ""
- users = user_register(
- user_id,
- {
- "access_token": session["access_token"],
- "email": email_address,
- "avatar": avatar,
- "nickname": user_info["en_name"],
- "login_channel": "feishu",
- "last_login_time": get_format_time(),
- "is_superuser": False,
- },
- )
- if not users:
- raise Exception(f"Fail to register {email_address}.")
- if len(users) > 1:
- raise Exception(f"Same email: {email_address} exists!")
-
- # Try to log in
- user = users[0]
- login_user(user)
- return redirect("/?auth=%s" % user.get_id())
- except Exception as e:
- rollback_user_registration(user_id)
- logging.exception(e)
- return redirect("/?error=%s" % str(e))
-
- # User has already registered, try to log in
- user = users[0]
- if user and hasattr(user, 'is_active') and user.is_active == "0":
- return redirect("/?error=user_inactive")
- user.access_token = get_uuid()
- login_user(user)
- user.save()
- return redirect("/?auth=%s" % user.get_id())
-
-
-async def user_info_from_feishu(access_token):
- headers = {
- "Content-Type": "application/json; charset=utf-8",
- "Authorization": f"Bearer {access_token}",
- }
- res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers)
- user_info = res.json()["data"]
- user_info["email"] = None if user_info.get("email") == "" else user_info["email"]
- return user_info
-
-
-async def user_info_from_github(access_token):
- headers = {"Accept": "application/json", "Authorization": f"token {access_token}"}
- res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers)
- user_info = res.json()
- email_info_response = await async_request(
- "GET",
- f"https://api.github.com/user/emails?access_token={access_token}",
- headers=headers,
- )
- email_info = email_info_response.json()
- user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"]
- return user_info
-
-
-@manager.route("/logout", methods=["GET"]) # noqa: F821
+@manager.route("/auth/logout", methods=["POST"]) # noqa: F821
@login_required
async def log_out():
"""
@@ -508,7 +289,7 @@ async def log_out():
return get_json_result(data=True)
-@manager.route("/setting", methods=["POST"]) # noqa: F821
+@manager.route("/users/me", methods=["PATCH"]) # noqa: F821
@login_required
async def setting_user():
"""
@@ -576,7 +357,7 @@ async def setting_user():
return get_json_result(data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR)
-@manager.route("/info", methods=["GET"]) # noqa: F821
+@manager.route("/users/me", methods=["GET"]) # noqa: F821
@login_required
async def user_profile():
"""
@@ -667,7 +448,7 @@ def user_register(user_id, user):
return UserService.query(email=user["email"])
-@manager.route("/register", methods=["POST"]) # noqa: F821
+@manager.route("/users", methods=["POST"]) # noqa: F821
@validate_request("nickname", "email", "password")
async def user_add():
"""
@@ -761,7 +542,7 @@ async def user_add():
)
-@manager.route("/tenant_info", methods=["GET"]) # noqa: F821
+@manager.route("/users/me/models", methods=["GET"]) # noqa: F821
@login_required
async def tenant_info():
"""
@@ -799,7 +580,7 @@ async def tenant_info():
return server_error_response(e)
-@manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821
+@manager.route("/users/me/models", methods=["PATCH"]) # noqa: F821
@login_required
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
async def set_tenant_info():
@@ -849,7 +630,7 @@ async def set_tenant_info():
return server_error_response(e)
-@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821
+@manager.route("/auth/password/forgot/captcha", methods=["POST"]) # noqa: F821
async def forget_get_captcha():
"""
GET /forget/captcha?email=
@@ -877,7 +658,7 @@ async def forget_get_captcha():
return response
-@manager.route("/forget/otp", methods=["POST"]) # noqa: F821
+@manager.route("/auth/password/forgot/otp", methods=["POST"]) # noqa: F821
async def forget_send_otp():
"""
POST /forget/otp
@@ -947,7 +728,7 @@ def _verified_key(email: str) -> str:
return f"otp:verified:{email}"
-@manager.route("/forget/verify-otp", methods=["POST"]) # noqa: F821
+@manager.route("/auth/password/forgot/otp/verify", methods=["POST"]) # noqa: F821
async def forget_verify_otp():
"""
Verify email + OTP only. On success:
@@ -1008,7 +789,7 @@ async def forget_verify_otp():
return get_json_result(data=True, code=RetCode.SUCCESS, message="otp verified")
-@manager.route("/forget/reset-password", methods=["POST"]) # noqa: F821
+@manager.route("/auth/password/reset", methods=["POST"]) # noqa: F821
async def forget_reset_password():
"""
Reset password after successful OTP verification.
diff --git a/api/apps/sdk/agents.py b/api/apps/sdk/agents.py
deleted file mode 100644
index f7f36fa19f0..00000000000
--- a/api/apps/sdk/agents.py
+++ /dev/null
@@ -1,938 +0,0 @@
-#
-# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import asyncio
-import base64
-import hashlib
-import hmac
-import ipaddress
-import json
-import logging
-import time
-from typing import Any, cast
-
-import jwt
-
-from agent.canvas import Canvas
-from api.apps.services.canvas_replica_service import CanvasReplicaService
-from api.db import CanvasCategory
-from api.db.services.canvas_service import UserCanvasService
-from api.db.services.file_service import FileService
-from api.db.services.user_service import UserService
-from api.db.services.user_canvas_version import UserCanvasVersionService
-from common.constants import RetCode
-from common.misc_utils import get_uuid
-from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
-from api.utils.api_utils import get_result
-from quart import request, Response
-from rag.utils.redis_conn import REDIS_CONN
-
-
-def _get_user_nickname(user_id: str) -> str:
- exists, user = UserService.get_by_id(user_id)
- if not exists:
- return user_id
- return str(getattr(user, "nickname", "") or user_id)
-
-
-@manager.route('/agents', methods=['GET']) # noqa: F821
-@token_required
-def list_agents(tenant_id):
- id = request.args.get("id")
- title = request.args.get("title")
- if id or title:
- canvas = UserCanvasService.query(id=id, title=title, user_id=tenant_id)
- if not canvas:
- return get_error_data_result("The agent doesn't exist.")
- page_number = int(request.args.get("page", 1))
- items_per_page = int(request.args.get("page_size", 30))
- order_by = request.args.get("orderby", "update_time")
- if str(request.args.get("desc","false")).lower() == "false":
- desc = False
- else:
- desc = True
- canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title)
- return get_result(data=canvas)
-
-
-@manager.route("/agents", methods=["POST"]) # noqa: F821
-@token_required
-async def create_agent(tenant_id: str):
- req: dict[str, Any] = cast(dict[str, Any], await get_request_json())
- req["user_id"] = tenant_id
-
- if req.get("dsl") is not None:
- try:
- req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
- except ValueError as e:
- return get_json_result(data=False, message=str(e), code=RetCode.ARGUMENT_ERROR)
- else:
- return get_json_result(data=False, message="No DSL data in request.", code=RetCode.ARGUMENT_ERROR)
-
- if req.get("title") is not None:
- req["title"] = req["title"].strip()
- else:
- return get_json_result(data=False, message="No title in request.", code=RetCode.ARGUMENT_ERROR)
-
- if UserCanvasService.query(user_id=tenant_id, title=req["title"]):
- return get_data_error_result(message=f"Agent with title {req['title']} already exists.")
-
- agent_id = get_uuid()
- req["id"] = agent_id
-
- if not UserCanvasService.save(**req):
- return get_data_error_result(message="Fail to create agent.")
-
- owner_nickname = _get_user_nickname(tenant_id)
- UserCanvasVersionService.save_or_replace_latest(
- user_canvas_id=agent_id,
- title=UserCanvasVersionService.build_version_title(owner_nickname, req.get("title")),
- dsl=req["dsl"]
- )
-
- return get_json_result(data=True)
-
-
-@manager.route("/agents/", methods=["PUT"]) # noqa: F821
-@token_required
-async def update_agent(tenant_id: str, agent_id: str):
- req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None}
- req["user_id"] = tenant_id
-
- if req.get("dsl") is not None:
- try:
- req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
- except ValueError as e:
- return get_json_result(data=False, message=str(e), code=RetCode.ARGUMENT_ERROR)
-
- if req.get("title") is not None:
- req["title"] = req["title"].strip()
-
- if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
- return get_json_result(
- data=False, message="Only owner of canvas authorized for this operation.",
- code=RetCode.OPERATING_ERROR)
-
- _, current_agent = UserCanvasService.get_by_id(agent_id)
- agent_title_for_version = req.get("title") or (current_agent.title if current_agent else "")
- owner_nickname = _get_user_nickname(tenant_id)
-
- UserCanvasService.update_by_id(agent_id, req)
-
- if req.get("dsl") is not None:
- UserCanvasVersionService.save_or_replace_latest(
- user_canvas_id=agent_id,
- title=UserCanvasVersionService.build_version_title(owner_nickname, agent_title_for_version),
- dsl=req["dsl"]
- )
-
- return get_json_result(data=True)
-
-
-@manager.route("/agents/", methods=["DELETE"]) # noqa: F821
-@token_required
-def delete_agent(tenant_id: str, agent_id: str):
- if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
- return get_json_result(
- data=False, message="Only owner of canvas authorized for this operation.",
- code=RetCode.OPERATING_ERROR)
-
- UserCanvasService.delete_by_id(agent_id)
- return get_json_result(data=True)
-
-@manager.route("/webhook/", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
-@manager.route("/webhook_test/",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
-async def webhook(agent_id: str):
- is_test = request.path.startswith("/api/v1/webhook_test")
- start_ts = time.time()
-
- # 1. Fetch canvas by agent_id
- exists, cvs = UserCanvasService.get_by_id(agent_id)
- if not exists:
- return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
-
- # 2. Check canvas category
- if cvs.canvas_category == CanvasCategory.DataFlow:
- return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
-
- # 3. Load DSL from canvas
- dsl = getattr(cvs, "dsl", None)
- if not isinstance(dsl, dict):
- return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
-
- # 4. Check webhook configuration in DSL
- webhook_cfg = {}
- components = dsl.get("components", {})
- for k, _ in components.items():
- cpn_obj = components[k]["obj"]
- if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
- webhook_cfg = cpn_obj["params"]
-
- if not webhook_cfg:
- return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
-
- # 5. Validate request method against webhook_cfg.methods
- allowed_methods = webhook_cfg.get("methods", [])
- request_method = request.method.upper()
- if allowed_methods and request_method not in allowed_methods:
- return get_data_error_result(
- code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
- ),RetCode.BAD_REQUEST
-
- # 6. Validate webhook security
- async def validate_webhook_security(security_cfg: dict):
- """Validate webhook security rules based on security configuration."""
-
- if not security_cfg:
- return # No security config → allowed by default
-
- # 1. Validate max body size
- await _validate_max_body_size(security_cfg)
-
- # 2. Validate IP whitelist
- _validate_ip_whitelist(security_cfg)
-
- # # 3. Validate rate limiting
- _validate_rate_limit(security_cfg)
-
- # 4. Validate authentication
- auth_type = security_cfg.get("auth_type", "none")
-
- if auth_type == "none":
- return
-
- if auth_type == "token":
- _validate_token_auth(security_cfg)
-
- elif auth_type == "basic":
- _validate_basic_auth(security_cfg)
-
- elif auth_type == "jwt":
- _validate_jwt_auth(security_cfg)
-
- else:
- raise Exception(f"Unsupported auth_type: {auth_type}")
-
- async def _validate_max_body_size(security_cfg):
- """Check request size does not exceed max_body_size."""
- max_size = security_cfg.get("max_body_size")
- if not max_size:
- return
-
- # Convert "10MB" → bytes
- units = {"kb": 1024, "mb": 1024**2}
- size_str = max_size.lower()
-
- for suffix, factor in units.items():
- if size_str.endswith(suffix):
- limit = int(size_str.replace(suffix, "")) * factor
- break
- else:
- raise Exception("Invalid max_body_size format")
- MAX_LIMIT = 10 * 1024 * 1024 # 10MB
- if limit > MAX_LIMIT:
- raise Exception("max_body_size exceeds maximum allowed size (10MB)")
-
- content_length = request.content_length or 0
- if content_length > limit:
- raise Exception(f"Request body too large: {content_length} > {limit}")
-
- def _validate_ip_whitelist(security_cfg):
- """Allow only IPs listed in ip_whitelist."""
- whitelist = security_cfg.get("ip_whitelist", [])
- if not whitelist:
- return
-
- client_ip = request.remote_addr
-
-
- for rule in whitelist:
- if "/" in rule:
- # CIDR notation
- if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
- return
- else:
- # Single IP
- if client_ip == rule:
- return
-
- raise Exception(f"IP {client_ip} is not allowed by whitelist")
-
- def _validate_rate_limit(security_cfg):
- """Simple in-memory rate limiting."""
- rl = security_cfg.get("rate_limit")
- if not rl:
- return
-
- limit = int(rl.get("limit", 60))
- if limit <= 0:
- raise Exception("rate_limit.limit must be > 0")
- per = rl.get("per", "minute")
-
- window = {
- "second": 1,
- "minute": 60,
- "hour": 3600,
- "day": 86400,
- }.get(per)
-
- if not window:
- raise Exception(f"Invalid rate_limit.per: {per}")
-
- capacity = limit
- rate = limit / window
- cost = 1
-
- key = f"rl:tb:{agent_id}"
- now = time.time()
-
- try:
- res = REDIS_CONN.lua_token_bucket(
- keys=[key],
- args=[capacity, rate, now, cost],
- client=REDIS_CONN.REDIS,
- )
-
- allowed = int(res[0])
- if allowed != 1:
- raise Exception("Too many requests (rate limit exceeded)")
-
- except Exception as e:
- raise Exception(f"Rate limit error: {e}")
-
- def _validate_token_auth(security_cfg):
- """Validate header-based token authentication."""
- token_cfg = security_cfg.get("token",{})
- header = token_cfg.get("token_header")
- token_value = token_cfg.get("token_value")
-
- provided = request.headers.get(header)
- if provided != token_value:
- raise Exception("Invalid token authentication")
-
- def _validate_basic_auth(security_cfg):
- """Validate HTTP Basic Auth credentials."""
- auth_cfg = security_cfg.get("basic_auth", {})
- username = auth_cfg.get("username")
- password = auth_cfg.get("password")
-
- auth = request.authorization
- if not auth or auth.username != username or auth.password != password:
- raise Exception("Invalid Basic Auth credentials")
-
- def _validate_jwt_auth(security_cfg):
- """Validate JWT token in Authorization header."""
- jwt_cfg = security_cfg.get("jwt", {})
- secret = jwt_cfg.get("secret")
- if not secret:
- raise Exception("JWT secret not configured")
-
- auth_header = request.headers.get("Authorization", "")
- if not auth_header.startswith("Bearer "):
- raise Exception("Missing Bearer token")
-
- token = auth_header[len("Bearer "):].strip()
- if not token:
- raise Exception("Empty Bearer token")
-
- alg = (jwt_cfg.get("algorithm") or "HS256").upper()
-
- decode_kwargs = {
- "key": secret,
- "algorithms": [alg],
- }
- options = {}
- if jwt_cfg.get("audience"):
- decode_kwargs["audience"] = jwt_cfg["audience"]
- options["verify_aud"] = True
- else:
- options["verify_aud"] = False
-
- if jwt_cfg.get("issuer"):
- decode_kwargs["issuer"] = jwt_cfg["issuer"]
- options["verify_iss"] = True
- else:
- options["verify_iss"] = False
- try:
- decoded = jwt.decode(
- token,
- options=options,
- **decode_kwargs,
- )
- except Exception as e:
- raise Exception(f"Invalid JWT: {str(e)}")
-
- raw_required_claims = jwt_cfg.get("required_claims", [])
- if isinstance(raw_required_claims, str):
- required_claims = [raw_required_claims]
- elif isinstance(raw_required_claims, (list, tuple, set)):
- required_claims = list(raw_required_claims)
- else:
- required_claims = []
-
- required_claims = [
- c for c in required_claims
- if isinstance(c, str) and c.strip()
- ]
-
- RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
- for claim in required_claims:
- if claim in RESERVED_CLAIMS:
- raise Exception(f"Reserved JWT claim cannot be required: {claim}")
-
- for claim in required_claims:
- if claim not in decoded:
- raise Exception(f"Missing JWT claim: {claim}")
-
- return decoded
-
- try:
- security_config=webhook_cfg.get("security", {})
- await validate_webhook_security(security_config)
- except Exception as e:
- return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
- if not isinstance(cvs.dsl, str):
- dsl = json.dumps(cvs.dsl, ensure_ascii=False)
- try:
- canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id)
- except Exception as e:
- resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
- resp.status_code = RetCode.BAD_REQUEST
- return resp
-
- # 7. Parse request body
- async def parse_webhook_request(content_type):
- """Parse request based on content-type and return structured data."""
-
- # 1. Query
- query_data = {k: v for k, v in request.args.items()}
-
- # 2. Headers
- header_data = {k: v for k, v in request.headers.items()}
-
- # 3. Body
- ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
- if ctype and ctype != content_type:
- raise ValueError(
- f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
- )
-
- body_data: dict = {}
-
- try:
- if ctype == "application/json":
- body_data = await request.get_json() or {}
-
- elif ctype == "multipart/form-data":
- nonlocal canvas
- form = await request.form
- files = await request.files
-
- body_data = {}
-
- for key, value in form.items():
- body_data[key] = value
-
- if len(files) > 10:
- raise Exception("Too many uploaded files")
- for key, file in files.items():
- desc = FileService.upload_info(
- cvs.user_id, # user
- file, # FileStorage
- None # url (None for webhook)
- )
- file_parsed= await canvas.get_files_async([desc])
- body_data[key] = file_parsed
-
- elif ctype == "application/x-www-form-urlencoded":
- form = await request.form
- body_data = dict(form)
-
- else:
- # text/plain / octet-stream / empty / unknown
- raw = await request.get_data()
- if raw:
- try:
- body_data = json.loads(raw.decode("utf-8"))
- except Exception:
- body_data = {}
- else:
- body_data = {}
-
- except Exception:
- body_data = {}
-
- return {
- "query": query_data,
- "headers": header_data,
- "body": body_data,
- "content_type": ctype,
- }
-
- def extract_by_schema(data, schema, name="section"):
- """
- Extract only fields defined in schema.
- Required fields must exist.
- Optional fields default to type-based default values.
- Type validation included.
- """
- props = schema.get("properties", {})
- required = schema.get("required", [])
-
- extracted = {}
-
- for field, field_schema in props.items():
- field_type = field_schema.get("type")
-
- # 1. Required field missing
- if field in required and field not in data:
- raise Exception(f"{name} missing required field: {field}")
-
- # 2. Optional → default value
- if field not in data:
- extracted[field] = default_for_type(field_type)
- continue
-
- raw_value = data[field]
-
- # 3. Auto convert value
- try:
- value = auto_cast_value(raw_value, field_type)
- except Exception as e:
- raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
-
- # 4. Type validation
- if not validate_type(value, field_type):
- raise Exception(
- f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
- )
-
- extracted[field] = value
-
- return extracted
-
-
- def default_for_type(t):
- """Return default value for the given schema type."""
- if t == "file":
- return []
- if t == "object":
- return {}
- if t == "boolean":
- return False
- if t == "number":
- return 0
- if t == "string":
- return ""
- if t and t.startswith("array"):
- return []
- if t == "null":
- return None
- return None
-
- def auto_cast_value(value, expected_type):
- """Convert string values into schema type when possible."""
-
- # Non-string values already good
- if not isinstance(value, str):
- return value
-
- v = value.strip()
-
- # Boolean
- if expected_type == "boolean":
- if v.lower() in ["true", "1"]:
- return True
- if v.lower() in ["false", "0"]:
- return False
- raise Exception(f"Cannot convert '{value}' to boolean")
-
- # Number
- if expected_type == "number":
- # integer
- if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
- return int(v)
-
- # float
- try:
- return float(v)
- except Exception:
- raise Exception(f"Cannot convert '{value}' to number")
-
- # Object
- if expected_type == "object":
- try:
- parsed = json.loads(v)
- if isinstance(parsed, dict):
- return parsed
- else:
- raise Exception("JSON is not an object")
- except Exception:
- raise Exception(f"Cannot convert '{value}' to object")
-
- # Array
- if expected_type.startswith("array"):
- try:
- parsed = json.loads(v)
- if isinstance(parsed, list):
- return parsed
- else:
- raise Exception("JSON is not an array")
- except Exception:
- raise Exception(f"Cannot convert '{value}' to array")
-
- # String (accept original)
- if expected_type == "string":
- return value
-
- # File
- if expected_type == "file":
- return value
- # Default: do nothing
- return value
-
-
- def validate_type(value, t):
- """Validate value type against schema type t."""
- if t == "file":
- return isinstance(value, list)
-
- if t == "string":
- return isinstance(value, str)
-
- if t == "number":
- return isinstance(value, (int, float))
-
- if t == "boolean":
- return isinstance(value, bool)
-
- if t == "object":
- return isinstance(value, dict)
-
- # array / array / array
- if t.startswith("array"):
- if not isinstance(value, list):
- return False
-
- if "<" in t and ">" in t:
- inner = t[t.find("<") + 1 : t.find(">")]
-
- # Check each element type
- for item in value:
- if not validate_type(item, inner):
- return False
-
- return True
-
- return True
- parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
- SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
-
- # Extract strictly by schema
- try:
- query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
- header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
- body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
- except Exception as e:
- return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
-
- clean_request = {
- "query": query_clean,
- "headers": header_clean,
- "body": body_clean,
- "input": parsed
- }
-
- execution_mode = webhook_cfg.get("execution_mode", "Immediately")
- response_cfg = webhook_cfg.get("response", {})
-
- def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
- key = f"webhook-trace-{agent_id}-logs"
-
- raw = REDIS_CONN.get(key)
- obj = json.loads(raw) if raw else {"webhooks": {}}
-
- ws = obj["webhooks"].setdefault(
- str(start_ts),
- {"start_ts": start_ts, "events": []}
- )
-
- ws["events"].append({
- "ts": time.time(),
- **event
- })
-
- REDIS_CONN.set_obj(key, obj, ttl)
-
- if execution_mode == "Immediately":
- status = response_cfg.get("status", 200)
- try:
- status = int(status)
- except (TypeError, ValueError):
- return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
-
- if not (200 <= status <= 399):
- return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
-
- body_tpl = response_cfg.get("body_template", "")
-
- def parse_body(body: str):
- if not body:
- return None, "application/json"
-
- try:
- parsed = json.loads(body)
- return parsed, "application/json"
- except (json.JSONDecodeError, TypeError):
- return body, "text/plain"
-
-
- body, content_type = parse_body(body_tpl)
- resp = Response(
- json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
- status=status,
- content_type=content_type,
- )
-
- async def background_run():
- try:
- async for ans in canvas.run(
- query="",
- user_id=cvs.user_id,
- webhook_payload=clean_request
- ):
- if is_test:
- append_webhook_trace(agent_id, start_ts, ans)
-
- if is_test:
- append_webhook_trace(
- agent_id,
- start_ts,
- {
- "event": "finished",
- "elapsed_time": time.time() - start_ts,
- "success": True,
- }
- )
-
- cvs.dsl = json.loads(str(canvas))
- UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
-
- except Exception as e:
- logging.exception("Webhook background run failed")
- if is_test:
- try:
- append_webhook_trace(
- agent_id,
- start_ts,
- {
- "event": "error",
- "message": str(e),
- "error_type": type(e).__name__,
- }
- )
- append_webhook_trace(
- agent_id,
- start_ts,
- {
- "event": "finished",
- "elapsed_time": time.time() - start_ts,
- "success": False,
- }
- )
- except Exception:
- logging.exception("Failed to append webhook trace")
-
- asyncio.create_task(background_run())
- return resp
- else:
- async def sse():
- nonlocal canvas
- contents: list[str] = []
- status = 200
- try:
- async for ans in canvas.run(
- query="",
- user_id=cvs.user_id,
- webhook_payload=clean_request,
- ):
- if ans["event"] == "message":
- content = ans["data"]["content"]
- if ans["data"].get("start_to_think", False):
- content = ""
- elif ans["data"].get("end_to_think", False):
- content = " "
- if content:
- contents.append(content)
- if ans["event"] == "message_end":
- status = int(ans["data"].get("status", status))
- if is_test:
- append_webhook_trace(
- agent_id,
- start_ts,
- ans
- )
- if is_test:
- append_webhook_trace(
- agent_id,
- start_ts,
- {
- "event": "finished",
- "elapsed_time": time.time() - start_ts,
- "success": True,
- }
- )
- final_content = "".join(contents)
- return {
- "message": final_content,
- "success": True,
- "code": status,
- }
-
- except Exception as e:
- if is_test:
- append_webhook_trace(
- agent_id,
- start_ts,
- {
- "event": "error",
- "message": str(e),
- "error_type": type(e).__name__,
- }
- )
- append_webhook_trace(
- agent_id,
- start_ts,
- {
- "event": "finished",
- "elapsed_time": time.time() - start_ts,
- "success": False,
- }
- )
- return {"code": 400, "message": str(e),"success":False}
-
- result = await sse()
- return Response(
- json.dumps(result),
- status=result["code"],
- mimetype="application/json",
- )
-
-
-@manager.route("/webhook_trace/", methods=["GET"]) # noqa: F821
-async def webhook_trace(agent_id: str):
- def encode_webhook_id(start_ts: str) -> str:
- WEBHOOK_ID_SECRET = "webhook_id_secret"
- sig = hmac.new(
- WEBHOOK_ID_SECRET.encode("utf-8"),
- start_ts.encode("utf-8"),
- hashlib.sha256,
- ).digest()
- return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=")
-
- def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None:
- for ts in webhooks.keys():
- if encode_webhook_id(ts) == enc_id:
- return ts
- return None
- since_ts = request.args.get("since_ts", type=float)
- webhook_id = request.args.get("webhook_id")
-
- key = f"webhook-trace-{agent_id}-logs"
- raw = REDIS_CONN.get(key)
-
- if since_ts is None:
- now = time.time()
- return get_json_result(
- data={
- "webhook_id": None,
- "events": [],
- "next_since_ts": now,
- "finished": False,
- }
- )
-
- if not raw:
- return get_json_result(
- data={
- "webhook_id": None,
- "events": [],
- "next_since_ts": since_ts,
- "finished": False,
- }
- )
-
- obj = json.loads(raw)
- webhooks = obj.get("webhooks", {})
-
- if webhook_id is None:
- candidates = [
- float(k) for k in webhooks.keys() if float(k) > since_ts
- ]
-
- if not candidates:
- return get_json_result(
- data={
- "webhook_id": None,
- "events": [],
- "next_since_ts": since_ts,
- "finished": False,
- }
- )
-
- start_ts = min(candidates)
- real_id = str(start_ts)
- webhook_id = encode_webhook_id(real_id)
-
- return get_json_result(
- data={
- "webhook_id": webhook_id,
- "events": [],
- "next_since_ts": start_ts,
- "finished": False,
- }
- )
-
- real_id = decode_webhook_id(webhook_id, webhooks)
-
- if not real_id:
- return get_json_result(
- data={
- "webhook_id": webhook_id,
- "events": [],
- "next_since_ts": since_ts,
- "finished": True,
- }
- )
-
- ws = webhooks.get(str(real_id))
- events = ws.get("events", [])
- new_events = [e for e in events if e.get("ts", 0) > since_ts]
-
- next_ts = since_ts
- for e in new_events:
- next_ts = max(next_ts, e["ts"])
-
- finished = any(e.get("event") == "finished" for e in new_events)
-
- return get_json_result(
- data={
- "webhook_id": webhook_id,
- "events": new_events,
- "next_since_ts": next_ts,
- "finished": finished,
- }
- )
diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py
index e6dd61d035e..e85a1d439c5 100644
--- a/api/apps/sdk/dify_retrieval.py
+++ b/api/apps/sdk/dify_retrieval.py
@@ -122,6 +122,8 @@ async def retrieval(tenant_id):
retrieval_setting = req.get("retrieval_setting", {})
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
top = int(retrieval_setting.get("top_k", 1024))
+ if top <= 0:
+ return build_error_result(message="`top_k` must be greater than 0", code=RetCode.DATA_ERROR)
metadata_condition = req.get("metadata_condition", {}) or {}
metas = DocMetadataService.get_flatted_meta_by_kbs([kb_id])
diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py
index bff583e4976..cf297c4b250 100644
--- a/api/apps/sdk/doc.py
+++ b/api/apps/sdk/doc.py
@@ -13,59 +13,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import datetime
-import re
+import logging
from io import BytesIO
-import xxhash
-from pydantic import BaseModel, Field, validator
from quart import request, send_file
-from api.db.db_models import APIToken, Document, File, Task
+from api.db.db_models import APIToken, Document, Task
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
-from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks
from api.db.services.tenant_llm_service import TenantLLMService
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_request_json, get_result, server_error_response, token_required
-from api.utils.image_utils import store_chunk_image
from common import settings
-from common.constants import FileSource, LLMType, ParserType, RetCode, TaskStatus
+from common.constants import LLMType, RetCode, TaskStatus
from common.metadata_utils import convert_conditions, meta_filter
-from common.misc_utils import thread_pool_exec
-from common.string_utils import is_content_empty, remove_redundant_spaces
-from common.tag_feature_utils import validate_tag_features
-from rag.app.qa import beAdoc, rmPrefix
from rag.app.tag import label_question
-from rag.nlp import rag_tokenizer, search
+from rag.nlp import search
from rag.prompts.generator import cross_languages, keyword_extraction
MAXIMUM_OF_UPLOADING_FILES = 256
-class Chunk(BaseModel):
- id: str = ""
- content: str = ""
- document_id: str = ""
- docnm_kwd: str = ""
- important_keywords: list = Field(default_factory=list)
- tag_kwd: list = Field(default_factory=list)
- questions: list = Field(default_factory=list)
- question_tks: str = ""
- image_id: str = ""
- available: bool = True
- positions: list[list[int]] = Field(default_factory=list)
+from api.utils.reference_metadata_utils import (
+ enrich_chunks_with_document_metadata,
+ resolve_reference_metadata_preferences,
+)
+
+def _resolve_reference_metadata(req: dict, search_config: dict | None = None):
+ return resolve_reference_metadata_preferences(req, search_config)
- @validator("positions")
- def validate_positions(cls, value):
- for sublist in value:
- if len(sublist) != 5:
- raise ValueError("Each sublist in positions must have a length of 5")
- return value
+def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=None) -> None:
+ enrich_chunks_with_document_metadata(chunks, metadata_fields)
@manager.route("/datasets//documents/", methods=["GET"]) # noqa: F821
@@ -134,15 +116,30 @@ async def download_doc(document_id):
if len(token) != 2:
return get_error_data_result(message="Authorization is not valid!")
token = token[1]
+ logging.info("Beta API token lookup attempted for document download")
objs = APIToken.query(beta=token)
if not objs:
+ logging.warning("Beta API token lookup failed for document download: invalid API key")
return get_error_data_result(message='Authentication error: API key is invalid!"')
+ if len(objs) > 1:
+ logging.error("Beta API token lookup is ambiguous for document download: matches=%s", len(objs))
+ return get_error_data_result(message="Authentication error: API key configuration is ambiguous.")
+ tenant_id = objs[0].tenant_id
+ logging.info("Beta API token authorized for document download: tenant_id=%s", tenant_id)
if not document_id:
return get_error_data_result(message="Specify document_id please.")
doc = DocumentService.query(id=document_id)
if not doc:
return get_error_data_result(message=f"The dataset not own the document {document_id}.")
+ if not KnowledgebaseService.query(id=doc[0].kb_id, tenant_id=tenant_id):
+ logging.warning(
+ "cross-tenant access denied for document download: tenant_id=%s kb_id=%s document_id=%s",
+ tenant_id,
+ doc[0].kb_id,
+ document_id,
+ )
+ return get_error_data_result(message="You do not have access to this document.")
# The process of downloading
doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address
file_stream = settings.STORAGE_IMPL.get(doc_id, doc_location)
@@ -158,171 +155,6 @@ async def download_doc(document_id):
)
-@manager.route("/datasets//metadata/update", methods=["POST"]) # noqa: F821
-@token_required
-async def metadata_batch_update(dataset_id, tenant_id):
- if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
- return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
-
- req = await get_request_json()
- selector = req.get("selector", {}) or {}
- updates = req.get("updates", []) or []
- deletes = req.get("deletes", []) or []
-
- if not isinstance(selector, dict):
- return get_error_data_result(message="selector must be an object.")
- if not isinstance(updates, list) or not isinstance(deletes, list):
- return get_error_data_result(message="updates and deletes must be lists.")
-
- metadata_condition = selector.get("metadata_condition", {}) or {}
- if metadata_condition and not isinstance(metadata_condition, dict):
- return get_error_data_result(message="metadata_condition must be an object.")
-
- document_ids = selector.get("document_ids", []) or []
- if document_ids and not isinstance(document_ids, list):
- return get_error_data_result(message="document_ids must be a list.")
-
- for upd in updates:
- if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd:
- return get_error_data_result(message="Each update requires key and value.")
- for d in deletes:
- if not isinstance(d, dict) or not d.get("key"):
- return get_error_data_result(message="Each delete requires key.")
-
- if document_ids:
- kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id])
- target_doc_ids = set(kb_doc_ids)
- invalid_ids = set(document_ids) - set(kb_doc_ids)
- if invalid_ids:
- return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}")
- target_doc_ids = set(document_ids)
-
- if metadata_condition:
- metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id])
- filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
- target_doc_ids = target_doc_ids & filtered_ids
- if metadata_condition.get("conditions") and not target_doc_ids:
- return get_result(data={"updated": 0, "matched_docs": 0})
-
- target_doc_ids = list(target_doc_ids)
- updated = DocMetadataService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes)
- return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
-
-
-@manager.route("/datasets//documents", methods=["DELETE"]) # noqa: F821
-@token_required
-async def delete(tenant_id, dataset_id):
- """
- Delete documents from a dataset.
- ---
- tags:
- - Documents
- security:
- - ApiKeyAuth: []
- parameters:
- - in: path
- name: dataset_id
- type: string
- required: true
- description: ID of the dataset.
- - in: body
- name: body
- description: Document deletion parameters.
- required: true
- schema:
- type: object
- properties:
- ids:
- type: array
- items:
- type: string
- description: |
- List of document IDs to delete.
- If omitted, `null`, or an empty array is provided, no documents will be deleted.
- - in: header
- name: Authorization
- type: string
- required: true
- description: Bearer token for authentication.
- responses:
- 200:
- description: Documents deleted successfully.
- schema:
- type: object
- """
- if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
- return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ")
- req = await get_request_json()
- if not req:
- return get_result()
-
- doc_ids = req.get("ids")
- if not doc_ids:
- if req.get("delete_all") is True:
- doc_ids = [doc.id for doc in DocumentService.query(kb_id=dataset_id)]
- if not doc_ids:
- return get_result()
- else:
- return get_result()
-
- doc_list = doc_ids
-
- unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_list, "document")
- doc_list = unique_doc_ids
-
- root_folder = FileService.get_root_folder(tenant_id)
- pf_id = root_folder["id"]
- FileService.init_knowledgebase_docs(pf_id, tenant_id)
- errors = ""
- not_found = []
- success_count = 0
- for doc_id in doc_list:
- try:
- e, doc = DocumentService.get_by_id(doc_id)
- if not e:
- not_found.append(doc_id)
- continue
- tenant_id = DocumentService.get_tenant_id(doc_id)
- if not tenant_id:
- return get_error_data_result(message="Tenant not found!")
-
- b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
-
- if not DocumentService.remove_document(doc, tenant_id):
- return get_error_data_result(message="Database error (Document removal)!")
-
- f2d = File2DocumentService.get_by_document_id(doc_id)
- FileService.filter_delete(
- [
- File.source_type == FileSource.KNOWLEDGEBASE,
- File.id == f2d[0].file_id,
- ]
- )
- File2DocumentService.delete_by_document_id(doc_id)
-
- settings.STORAGE_IMPL.rm(b, n)
- success_count += 1
- except Exception as e:
- errors += str(e)
-
- if not_found:
- return get_result(message=f"Documents not found: {not_found}", code=RetCode.DATA_ERROR)
-
- if errors:
- return get_result(message=errors, code=RetCode.SERVER_ERROR)
-
- if duplicate_messages:
- if success_count > 0:
- return get_result(
- message=f"Partially deleted {success_count} datasets with {len(duplicate_messages)} errors",
- data={"success_count": success_count, "errors": duplicate_messages},
- )
- else:
- return get_error_data_result(message=";".join(duplicate_messages))
-
- return get_result()
-
-
DOC_STOP_PARSING_INVALID_STATE_MESSAGE = "Can't stop parsing document that has not started or already completed"
DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE = "DOC_STOP_PARSING_INVALID_STATE"
@@ -495,642 +327,6 @@ async def stop_parsing(tenant_id, dataset_id):
return get_result()
-@manager.route("/datasets//documents//chunks", methods=["GET"]) # noqa: F821
-@token_required
-async def list_chunks(tenant_id, dataset_id, document_id):
- """
- List chunks of a document.
- ---
- tags:
- - Chunks
- security:
- - ApiKeyAuth: []
- parameters:
- - in: path
- name: dataset_id
- type: string
- required: true
- description: ID of the dataset.
- - in: path
- name: document_id
- type: string
- required: true
- description: ID of the document.
- - in: query
- name: page
- type: integer
- required: false
- default: 1
- description: Page number.
- - in: query
- name: page_size
- type: integer
- required: false
- default: 30
- description: Number of items per page.
- - in: query
- name: id
- type: string
- required: false
- default: ""
- description: Chunk id.
- - in: header
- name: Authorization
- type: string
- required: true
- description: Bearer token for authentication.
- responses:
- 200:
- description: List of chunks.
- schema:
- type: object
- properties:
- total:
- type: integer
- description: Total number of chunks.
- chunks:
- type: array
- items:
- type: object
- properties:
- id:
- type: string
- description: Chunk ID.
- content:
- type: string
- description: Chunk content.
- document_id:
- type: string
- description: ID of the document.
- important_keywords:
- type: array
- items:
- type: string
- description: Important keywords.
- tag_kwd:
- type: array
- items:
- type: string
- description: Tag keywords.
- image_id:
- type: string
- description: Image ID associated with the chunk.
- doc:
- type: object
- description: Document details.
- """
- if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
- return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
- doc = DocumentService.query(id=document_id, kb_id=dataset_id)
- if not doc:
- return get_error_data_result(message=f"You don't own the document {document_id}.")
- doc = doc[0]
- req = request.args
- doc_id = document_id
- page = int(req.get("page", 1))
- size = int(req.get("page_size", 30))
- question = req.get("keywords", "")
- query = {
- "doc_ids": [doc_id],
- "page": page,
- "size": size,
- "question": question,
- "sort": True,
- }
- if "available" in req:
- query["available_int"] = 1 if req["available"] == "true" else 0
- key_mapping = {
- "chunk_num": "chunk_count",
- "kb_id": "dataset_id",
- "token_num": "token_count",
- "parser_id": "chunk_method",
- }
- run_mapping = {
- "0": "UNSTART",
- "1": "RUNNING",
- "2": "CANCEL",
- "3": "DONE",
- "4": "FAIL",
- }
- doc = doc.to_dict()
- renamed_doc = {}
- for key, value in doc.items():
- new_key = key_mapping.get(key, key)
- renamed_doc[new_key] = value
- if key == "run":
- renamed_doc["run"] = run_mapping.get(str(value))
-
- res = {"total": 0, "chunks": [], "doc": renamed_doc}
- if req.get("id"):
- chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
- if not chunk:
- return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND)
- k = []
- for n in chunk.keys():
- if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
- k.append(n)
- for n in k:
- del chunk[n]
- if not chunk:
- return get_error_data_result(f"Chunk `{req.get('id')}` not found.")
- res["total"] = 1
- final_chunk = {
- "id": chunk.get("id", chunk.get("chunk_id")),
- "content": chunk["content_with_weight"],
- "document_id": chunk.get("doc_id", chunk.get("document_id")),
- "docnm_kwd": chunk["docnm_kwd"],
- "important_keywords": chunk.get("important_kwd", []),
- "questions": chunk.get("question_kwd", []),
- "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
- "image_id": chunk.get("img_id", ""),
- "available": bool(chunk.get("available_int", 1)),
- "positions": chunk.get("position_int", []),
- "tag_kwd": chunk.get("tag_kwd", []),
- "tag_feas": chunk.get("tag_feas", {}),
- }
- res["chunks"].append(final_chunk)
- _ = Chunk(**final_chunk)
-
- elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
- sres = await settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
- res["total"] = sres.total
- for id in sres.ids:
- d = {
- "id": id,
- "content": (remove_redundant_spaces(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get("content_with_weight", "")),
- "document_id": sres.field[id]["doc_id"],
- "docnm_kwd": sres.field[id]["docnm_kwd"],
- "important_keywords": sres.field[id].get("important_kwd", []),
- "tag_kwd": sres.field[id].get("tag_kwd", []),
- "questions": sres.field[id].get("question_kwd", []),
- "dataset_id": sres.field[id].get("kb_id", sres.field[id].get("dataset_id")),
- "image_id": sres.field[id].get("img_id", ""),
- "available": bool(int(sres.field[id].get("available_int", "1"))),
- "positions": sres.field[id].get("position_int", []),
- }
- res["chunks"].append(d)
- _ = Chunk(**d) # validate the chunk
- return get_result(data=res)
-
-
-@manager.route( # noqa: F821
- "/datasets//documents//chunks", methods=["POST"]
-)
-@token_required
-async def add_chunk(tenant_id, dataset_id, document_id):
- """
- Add a chunk to a document.
- ---
- tags:
- - Chunks
- security:
- - ApiKeyAuth: []
- parameters:
- - in: path
- name: dataset_id
- type: string
- required: true
- description: ID of the dataset.
- - in: path
- name: document_id
- type: string
- required: true
- description: ID of the document.
- - in: body
- name: body
- description: Chunk data.
- required: true
- schema:
- type: object
- properties:
- content:
- type: string
- required: true
- description: Content of the chunk.
- important_keywords:
- type: array
- items:
- type: string
- description: Important keywords.
- image_base64:
- type: string
- description: Base64-encoded image to associate with the chunk.
- - in: header
- name: Authorization
- type: string
- required: true
- description: Bearer token for authentication.
- responses:
- 200:
- description: Chunk added successfully.
- schema:
- type: object
- properties:
- chunk:
- type: object
- properties:
- id:
- type: string
- description: Chunk ID.
- content:
- type: string
- description: Chunk content.
- document_id:
- type: string
- description: ID of the document.
- important_keywords:
- type: array
- items:
- type: string
- description: Important keywords.
- """
- if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
- return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
- doc = DocumentService.query(id=document_id, kb_id=dataset_id)
- if not doc:
- return get_error_data_result(message=f"You don't own the document {document_id}.")
- doc = doc[0]
- req = await get_request_json()
- if is_content_empty(req.get("content")):
- return get_error_data_result(message="`content` is required")
- if "important_keywords" in req:
- if not isinstance(req["important_keywords"], list):
- return get_error_data_result("`important_keywords` is required to be a list")
- if "questions" in req:
- if not isinstance(req["questions"], list):
- return get_error_data_result("`questions` is required to be a list")
- chunk_id = xxhash.xxh64((req["content"] + document_id).encode("utf-8")).hexdigest()
- d = {
- "id": chunk_id,
- "content_ltks": rag_tokenizer.tokenize(req["content"]),
- "content_with_weight": req["content"],
- }
- d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
- d["important_kwd"] = req.get("important_keywords", [])
- d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_keywords", [])))
- d["question_kwd"] = [str(q).strip() for q in req.get("questions", []) if str(q).strip()]
- d["question_tks"] = rag_tokenizer.tokenize("\n".join(req.get("questions", [])))
- d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
- d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
- d["kb_id"] = dataset_id
- d["docnm_kwd"] = doc.name
- d["doc_id"] = document_id
- if "tag_kwd" in req:
- if not isinstance(req["tag_kwd"], list):
- return get_error_data_result("`tag_kwd` is required to be a list")
- if not all(isinstance(t, str) for t in req["tag_kwd"]):
- return get_error_data_result("`tag_kwd` must be a list of strings")
- d["tag_kwd"] = req["tag_kwd"]
- if "tag_feas" in req:
- try:
- d["tag_feas"] = validate_tag_features(req["tag_feas"])
- except ValueError as exc:
- return get_error_data_result(f"`tag_feas` {exc}")
- import base64
-
- image_base64 = req.get("image_base64", None)
- if image_base64:
- d["img_id"] = "{}-{}".format(dataset_id, chunk_id)
- d["doc_type_kwd"] = "image"
-
- tenant_embd_id = DocumentService.get_tenant_embd_id(document_id)
- if tenant_embd_id:
- model_config = get_model_config_by_id(tenant_embd_id)
- else:
- embd_id = DocumentService.get_embd_id(document_id)
- model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id)
- embd_mdl = TenantLLMService.model_instance(model_config)
- v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
- v = 0.1 * v[0] + 0.9 * v[1]
- d["q_%d_vec" % len(v)] = v.tolist()
- settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
-
- if image_base64:
- store_chunk_image(dataset_id, chunk_id, base64.b64decode(image_base64))
-
- DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
- # rename keys
- key_mapping = {
- "id": "id",
- "content_with_weight": "content",
- "doc_id": "document_id",
- "important_kwd": "important_keywords",
- "tag_kwd": "tag_kwd",
- "question_kwd": "questions",
- "kb_id": "dataset_id",
- "create_timestamp_flt": "create_timestamp",
- "create_time": "create_time",
- "document_keyword": "document",
- "img_id": "image_id",
- }
- renamed_chunk = {}
- for key, value in d.items():
- if key in key_mapping:
- new_key = key_mapping.get(key, key)
- renamed_chunk[new_key] = value
- _ = Chunk(**renamed_chunk) # validate the chunk
- return get_result(data={"chunk": renamed_chunk})
- # return get_result(data={"chunk_id": chunk_id})
-
-
-@manager.route( # noqa: F821
- "datasets//documents//chunks", methods=["DELETE"]
-)
-@token_required
-async def rm_chunk(tenant_id, dataset_id, document_id):
- """
- Remove chunks from a document.
- ---
- tags:
- - Chunks
- security:
- - ApiKeyAuth: []
- parameters:
- - in: path
- name: dataset_id
- type: string
- required: true
- description: ID of the dataset.
- - in: path
- name: document_id
- type: string
- required: true
- description: ID of the document.
- - in: body
- name: body
- description: Chunk removal parameters.
- required: true
- schema:
- type: object
- properties:
- chunk_ids:
- type: array
- items:
- type: string
- description: |
- List of chunk IDs to remove.
- If omitted, `null`, or an empty array is provided, no chunks will be deleted.
- - in: header
- name: Authorization
- type: string
- required: true
- description: Bearer token for authentication.
- responses:
- 200:
- description: Chunks removed successfully.
- schema:
- type: object
- """
- if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
- return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
- docs = DocumentService.get_by_ids([document_id])
- if not docs:
- raise LookupError(f"Can't find the document with ID {document_id}!")
- req = await get_request_json()
- if not req:
- return get_result()
-
- chunk_ids = req.get("chunk_ids")
- if not chunk_ids:
- if req.get("delete_all") is True:
- doc = docs[0]
- # Clean up storage assets while index rows still exist for discovery
- DocumentService.delete_chunk_images(doc, tenant_id)
- condition = {"doc_id": document_id}
- chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
- if chunk_number != 0:
- DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
- return get_result(message=f"deleted {chunk_number} chunks")
- else:
- return get_result()
-
- condition = {"doc_id": document_id}
- unique_chunk_ids, duplicate_messages = check_duplicate_ids(chunk_ids, "chunk")
- condition["id"] = unique_chunk_ids
- chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
- if chunk_number != 0:
- DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
- if chunk_number != len(unique_chunk_ids):
- if len(unique_chunk_ids) == 0:
- return get_result(message=f"deleted {chunk_number} chunks")
- return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(unique_chunk_ids)}")
- if duplicate_messages:
- return get_result(
- message=f"Partially deleted {chunk_number} chunks with {len(duplicate_messages)} errors",
- data={"success_count": chunk_number, "errors": duplicate_messages},
- )
- return get_result(message=f"deleted {chunk_number} chunks")
-
-
-@manager.route( # noqa: F821
- "/datasets//documents//chunks/", methods=["PUT"]
-)
-@token_required
-async def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
- """
- Update a chunk within a document.
- ---
- tags:
- - Chunks
- security:
- - ApiKeyAuth: []
- parameters:
- - in: path
- name: dataset_id
- type: string
- required: true
- description: ID of the dataset.
- - in: path
- name: document_id
- type: string
- required: true
- description: ID of the document.
- - in: path
- name: chunk_id
- type: string
- required: true
- description: ID of the chunk to update.
- - in: body
- name: body
- description: Chunk update parameters.
- required: true
- schema:
- type: object
- properties:
- content:
- type: string
- description: Updated content of the chunk.
- important_keywords:
- type: array
- items:
- type: string
- description: Updated important keywords.
- tag_kwd:
- type: array
- items:
- type: string
- description: Updated tag keywords.
- available:
- type: boolean
- description: Availability status of the chunk.
- - in: header
- name: Authorization
- type: string
- required: true
- description: Bearer token for authentication.
- responses:
- 200:
- description: Chunk updated successfully.
- schema:
- type: object
- """
- chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
- if chunk is None:
- return get_error_data_result(f"Can't find this chunk {chunk_id}")
- if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
- return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
- doc = DocumentService.query(id=document_id, kb_id=dataset_id)
- if not doc:
- return get_error_data_result(message=f"You don't own the document {document_id}.")
- doc = doc[0]
- req = await get_request_json()
- content = req.get("content")
- if content is not None:
- if is_content_empty(content):
- return get_error_data_result(message="`content` is required")
- else:
- content = chunk.get("content_with_weight", "")
- d = {"id": chunk_id, "content_with_weight": content}
- d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"])
- d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
- if "important_keywords" in req:
- if not isinstance(req["important_keywords"], list):
- return get_error_data_result("`important_keywords` should be a list")
- d["important_kwd"] = req.get("important_keywords", [])
- d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"]))
- if "questions" in req:
- if not isinstance(req["questions"], list):
- return get_error_data_result("`questions` should be a list")
- d["question_kwd"] = [str(q).strip() for q in req.get("questions", []) if str(q).strip()]
- d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["questions"]))
- if "available" in req:
- d["available_int"] = int(req["available"])
- if "positions" in req:
- if not isinstance(req["positions"], list):
- return get_error_data_result("`positions` should be a list")
- d["position_int"] = req["positions"]
- if "tag_kwd" in req:
- if not isinstance(req["tag_kwd"], list):
- return get_error_data_result("`tag_kwd` should be a list")
- if not all(isinstance(t, str) for t in req["tag_kwd"]):
- return get_error_data_result("`tag_kwd` must be a list of strings")
- d["tag_kwd"] = req["tag_kwd"]
- if "tag_feas" in req:
- try:
- d["tag_feas"] = validate_tag_features(req["tag_feas"])
- except ValueError as exc:
- return get_error_data_result(f"`tag_feas` {exc}")
- tenant_embd_id = DocumentService.get_tenant_embd_id(document_id)
- if tenant_embd_id:
- model_config = get_model_config_by_id(tenant_embd_id)
- else:
- embd_id = DocumentService.get_embd_id(document_id)
- model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id)
- embd_mdl = TenantLLMService.model_instance(model_config)
- if doc.parser_id == ParserType.QA:
- arr = [t for t in re.split(r"[\n\t]", d["content_with_weight"]) if len(t) > 1]
- if len(arr) != 2:
- return get_error_data_result(message="Q&A must be separated by TAB/ENTER key.")
- q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
- d = beAdoc(d, arr[0], arr[1], not any([rag_tokenizer.is_chinese(t) for t in q + a]))
-
- v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
- v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
- d["q_%d_vec" % len(v)] = v.tolist()
- settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
- return get_result()
-
-
-@manager.route( # noqa: F821
- "/datasets//documents//chunks/switch", methods=["POST"]
-)
-@token_required
-async def switch_chunks(tenant_id, dataset_id, document_id):
- """
- Switch availability of specified chunks (same as chunk_app switch).
- ---
- tags:
- - Chunks
- security:
- - ApiKeyAuth: []
- parameters:
- - in: path
- name: dataset_id
- type: string
- required: true
- description: ID of the dataset.
- - in: path
- name: document_id
- type: string
- required: true
- description: ID of the document.
- - in: body
- name: body
- required: true
- schema:
- type: object
- properties:
- chunk_ids:
- type: array
- items:
- type: string
- description: List of chunk IDs to switch.
- available_int:
- type: integer
- description: 1 for available, 0 for unavailable.
- available:
- type: boolean
- description: Availability status (alternative to available_int).
- - in: header
- name: Authorization
- type: string
- required: true
- description: Bearer token for authentication.
- responses:
- 200:
- description: Chunks availability switched successfully.
- """
- if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
- return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
- req = await get_request_json()
- if not req.get("chunk_ids"):
- return get_error_data_result(message="`chunk_ids` is required.")
- if "available_int" not in req and "available" not in req:
- return get_error_data_result(message="`available_int` or `available` is required.")
- available_int = int(req["available_int"]) if "available_int" in req else (1 if req.get("available") else 0)
- try:
-
- def _switch_sync():
- e, doc = DocumentService.get_by_id(document_id)
- if not e:
- return get_error_data_result(message="Document not found!")
- if not doc or str(doc.kb_id) != str(dataset_id):
- return get_error_data_result(message="Document not found!")
- for cid in req["chunk_ids"]:
- if not settings.docStoreConn.update(
- {"id": cid},
- {"available_int": available_int},
- search.index_name(tenant_id),
- doc.kb_id,
- ):
- return get_error_data_result(message="Index updating failure")
- return get_result(data=True)
-
- return await thread_pool_exec(_switch_sync)
- except Exception as e:
- return server_error_response(e)
-
-
@manager.route("/retrieval", methods=["POST"]) # noqa: F821
@token_required
async def retrieval_test(tenant_id):
@@ -1268,6 +464,8 @@ async def retrieval_test(tenant_id):
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))
+ if top <= 0:
+ return get_error_data_result("`top_k` must be greater than 0")
highlight_val = req.get("highlight", None)
if highlight_val is None:
highlight = False
@@ -1280,6 +478,7 @@ async def retrieval_test(tenant_id):
return get_error_data_result("`highlight` should be a boolean")
else:
return get_error_data_result("`highlight` should be a boolean")
+ include_metadata, metadata_fields = _resolve_reference_metadata(req)
try:
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
@@ -1338,6 +537,15 @@ async def retrieval_test(tenant_id):
for c in ranks["chunks"]:
c.pop("vector", None)
+ if include_metadata:
+ logging.info(
+ "sdk.retrieval reference_metadata enabled dataset_ids=%s fields=%s chunks=%s",
+ kb_ids,
+ sorted(metadata_fields) if metadata_fields else None,
+ len(ranks["chunks"]),
+ )
+ enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields)
+
##rename keys
renamed_chunks = []
for chunk in ranks["chunks"]:
diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py
index 82e048ff17b..11960dcf65c 100644
--- a/api/apps/sdk/session.py
+++ b/api/apps/sdk/session.py
@@ -14,47 +14,44 @@
# limitations under the License.
#
import json
-import copy
import re
-import time
-import os
-import tempfile
import logging
-from quart import Response, jsonify, request
-
-from common.token_utils import num_tokens_from_string
+from quart import Response, request
from agent.canvas import Canvas
from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
-from api.db.services.canvas_service import UserCanvasService, completion_openai
+from api.db.services.canvas_service import UserCanvasService
from api.db.services.canvas_service import completion as agent_completion
-from api.db.services.conversation_service import ConversationService
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
-from api.db.services.conversation_service import async_completion as rag_completion
-from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
+from api.db.services.dialog_service import DialogService, async_ask, gen_mindmap
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
-from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
+from common.metadata_utils import apply_meta_data_filter
from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, \
get_model_config_by_type_and_name
from common.misc_utils import get_uuid
-from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
+from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \
get_result, get_request_json, server_error_response, token_required, validate_request
from rag.app.tag import label_question
from rag.prompts.template import load_prompt
-from rag.prompts.generator import cross_languages, keyword_extraction, chunks_format
+from rag.prompts.generator import cross_languages, keyword_extraction
from common.constants import RetCode, LLMType, StatusEnum
from common import settings
+from api.utils.reference_metadata_utils import (
+ enrich_chunks_with_document_metadata,
+ resolve_reference_metadata_preferences,
+)
+
+logger = logging.getLogger(__name__)
-@manager.route("/agents//sessions", methods=["POST"]) # noqa: F821
@token_required
async def create_agent_session(tenant_id, agent_id):
req = await get_request_json()
@@ -92,558 +89,6 @@ async def create_agent_session(tenant_id, agent_id):
return get_result(data=conv)
-@manager.route("/chats//completions", methods=["POST"]) # noqa: F821
-@token_required
-async def chat_completion(tenant_id, chat_id):
- req = await get_request_json()
- if not req:
- req = {"question": ""}
- if not req.get("session_id"):
- req["question"] = ""
- dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value)
- if not dia:
- return get_error_data_result(f"You don't own the chat {chat_id}")
- dia = dia[0]
- if req.get("session_id"):
- if not ConversationService.query(id=req["session_id"], dialog_id=chat_id):
- return get_error_data_result(f"You don't own the session {req['session_id']}")
-
- metadata_condition = req.get("metadata_condition") or {}
- if metadata_condition and not isinstance(metadata_condition, dict):
- return get_error_data_result(message="metadata_condition must be an object.")
-
- if metadata_condition and req.get("question"):
- metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or [])
- filtered_doc_ids = meta_filter(
- metas,
- convert_conditions(metadata_condition),
- metadata_condition.get("logic", "and"),
- )
- if metadata_condition.get("conditions") and not filtered_doc_ids:
- filtered_doc_ids = ["-999"]
-
- if filtered_doc_ids:
- req["doc_ids"] = ",".join(filtered_doc_ids)
- else:
- req.pop("doc_ids", None)
-
- if req.get("stream", True):
- resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
- resp.headers.add_header("Cache-control", "no-cache")
- resp.headers.add_header("Connection", "keep-alive")
- resp.headers.add_header("X-Accel-Buffering", "no")
- resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
-
- return resp
- else:
- answer = None
- async for ans in rag_completion(tenant_id, chat_id, **req):
- answer = ans
- break
- return get_result(data=answer)
-
-
-@manager.route("/chats_openai//chat/completions", methods=["POST"]) # noqa: F821
-@validate_request("model", "messages") # noqa: F821
-@token_required
-async def chat_completion_openai_like(tenant_id, chat_id):
- """
- OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint.
-
- This function allows users to interact with a model and receive responses based on a series of historical messages.
- If `stream` is set to True (by default), the response will be streamed in chunks, mimicking the OpenAI-style API.
- Set `stream` to False explicitly, the response will be returned in a single complete answer.
-
- Reference:
-
- - If `stream` is True, the final answer and reference information will appear in the **last chunk** of the stream.
- - If `stream` is False, the reference will be included in `choices[0].message.reference`.
- - If `extra_body.reference_metadata.include` is True, each reference chunk may include `document_metadata` in both streaming and non-streaming responses.
-
- Example usage:
-
- curl -X POST https://ragflow_address.com/api/v1/chats_openai//chat/completions \
- -H "Content-Type: application/json" \
- -H "Authorization: Bearer $RAGFLOW_API_KEY" \
- -d '{
- "model": "model",
- "messages": [{"role": "user", "content": "Say this is a test!"}],
- "stream": true
- }'
-
- Alternatively, you can use Python's `OpenAI` client:
-
- NOTE: Streaming via `client.chat.completions.create(stream=True, ...)` does
- not return `reference` currently. The only way to return `reference` is
- non-stream mode with `with_raw_response`.
-
- from openai import OpenAI
- import json
-
- model = "model"
- client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/")
-
- stream = True
- reference = True
-
- request_kwargs = dict(
- model="model",
- messages=[
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Who are you?"},
- {"role": "assistant", "content": "I am an AI assistant named..."},
- {"role": "user", "content": "Can you tell me how to install neovim"},
- ],
- extra_body={
- "reference": reference,
- "reference_metadata": {
- "include": True,
- "fields": ["author", "year", "source"],
- },
- "metadata_condition": {
- "logic": "and",
- "conditions": [
- {
- "name": "author",
- "comparison_operator": "is",
- "value": "bob"
- }
- ]
- }
- },
- )
-
- if stream:
- completion = client.chat.completions.create(stream=True, **request_kwargs)
- for chunk in completion:
- print(chunk)
- else:
- resp = client.chat.completions.with_raw_response.create(
- stream=False, **request_kwargs
- )
- print("status:", resp.http_response.status_code)
- raw_text = resp.http_response.text
- print("raw:", raw_text)
-
- data = json.loads(raw_text)
- print("assistant:", data["choices"][0]["message"].get("content"))
- print("reference:", data["choices"][0]["message"].get("reference"))
-
- """
- req = await get_request_json()
-
- extra_body = req.get("extra_body") or {}
- if extra_body and not isinstance(extra_body, dict):
- return get_error_data_result("extra_body must be an object.")
-
- need_reference = bool(extra_body.get("reference", False))
- reference_metadata = extra_body.get("reference_metadata") or {}
- if reference_metadata and not isinstance(reference_metadata, dict):
- return get_error_data_result("reference_metadata must be an object.")
- include_reference_metadata = bool(reference_metadata.get("include", False))
- metadata_fields = reference_metadata.get("fields")
- if metadata_fields is not None and not isinstance(metadata_fields, list):
- return get_error_data_result("reference_metadata.fields must be an array.")
-
- messages = req.get("messages", [])
- # To prevent empty [] input
- if len(messages) < 1:
- return get_error_data_result("You have to provide messages.")
- if messages[-1]["role"] != "user":
- return get_error_data_result("The last content of this conversation is not from user.")
-
- prompt = messages[-1]["content"]
- # Treat context tokens as reasoning tokens
- context_token_used = sum(num_tokens_from_string(message["content"]) for message in messages)
-
- dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value)
- if not dia:
- return get_error_data_result(f"You don't own the chat {chat_id}")
- dia = dia[0]
-
- metadata_condition = extra_body.get("metadata_condition") or {}
- if metadata_condition and not isinstance(metadata_condition, dict):
- return get_error_data_result(message="metadata_condition must be an object.")
-
- doc_ids_str = None
- if metadata_condition:
- metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or [])
- filtered_doc_ids = meta_filter(
- metas,
- convert_conditions(metadata_condition),
- metadata_condition.get("logic", "and"),
- )
- if metadata_condition.get("conditions") and not filtered_doc_ids:
- filtered_doc_ids = ["-999"]
- doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None
-
- # Filter system and non-sense assistant messages
- msg = []
- for m in messages:
- if m["role"] == "system":
- continue
- if m["role"] == "assistant" and not msg:
- continue
- msg.append(m)
-
- # tools = get_tools()
- # toolcall_session = SimpleFunctionCallServer()
- tools = None
- toolcall_session = None
-
- if req.get("stream", True):
- # The value for the usage field on all chunks except for the last one will be null.
- # The usage field on the last chunk contains token usage statistics for the entire request.
- # The choices field on the last chunk will always be an empty array [].
- async def streamed_response_generator(chat_id, dia, msg):
- token_used = 0
- last_ans = {}
- full_content = ""
- full_reasoning = ""
- final_answer = None
- final_reference = None
- in_think = False
- response = {
- "id": f"chatcmpl-{chat_id}",
- "choices": [
- {
- "delta": {
- "content": "",
- "role": "assistant",
- "function_call": None,
- "tool_calls": None,
- "reasoning_content": "",
- },
- "finish_reason": None,
- "index": 0,
- "logprobs": None,
- }
- ],
- "created": int(time.time()),
- "model": "model",
- "object": "chat.completion.chunk",
- "system_fingerprint": "",
- "usage": None,
- }
-
- try:
- chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
- if doc_ids_str:
- chat_kwargs["doc_ids"] = doc_ids_str
- async for ans in async_chat(dia, msg, True, **chat_kwargs):
- last_ans = ans
- if ans.get("final"):
- if ans.get("answer"):
- full_content = ans["answer"]
- response["choices"][0]["delta"]["content"] = full_content
- response["choices"][0]["delta"]["reasoning_content"] = None
- yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
- final_answer = full_content
- final_reference = ans.get("reference", {})
- continue
- if ans.get("start_to_think"):
- in_think = True
- continue
- if ans.get("end_to_think"):
- in_think = False
- continue
- delta = ans.get("answer") or ""
- if not delta:
- continue
- token_used += num_tokens_from_string(delta)
- if in_think:
- full_reasoning += delta
- response["choices"][0]["delta"]["reasoning_content"] = delta
- response["choices"][0]["delta"]["content"] = None
- else:
- full_content += delta
- response["choices"][0]["delta"]["content"] = delta
- response["choices"][0]["delta"]["reasoning_content"] = None
- yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
- except Exception as e:
- response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
- yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
-
- # The last chunk
- response["choices"][0]["delta"]["content"] = None
- response["choices"][0]["delta"]["reasoning_content"] = None
- response["choices"][0]["finish_reason"] = "stop"
- prompt_tokens = num_tokens_from_string(prompt)
- response["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": token_used, "total_tokens": prompt_tokens + token_used}
- if need_reference:
- reference_payload = final_reference if final_reference is not None else last_ans.get("reference", [])
- response["choices"][0]["delta"]["reference"] = _build_reference_chunks(
- reference_payload,
- include_metadata=include_reference_metadata,
- metadata_fields=metadata_fields,
- )
- response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content
- yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
- yield "data:[DONE]\n\n"
-
- resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream")
- resp.headers.add_header("Cache-control", "no-cache")
- resp.headers.add_header("Connection", "keep-alive")
- resp.headers.add_header("X-Accel-Buffering", "no")
- resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
- return resp
- else:
- answer = None
- chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
- if doc_ids_str:
- chat_kwargs["doc_ids"] = doc_ids_str
- async for ans in async_chat(dia, msg, False, **chat_kwargs):
- # focus answer content only
- answer = ans
- break
- content = answer["answer"]
-
- response = {
- "id": f"chatcmpl-{chat_id}",
- "object": "chat.completion",
- "created": int(time.time()),
- "model": req.get("model", ""),
- "usage": {
- "prompt_tokens": num_tokens_from_string(prompt),
- "completion_tokens": num_tokens_from_string(content),
- "total_tokens": num_tokens_from_string(prompt) + num_tokens_from_string(content),
- "completion_tokens_details": {
- "reasoning_tokens": context_token_used,
- "accepted_prediction_tokens": num_tokens_from_string(content),
- "rejected_prediction_tokens": 0, # 0 for simplicity
- },
- },
- "choices": [
- {
- "message": {
- "role": "assistant",
- "content": content,
- },
- "logprobs": None,
- "finish_reason": "stop",
- "index": 0,
- }
- ],
- }
- if need_reference:
- response["choices"][0]["message"]["reference"] = _build_reference_chunks(
- answer.get("reference", {}),
- include_metadata=include_reference_metadata,
- metadata_fields=metadata_fields,
- )
-
- return jsonify(response)
-
-
-@manager.route("/agents_openai//chat/completions", methods=["POST"]) # noqa: F821
-@validate_request("model", "messages") # noqa: F821
-@token_required
-async def agents_completion_openai_compatibility(tenant_id, agent_id):
- req = await get_request_json()
- messages = req.get("messages", [])
- if not messages:
- return get_error_data_result("You must provide at least one message.")
- if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
- return get_error_data_result(f"You don't own the agent {agent_id}")
-
- filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
- prompt_tokens = sum(num_tokens_from_string(m["content"]) for m in filtered_messages)
- if not filtered_messages:
- return jsonify(
- get_data_openai(
- id=agent_id,
- content="No valid messages found (user or assistant).",
- finish_reason="stop",
- model=req.get("model", ""),
- completion_tokens=num_tokens_from_string("No valid messages found (user or assistant)."),
- prompt_tokens=prompt_tokens,
- )
- )
-
- question = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
-
- stream = req.pop("stream", False)
- if stream:
- resp = Response(
- completion_openai(
- tenant_id,
- agent_id,
- question,
- session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
- stream=True,
- **req,
- ),
- mimetype="text/event-stream",
- )
- resp.headers.add_header("Cache-control", "no-cache")
- resp.headers.add_header("Connection", "keep-alive")
- resp.headers.add_header("X-Accel-Buffering", "no")
- resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
- return resp
- else:
- # For non-streaming, just return the response directly
- async for response in completion_openai(
- tenant_id,
- agent_id,
- question,
- session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
- stream=False,
- **req,
- ):
- return jsonify(response)
-
- return None
-
-
-@manager.route("/agents//completions", methods=["POST"]) # noqa: F821
-@token_required
-async def agent_completions(tenant_id, agent_id):
- req = await get_request_json()
- return_trace = bool(req.get("return_trace", False))
-
- if req.get("stream", True):
-
- async def generate():
- trace_items = []
- async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
- if isinstance(answer, str):
- try:
- ans = json.loads(answer[5:]) # remove "data:"
- except Exception:
- continue
-
- event = ans.get("event")
- if event == "node_finished":
- if return_trace:
- data = ans.get("data", {})
- trace_items.append(
- {
- "component_id": data.get("component_id"),
- "trace": [copy.deepcopy(data)],
- }
- )
- ans.setdefault("data", {})["trace"] = trace_items
- answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
- yield answer
-
- if event not in ["message", "message_end"]:
- continue
-
- yield answer
-
- yield "data:[DONE]\n\n"
-
- resp = Response(generate(), mimetype="text/event-stream")
- resp.headers.add_header("Cache-control", "no-cache")
- resp.headers.add_header("Connection", "keep-alive")
- resp.headers.add_header("X-Accel-Buffering", "no")
- resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
- return resp
-
- full_content = ""
- reference = {}
- final_ans = ""
- trace_items = []
- structured_output = {}
- async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
- try:
- ans = json.loads(answer[5:])
-
- if ans["event"] == "message":
- full_content += ans["data"]["content"]
-
- if ans.get("data", {}).get("reference", None):
- reference.update(ans["data"]["reference"])
-
- if ans.get("event") == "node_finished":
- data = ans.get("data", {})
- node_out = data.get("outputs", {})
- component_id = data.get("component_id")
- if component_id is not None and "structured" in node_out:
- structured_output[component_id] = copy.deepcopy(node_out["structured"])
- if return_trace:
- trace_items.append(
- {
- "component_id": data.get("component_id"),
- "trace": [copy.deepcopy(data)],
- }
- )
-
- final_ans = ans
- except Exception as e:
- return get_result(data=f"**ERROR**: {str(e)}")
- final_ans["data"]["content"] = full_content
- final_ans["data"]["reference"] = reference
- if structured_output:
- final_ans["data"]["structured"] = structured_output
- if return_trace and final_ans:
- final_ans["data"]["trace"] = trace_items
- return get_result(data=final_ans)
-
-
-@manager.route("/agents//sessions", methods=["GET"]) # noqa: F821
-@token_required
-async def list_agent_session(tenant_id, agent_id):
- if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
- return get_error_data_result(message=f"You don't own the agent {agent_id}.")
- id = request.args.get("id")
- user_id = request.args.get("user_id")
- page_number = int(request.args.get("page", 1))
- items_per_page = int(request.args.get("page_size", 30))
- orderby = request.args.get("orderby", "update_time")
- if request.args.get("desc") == "False" or request.args.get("desc") == "false":
- desc = False
- else:
- desc = True
- # dsl defaults to True in all cases except for False and false
- include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
- total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id,
- user_id, include_dsl)
- if not convs:
- return get_result(data=[])
- for conv in convs:
- conv["messages"] = conv.pop("message")
- infos = conv["messages"]
- for info in infos:
- if "prompt" in info:
- info.pop("prompt")
- conv["agent_id"] = conv.pop("dialog_id")
- # Fix for session listing endpoint
- if conv["reference"]:
- messages = conv["messages"]
- message_num = 0
- chunk_num = 0
- # Ensure reference is a list type to prevent KeyError
- if not isinstance(conv["reference"], list):
- conv["reference"] = []
- while message_num < len(messages):
- if message_num != 0 and messages[message_num]["role"] != "user":
- chunk_list = []
- # Add boundary and type checks to prevent KeyError
- if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(
- conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]:
- chunks = conv["reference"][chunk_num]["chunks"]
- for chunk in chunks:
- # Ensure chunk is a dictionary before calling get method
- if not isinstance(chunk, dict):
- continue
- new_chunk = {
- "id": chunk.get("chunk_id", chunk.get("id")),
- "content": chunk.get("content_with_weight", chunk.get("content")),
- "document_id": chunk.get("doc_id", chunk.get("document_id")),
- "document_name": chunk.get("docnm_kwd", chunk.get("document_name")),
- "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
- "image_id": chunk.get("image_id", chunk.get("img_id")),
- "positions": chunk.get("positions", chunk.get("position_int")),
- }
- chunk_list.append(new_chunk)
- chunk_num += 1
- messages[message_num]["reference"] = chunk_list
- message_num += 1
- del conv["reference"]
- return get_result(data=convs)
-
-
@manager.route("/agents//sessions", methods=["DELETE"]) # noqa: F821
@token_required
async def delete_agent_session(tenant_id, agent_id):
@@ -697,97 +142,6 @@ async def delete_agent_session(tenant_id, agent_id):
return get_result()
-@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
-@token_required
-async def ask_about(tenant_id):
- req = await get_request_json()
- if not req.get("question"):
- return get_error_data_result("`question` is required.")
- if not req.get("dataset_ids"):
- return get_error_data_result("`dataset_ids` is required.")
- if not isinstance(req.get("dataset_ids"), list):
- return get_error_data_result("`dataset_ids` should be a list.")
- req["kb_ids"] = req.pop("dataset_ids")
- for kb_id in req["kb_ids"]:
- if not KnowledgebaseService.accessible(kb_id, tenant_id):
- return get_error_data_result(f"You don't own the dataset {kb_id}.")
- kbs = KnowledgebaseService.query(id=kb_id)
- kb = kbs[0]
- if kb.chunk_num == 0:
- return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
- uid = tenant_id
-
- async def stream():
- nonlocal req, uid
- try:
- async for ans in async_ask(req["question"], req["kb_ids"], uid):
- yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
- except Exception as e:
- yield "data:" + json.dumps(
- {"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
- ensure_ascii=False) + "\n\n"
- yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
-
- resp = Response(stream(), mimetype="text/event-stream")
- resp.headers.add_header("Cache-control", "no-cache")
- resp.headers.add_header("Connection", "keep-alive")
- resp.headers.add_header("X-Accel-Buffering", "no")
- resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
- return resp
-
-
-@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
-@token_required
-async def related_questions(tenant_id):
- req = await get_request_json()
- if not req.get("question"):
- return get_error_data_result("`question` is required.")
- question = req["question"]
- industry = req.get("industry", "")
- chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
- chat_mdl = LLMBundle(tenant_id, chat_model_config)
- prompt = """
-Objective: To generate search terms related to the user's search keywords, helping users find more valuable information.
-Instructions:
- - Based on the keywords provided by the user, generate 5-10 related search terms.
- - Each search term should be directly or indirectly related to the keyword, guiding the user to find more valuable information.
- - Use common, general terms as much as possible, avoiding obscure words or technical jargon.
- - Keep the term length between 2-4 words, concise and clear.
- - DO NOT translate, use the language of the original keywords.
-"""
- if industry:
- prompt += f" - Ensure all search terms are relevant to the industry: {industry}.\n"
- prompt += """
-### Example:
-Keywords: Chinese football
-Related search terms:
-1. Current status of Chinese football
-2. Reform of Chinese football
-3. Youth training of Chinese football
-4. Chinese football in the Asian Cup
-5. Chinese football in the World Cup
-
-Reason:
- - When searching, users often only use one or two keywords, making it difficult to fully express their information needs.
- - Generating related search terms can help users dig deeper into relevant information and improve search efficiency.
- - At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
-
-"""
- ans = await chat_mdl.async_chat(
- prompt,
- [
- {
- "role": "user",
- "content": f"""
-Keywords: {question}
-Related search terms:
- """,
- }
- ],
- {"temperature": 0.9},
- )
- return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
-
@manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821
async def chatbot_completions(dialog_id):
@@ -800,20 +154,69 @@ async def chatbot_completions(dialog_id):
objs = APIToken.query(beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
+ tenant_id = objs[0].tenant_id
+ exists, dialog = DialogService.get_by_id(dialog_id)
+ if (not exists
+ or getattr(dialog, "tenant_id", None) != tenant_id
+ or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value):
+ logger.warning(
+ "Denied chatbot access: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
+ "no access to this chatbot",
+ tenant_id,
+ dialog_id,
+ req.get("user_id"),
+ req.get("session_id"),
+ )
+ return get_error_data_result(message="Authentication error: no access to this chatbot!")
if "quote" not in req:
req["quote"] = False
+ def _validate_iframe_access():
+ if req.get("session_id"):
+ exists, conv = API4ConversationService.get_by_id(req.get("session_id"))
+ if not exists:
+ raise AssertionError("Session not found!")
+ if conv.dialog_id != dialog_id:
+ raise AssertionError("Session does not belong to this dialog")
+ if tenant_id and conv.user_id and conv.user_id != tenant_id:
+ raise AssertionError("Session does not belong to this tenant")
+
if req.get("stream", True):
- resp = Response(iframe_completion(dialog_id, **req), mimetype="text/event-stream")
+ try:
+ _validate_iframe_access()
+ except AssertionError:
+ logger.warning(
+ "Denied chatbot completion stream: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
+ "no access to this chatbot",
+ tenant_id,
+ dialog_id,
+ req.get("user_id"),
+ req.get("session_id"),
+ )
+ return get_error_data_result(message="Authentication error: no access to this chatbot!")
+
+ resp = Response(iframe_completion(dialog_id, tenant_id=tenant_id, **req), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
- async for answer in iframe_completion(dialog_id, **req):
- return get_result(data=answer)
+ try:
+ _validate_iframe_access()
+ async for answer in iframe_completion(dialog_id, tenant_id=tenant_id, **req):
+ return get_result(data=answer)
+ except AssertionError:
+ logger.warning(
+ "Denied chatbot completion: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
+ "no access to this chatbot",
+ tenant_id,
+ dialog_id,
+ req.get("user_id"),
+ req.get("session_id"),
+ )
+ return get_error_data_result(message="Authentication error: no access to this chatbot!")
return None
@@ -826,11 +229,23 @@ async def chatbots_inputs(dialog_id):
objs = APIToken.query(beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
-
- e, dialog = DialogService.get_by_id(dialog_id)
- if not e:
- return get_error_data_result(f"Can't find dialog by ID: {dialog_id}")
-
+ tenant_id = objs[0].tenant_id
+ exists, dialog = DialogService.get_by_id(dialog_id)
+ if (not exists
+ or getattr(dialog, "tenant_id", None) != tenant_id
+ or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value):
+ request_args = getattr(request, "args", {}) or {}
+ request_user_id = request_args.get("user_id") if hasattr(request_args, "get") else None
+ request_session_id = request_args.get("session_id") if hasattr(request_args, "get") else None
+ logger.warning(
+ "Denied chatbot access: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s",
+ "no access to this chatbot",
+ tenant_id,
+ dialog_id,
+ request_user_id,
+ request_session_id,
+ )
+ return get_error_data_result(message="Authentication error: no access to this chatbot!")
return get_result(
data={
"title": dialog.name,
@@ -971,12 +386,15 @@ async def retrieval_test_embedded():
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
+ if top <= 0:
+ return get_error_data_result("`top_k` must be greater than 0")
langs = req.get("cross_languages", [])
rerank_id = req.get("rerank_id", "")
tenant_rerank_id = req.get("tenant_rerank_id", "")
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
+ search_config = {}
async def _retrieval():
nonlocal similarity_threshold, vector_similarity_weight, top, rerank_id
@@ -987,8 +405,11 @@ async def _retrieval():
meta_data_filter = {}
chat_mdl = None
if req.get("search_id", ""):
- search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
- meta_data_filter = search_config.get("meta_data_filter", {})
+ nonlocal search_config
+ detail = SearchService.get_detail(req.get("search_id", ""))
+ if detail:
+ search_config = detail.get("search_config", {})
+ meta_data_filter = search_config.get("meta_data_filter", {})
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_id = search_config.get("chat_id", "")
if chat_id:
@@ -1012,8 +433,15 @@ async def _retrieval():
chat_mdl = LLMBundle(tenant_id, chat_model_config)
if meta_data_filter:
- metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
- local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, _question, chat_mdl, local_doc_ids)
+ local_doc_ids = await apply_meta_data_filter(
+ meta_data_filter,
+ None,
+ _question,
+ chat_mdl,
+ local_doc_ids,
+ kb_ids=kb_ids,
+ metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
+ )
tenants = UserTenantService.query(user_id=tenant_id)
for kb_id in kb_ids:
@@ -1064,6 +492,11 @@ async def _retrieval():
for c in ranks["chunks"]:
c.pop("vector", None)
+
+ include_metadata, metadata_fields = _resolve_reference_metadata(req, search_config)
+ if include_metadata:
+ enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields)
+
ranks["labels"] = labels
return get_json_result(data=ranks)
@@ -1179,126 +612,6 @@ async def mindmap():
return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map)
-@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
-@token_required
-async def sequence2txt(tenant_id):
- req = await request.form
- stream_mode = req.get("stream", "false").lower() == "true"
- files = await request.files
- if "file" not in files:
- return get_error_data_result(message="Missing 'file' in multipart form-data")
-
- uploaded = files["file"]
-
- ALLOWED_EXTS = {
- ".wav", ".mp3", ".m4a", ".aac",
- ".flac", ".ogg", ".webm",
- ".opus", ".wma"
- }
-
- filename = uploaded.filename or ""
- suffix = os.path.splitext(filename)[-1].lower()
- if suffix not in ALLOWED_EXTS:
- return get_error_data_result(message=
- f"Unsupported audio format: {suffix}. "
- f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
- )
- fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
- os.close(fd)
- await uploaded.save(temp_audio_path)
-
- try:
- default_asr_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.SPEECH2TEXT)
- except Exception as e:
- return get_error_data_result(message=str(e))
- asr_mdl=LLMBundle(tenant_id, default_asr_model_config)
- if not stream_mode:
- text = asr_mdl.transcription(temp_audio_path)
- try:
- os.remove(temp_audio_path)
- except Exception as e:
- logging.error(f"Failed to remove temp audio file: {str(e)}")
- return get_json_result(data={"text": text})
- async def event_stream():
- try:
- for evt in asr_mdl.stream_transcription(temp_audio_path):
- yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
- except Exception as e:
- err = {"event": "error", "text": str(e)}
- yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
- finally:
- try:
- os.remove(temp_audio_path)
- except Exception as e:
- logging.error(f"Failed to remove temp audio file: {str(e)}")
-
- return Response(event_stream(), content_type="text/event-stream")
-
-@manager.route("/tts", methods=["POST"]) # noqa: F821
-@token_required
-async def tts(tenant_id):
- req = await get_request_json()
- text = req["text"]
-
- try:
- default_tts_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.TTS)
- except Exception as e:
- return get_error_data_result(message=str(e))
- tts_mdl = LLMBundle(tenant_id, default_tts_model_config)
-
- def stream_audio():
- try:
- for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
- for chunk in tts_mdl.tts(txt):
- yield chunk
- except Exception as e:
- yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
-
- resp = Response(stream_audio(), mimetype="audio/mpeg")
- resp.headers.add_header("Cache-Control", "no-cache")
- resp.headers.add_header("Connection", "keep-alive")
- resp.headers.add_header("X-Accel-Buffering", "no")
-
- return resp
-
-
-def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None):
- chunks = chunks_format(reference)
- if not include_metadata:
- return chunks
-
- doc_ids_by_kb = {}
- for chunk in chunks:
- kb_id = chunk.get("dataset_id")
- doc_id = chunk.get("document_id")
- if not kb_id or not doc_id:
- continue
- doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id)
-
- if not doc_ids_by_kb:
- return chunks
-
- meta_by_doc = {}
- for kb_id, doc_ids in doc_ids_by_kb.items():
- meta_map = DocMetadataService.get_metadata_for_documents(list(doc_ids), kb_id)
- if meta_map:
- meta_by_doc.update(meta_map)
-
- if metadata_fields is not None:
- metadata_fields = {f for f in metadata_fields if isinstance(f, str)}
- if not metadata_fields:
- return chunks
-
- for chunk in chunks:
- doc_id = chunk.get("document_id")
- if not doc_id:
- continue
- meta = meta_by_doc.get(doc_id)
- if not meta:
- continue
- if metadata_fields is not None:
- meta = {k: v for k, v in meta.items() if k in metadata_fields}
- if meta:
- chunk["document_metadata"] = meta
- return chunks
+def _resolve_reference_metadata(req, search_config=None):
+ return resolve_reference_metadata_preferences(req, search_config)
diff --git a/api/apps/services/canvas_replica_service.py b/api/apps/services/canvas_replica_service.py
index a2aa56b6f96..17b6c99cb02 100644
--- a/api/apps/services/canvas_replica_service.py
+++ b/api/apps/services/canvas_replica_service.py
@@ -160,7 +160,7 @@ def bootstrap(
@classmethod
def load_for_run(cls, canvas_id: str, tenant_id: str, runtime_user_id: str):
- """Load current runtime replica used by /completion."""
+ """Load current runtime replica used by /completions."""
replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id))
return cls._read_payload(replica_key)
diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py
index 8cb718467a3..9e49596539c 100644
--- a/api/apps/services/dataset_api_service.py
+++ b/api/apps/services/dataset_api_service.py
@@ -16,6 +16,7 @@
import logging
import json
import os
+import re
from common.constants import PAGERANK_FLD
from common import settings
from api.db.db_models import File
@@ -25,10 +26,31 @@
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.connector_service import Connector2KbService
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
-from api.db.services.user_service import TenantService, UserService
+from api.db.services.user_service import TenantService, UserService, UserTenantService
+from api.db.services.tenant_llm_service import TenantLLMService
from common.constants import FileSource, StatusEnum
from api.utils.api_utils import deep_merge, get_parser_config, remap_dictionary_keys, verify_embedding_availability
+_VALID_INDEX_TYPES = {"graph", "raptor", "mindmap"}
+
+_INDEX_TYPE_TO_TASK_TYPE = {
+ "graph": "graphrag",
+ "raptor": "raptor",
+ "mindmap": "mindmap",
+}
+
+_INDEX_TYPE_TO_TASK_ID_FIELD = {
+ "graph": "graphrag_task_id",
+ "raptor": "raptor_task_id",
+ "mindmap": "mindmap_task_id",
+}
+
+_INDEX_TYPE_TO_DISPLAY_NAME = {
+ "graph": "Graph",
+ "raptor": "RAPTOR",
+ "mindmap": "Mindmap",
+}
+
async def create_dataset(tenant_id: str, req: dict):
"""
@@ -61,12 +83,7 @@ async def create_dataset(tenant_id: str, req: dict):
req["parser_config"] = parser_cfg
req.update(ext_fields)
- e, create_dict = KnowledgebaseService.create_with_name(
- name=req.pop("name", None),
- tenant_id=tenant_id,
- parser_id=req.pop("parser_id", None),
- **req
- )
+ e, create_dict = KnowledgebaseService.create_with_name(name=req.pop("name", None), tenant_id=tenant_id, parser_id=req.pop("parser_id", None), **req)
if not e:
return False, create_dict
@@ -132,12 +149,12 @@ async def delete_datasets(tenant_id: str, ids: list = None, delete_all: bool = F
]
)
File2DocumentService.delete_by_document_id(doc.id)
- FileService.filter_delete(
- [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
+ FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name])
# Drop index for this dataset
try:
from rag.nlp import search
+
idxnm = search.index_name(kb.tenant_id)
settings.docStoreConn.delete_idx(idxnm, kb_id)
except Exception as e:
@@ -158,6 +175,57 @@ async def delete_datasets(tenant_id: str, ids: list = None, delete_all: bool = F
return True, {"success_count": success_count, "errors": errors[:5]}
+def get_dataset(dataset_id: str, tenant_id: str):
+ """
+ Get a single dataset.
+
+ :param dataset_id: dataset ID
+ :param tenant_id: tenant ID
+ :return: (success, result) or (success, error_message)
+ """
+ if not dataset_id:
+ return False, 'Lack of "Dataset ID"'
+
+ if not KnowledgebaseService.accessible(dataset_id, tenant_id):
+ return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'"
+
+ ok, kb = KnowledgebaseService.get_by_id(dataset_id)
+ if not ok:
+ return False, "Invalid Dataset ID"
+
+ response_data = remap_dictionary_keys(kb.to_dict())
+ response_data["size"] = DocumentService.get_total_size_by_kb_id(dataset_id)
+ response_data["connectors"] = list(Connector2KbService.list_connectors(dataset_id))
+ return True, response_data
+
+
+def get_ingestion_summary(dataset_id: str, tenant_id: str):
+ """
+ Get ingestion summary for a dataset.
+
+ :param dataset_id: dataset ID
+ :param tenant_id: tenant ID
+ :return: (success, result) or (success, error_message)
+ """
+ if not dataset_id:
+ return False, 'Lack of "Dataset ID"'
+
+ if not KnowledgebaseService.accessible(dataset_id, tenant_id):
+ return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'"
+
+ ok, kb = KnowledgebaseService.get_by_id(dataset_id)
+ if not ok:
+ return False, "Invalid Dataset ID"
+
+ status = DocumentService.get_parsing_status_by_kb_ids([dataset_id]).get(dataset_id, {})
+ return True, {
+ "doc_num": kb.doc_num,
+ "chunk_num": kb.chunk_num,
+ "token_num": kb.token_num,
+ "status": status,
+ }
+
+
async def update_dataset(tenant_id: str, dataset_id: str, req: dict):
"""
Update a dataset.
@@ -195,7 +263,7 @@ async def update_dataset(tenant_id: str, dataset_id: str, req: dict):
parser_cfg["metadata"] = fields
parser_cfg["enable_metadata"] = auto_meta.get("enabled", True)
req["parser_config"] = parser_cfg
-
+
# Merge ext fields with req
req.update(ext_fields)
@@ -232,16 +300,13 @@ async def update_dataset(tenant_id: str, dataset_id: str, req: dict):
req["pipeline_id"] = ""
if "name" in req and req["name"].lower() != kb.name.lower():
- exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id,
- status=StatusEnum.VALID.value)
+ exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
if exists:
return False, f"Dataset name '{req['name']}' already exists"
if "embd_id" in req:
if not req["embd_id"]:
req["embd_id"] = kb.embd_id
- if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
- return False, f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}"
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
if not ok:
return False, err
@@ -252,13 +317,13 @@ async def update_dataset(tenant_id: str, dataset_id: str, req: dict):
if req["pagerank"] > 0:
from rag.nlp import search
- settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
- search.index_name(kb.tenant_id), kb.id)
+
+ settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
from rag.nlp import search
- settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
- search.index_name(kb.tenant_id), kb.id)
+
+ settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)
if "parse_type" in req:
del req["parse_type"]
@@ -317,27 +382,13 @@ def list_datasets(tenant_id: str, args: dict):
else:
tenants = TenantService.get_joined_tenants_by_user_id(tenant_id)
tenant_ids = [m["tenant_id"] for m in tenants]
- kbs, total = KnowledgebaseService.get_list(
- tenant_ids,
- tenant_id,
- page,
- page_size,
- orderby,
- desc,
- kb_id,
- name,
- keywords,
- parser_id
- )
+ kbs, total = KnowledgebaseService.get_list(tenant_ids, tenant_id, page, page_size, orderby, desc, kb_id, name, keywords, parser_id)
users = UserService.get_by_ids([m["tenant_id"] for m in kbs])
user_map = {m.id: m.to_dict() for m in users}
response_data_list = []
for kb in kbs:
user_dict = user_map.get(kb["tenant_id"], {})
- kb.update({
- "nickname": user_dict.get("nickname", ""),
- "tenant_avatar": user_dict.get("avatar", "")
- })
+ kb.update({"nickname": user_dict.get("nickname", ""), "tenant_avatar": user_dict.get("avatar", "")})
response_data_list.append(remap_dictionary_keys(kb))
return True, {"data": response_data_list, "total": total}
@@ -354,13 +405,11 @@ async def get_knowledge_graph(dataset_id: str, tenant_id: str):
return False, "No authorization."
_, kb = KnowledgebaseService.get_by_id(dataset_id)
- req = {
- "kb_id": [dataset_id],
- "knowledge_graph_kwd": ["graph"]
- }
+ req = {"kb_id": [dataset_id], "knowledge_graph_kwd": ["graph"]}
obj = {"graph": {}, "mind_map": {}}
from rag.nlp import search
+
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
return True, obj
sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
@@ -380,8 +429,7 @@ async def get_knowledge_graph(dataset_id: str, tenant_id: str):
obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256]
if "edges" in obj["graph"]:
node_id_set = {o["id"] for o in obj["graph"]["nodes"]}
- filtered_edges = [o for o in obj["graph"]["edges"] if
- o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
+ filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set]
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128]
return True, obj
@@ -398,20 +446,28 @@ def delete_knowledge_graph(dataset_id: str, tenant_id: str):
return False, "No authorization."
_, kb = KnowledgebaseService.get_by_id(dataset_id)
from rag.nlp import search
- settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
+ from rag.graphrag.phase_markers import clear_phase_markers
+ settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation", "community_report"]},
search.index_name(kb.tenant_id), dataset_id)
+ # Wiping the graph invalidates any phase-completion markers used to
+ # short-circuit resolution / community detection on resume.
+ clear_phase_markers(dataset_id)
return True, True
-def run_graphrag(dataset_id: str, tenant_id: str):
+def run_index(dataset_id: str, tenant_id: str, index_type: str):
"""
- Run GraphRAG for a dataset.
+ Run an indexing task (graph/raptor/mindmap) for a dataset.
:param dataset_id: dataset ID
:param tenant_id: tenant ID
+ :param index_type: one of "graph", "raptor", "mindmap"
:return: (success, result) or (success, error_message)
"""
+ if index_type not in _VALID_INDEX_TYPES:
+ return False, f"Invalid index type '{index_type}'. Must be one of {sorted(_VALID_INDEX_TYPES)}"
+
if not dataset_id:
return False, 'Lack of "Dataset ID"'
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
@@ -421,14 +477,18 @@ def run_graphrag(dataset_id: str, tenant_id: str):
if not ok:
return False, "Invalid Dataset ID"
- task_id = kb.graphrag_task_id
- if task_id:
- ok, task = TaskService.get_by_id(task_id)
+ task_type = _INDEX_TYPE_TO_TASK_TYPE[index_type]
+ task_id_field = _INDEX_TYPE_TO_TASK_ID_FIELD[index_type]
+ display_name = _INDEX_TYPE_TO_DISPLAY_NAME[index_type]
+
+ existing_task_id = getattr(kb, task_id_field, None)
+ if existing_task_id:
+ ok, task = TaskService.get_by_id(existing_task_id)
if not ok:
- logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}")
+ logging.warning(f"A valid {display_name} task id is expected for Dataset {dataset_id}")
if task and task.progress not in [-1, 1]:
- return False, f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running."
+ return False, f"Task {existing_task_id} in progress with status {task.progress}. A {display_name} Task is already running."
documents, _ = DocumentService.get_by_kb_id(
kb_id=dataset_id,
@@ -447,24 +507,29 @@ def run_graphrag(dataset_id: str, tenant_id: str):
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
- task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
+ task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty=task_type, priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
- if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
- logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}")
+ if not KnowledgebaseService.update_by_id(kb.id, {task_id_field: task_id}):
+ logging.warning(f"Cannot save {task_id_field} for Dataset {dataset_id}")
- return True, {"graphrag_task_id": task_id}
+ return True, {"task_id": task_id}
-def trace_graphrag(dataset_id: str, tenant_id: str):
+def trace_index(dataset_id: str, tenant_id: str, index_type: str):
"""
- Trace GraphRAG task for a dataset.
+ Trace an indexing task (graph/raptor/mindmap) for a dataset.
:param dataset_id: dataset ID
:param tenant_id: tenant ID
+ :param index_type: one of "graph", "raptor", "mindmap"
:return: (success, result) or (success, error_message)
"""
+ if index_type not in _VALID_INDEX_TYPES:
+ return False, f"Invalid index type '{index_type}'. Must be one of {sorted(_VALID_INDEX_TYPES)}"
+
if not dataset_id:
return False, 'Lack of "Dataset ID"'
+
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return False, "No authorization."
@@ -472,7 +537,8 @@ def trace_graphrag(dataset_id: str, tenant_id: str):
if not ok:
return False, "Invalid Dataset ID"
- task_id = kb.graphrag_task_id
+ task_id_field = _INDEX_TYPE_TO_TASK_ID_FIELD[index_type]
+ task_id = getattr(kb, task_id_field, None)
if not task_id:
return True, {}
@@ -483,9 +549,9 @@ def trace_graphrag(dataset_id: str, tenant_id: str):
return True, task.to_dict()
-def run_raptor(dataset_id: str, tenant_id: str):
+def list_tags(dataset_id: str, tenant_id: str):
"""
- Run RAPTOR for a dataset.
+ List tags for a dataset.
:param dataset_id: dataset ID
:param tenant_id: tenant ID
@@ -493,6 +559,118 @@ def run_raptor(dataset_id: str, tenant_id: str):
"""
if not dataset_id:
return False, 'Lack of "Dataset ID"'
+
+ if not KnowledgebaseService.accessible(dataset_id, tenant_id):
+ return False, "No authorization."
+
+ tenants = UserTenantService.get_tenants_by_user_id(tenant_id)
+ tags = []
+ for tenant in tenants:
+ tags += settings.retriever.all_tags(tenant["tenant_id"], [dataset_id])
+ return True, tags
+
+
+def aggregate_tags(dataset_ids: list[str], tenant_id: str):
+ """
+ Aggregate tags across multiple datasets.
+
+ :param dataset_ids: list of dataset IDs
+ :param tenant_id: tenant ID
+ :return: (success, result) or (success, error_message)
+ """
+ if not dataset_ids:
+ return False, 'Lack of "dataset_ids"'
+
+ for dataset_id in dataset_ids:
+ if not KnowledgebaseService.accessible(dataset_id, tenant_id):
+ return False, f"No authorization for dataset '{dataset_id}'"
+
+ dataset_ids_by_tenant = {}
+ for dataset_id in dataset_ids:
+ ok, kb = KnowledgebaseService.get_by_id(dataset_id)
+ if not ok:
+ return False, f"Invalid Dataset ID '{dataset_id}'"
+ dataset_ids_by_tenant.setdefault(kb.tenant_id, []).append(dataset_id)
+
+ merged = {}
+ for kb_tenant_id, kb_ids in dataset_ids_by_tenant.items():
+ for bucket in settings.retriever.all_tags(kb_tenant_id, kb_ids):
+ tag = bucket["value"]
+ merged[tag] = merged.get(tag, 0) + bucket["count"]
+
+ return True, [{"value": tag, "count": count} for tag, count in merged.items()]
+
+
+def get_flattened_metadata(dataset_ids: list[str], tenant_id: str):
+ """
+ Get flattened metadata for datasets.
+
+ :param dataset_ids: list of dataset IDs
+ :param tenant_id: tenant ID
+ :return: (success, result) or (success, error_message)
+ """
+ if not dataset_ids:
+ return False, 'Lack of "dataset_ids"'
+
+ for dataset_id in dataset_ids:
+ if not KnowledgebaseService.accessible(dataset_id, tenant_id):
+ return False, f"No authorization for dataset '{dataset_id}'"
+
+ from api.db.services.doc_metadata_service import DocMetadataService
+
+ return True, DocMetadataService.get_flatted_meta_by_kbs(dataset_ids)
+
+
+def get_auto_metadata(dataset_id: str, tenant_id: str):
+ """
+ Get auto-metadata configuration for a dataset.
+
+ :param dataset_id: dataset ID
+ :param tenant_id: tenant ID
+ :return: (success, result) or (success, error_message)
+ """
+ kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
+ if kb is None:
+ return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'"
+ parser_cfg = kb.parser_config or {}
+ return True, {"metadata": parser_cfg.get("metadata") or [], "built_in_metadata": parser_cfg.get("built_in_metadata") or []}
+
+
+async def update_auto_metadata(dataset_id: str, tenant_id: str, cfg: dict):
+ """
+ Update auto-metadata configuration for a dataset.
+
+ :param dataset_id: dataset ID
+ :param tenant_id: tenant ID
+ :param cfg: auto-metadata configuration
+ :return: (success, result) or (success, error_message)
+ """
+ kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
+ if kb is None:
+ return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'"
+
+ parser_cfg = kb.parser_config or {}
+ parser_cfg["metadata"] = cfg.get("metadata")
+ parser_cfg["built_in_metadata"] = cfg.get("built_in_metadata")
+
+ if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": parser_cfg}):
+ return False, "Update auto-metadata error.(Database error)"
+
+ return True, cfg
+
+
+def delete_tags(dataset_id: str, tenant_id: str, tags: list[str]):
+ """
+ Delete tags from a dataset.
+
+ :param dataset_id: dataset ID
+ :param tenant_id: tenant ID
+ :param tags: list of tags to delete
+ :return: (success, result) or (success, error_message)
+ """
+ if not dataset_id:
+ return False, 'Lack of "Dataset ID"'
+
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
return False, "No authorization."
@@ -500,14 +678,178 @@ def run_raptor(dataset_id: str, tenant_id: str):
if not ok:
return False, "Invalid Dataset ID"
- task_id = kb.raptor_task_id
+ from rag.nlp import search
+
+ for t in tags:
+ settings.docStoreConn.update({"tag_kwd": t, "kb_id": [dataset_id]}, {"remove": {"tag_kwd": t}}, search.index_name(kb.tenant_id), dataset_id)
+
+ return True, {}
+
+
+def list_ingestion_logs(
+ dataset_id: str,
+ tenant_id: str,
+ page: int,
+ page_size: int,
+ orderby: str,
+ desc: bool,
+ operation_status: list = None,
+ create_date_from: str = None,
+ create_date_to: str = None,
+ log_type: str = "dataset",
+ keywords: str = None,
+):
+ """
+ List ingestion logs for a dataset.
+
+ :param dataset_id: dataset ID
+ :param tenant_id: tenant ID
+ :param page: page number
+ :param page_size: items per page
+ :param orderby: order by field
+ :param desc: descending order
+ :param operation_status: filter by operation status
+ :param create_date_from: filter start date
+ :param create_date_to: filter end date
+ :param log_type: "dataset" or "file"
+ :param keywords: search keywords for file logs
+ :return: (success, result) or (success, error_message)
+ """
+ if not dataset_id:
+ return False, 'Lack of "Dataset ID"'
+
+ if not KnowledgebaseService.accessible(dataset_id, tenant_id):
+ return False, "No authorization."
+
+ from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
+
+ allowed_log_types = {"dataset", "file"}
+ if log_type not in allowed_log_types:
+ logging.warning(
+ "list_ingestion_logs invalid log_type: dataset_id=%s tenant_id=%s log_type=%s",
+ dataset_id,
+ tenant_id,
+ log_type,
+ )
+ return False, 'Invalid "log_type", expected "dataset" or "file"'
+
+ logging.info(
+ "list_ingestion_logs: dataset_id=%s tenant_id=%s log_type=%s page=%s page_size=%s",
+ dataset_id,
+ tenant_id,
+ log_type,
+ page,
+ page_size,
+ )
+
+ if log_type == "file":
+ logs, total = PipelineOperationLogService.get_file_logs_by_kb_id(dataset_id, page, page_size, orderby, desc, keywords, operation_status or [], None, None, create_date_from, create_date_to)
+ else:
+ logs, total = PipelineOperationLogService.get_dataset_logs_by_kb_id(dataset_id, page, page_size, orderby, desc, operation_status or [], create_date_from, create_date_to, keywords)
+ return True, {"total": total, "logs": logs}
+
+
+def get_ingestion_log(dataset_id: str, tenant_id: str, log_id: str):
+ """
+ Get a single ingestion log.
+
+ :param dataset_id: dataset ID
+ :param tenant_id: tenant ID
+ :param log_id: log ID
+ :return: (success, result) or (success, error_message)
+ """
+ if not dataset_id:
+ return False, 'Lack of "Dataset ID"'
+
+ if not KnowledgebaseService.accessible(dataset_id, tenant_id):
+ return False, "No authorization."
+
+ from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
+
+ fields = PipelineOperationLogService.get_dataset_logs_fields()
+ log = PipelineOperationLogService.model.select(*fields).where((PipelineOperationLogService.model.id == log_id) & (PipelineOperationLogService.model.kb_id == dataset_id)).first()
+ if not log:
+ return False, "Log not found"
+
+ return True, log.to_dict()
+
+
+def delete_index(dataset_id: str, tenant_id: str, index_type: str, wipe: bool = True):
+ """
+ Delete an indexing task (graph/raptor/mindmap) for a dataset.
+
+ :param dataset_id: dataset ID
+ :param tenant_id: tenant ID
+ :param index_type: one of "graph", "raptor", "mindmap"
+ :param wipe: when True (default) the persisted artefacts (graph rows,
+ raptor summaries) are removed from the doc store and any GraphRAG
+ phase-completion markers are cleared. Pass False to cancel the
+ running task while keeping prior progress so it can be resumed.
+ :return: (success, result) or (success, error_message)
+ """
+ if index_type not in _VALID_INDEX_TYPES:
+ return False, f"Invalid index type '{index_type}'. Must be one of {sorted(_VALID_INDEX_TYPES)}"
+
+ if not dataset_id:
+ return False, 'Lack of "Dataset ID"'
+
+ if not KnowledgebaseService.accessible(dataset_id, tenant_id):
+ return False, "No authorization."
+
+ ok, kb = KnowledgebaseService.get_by_id(dataset_id)
+ if not ok:
+ return False, "Invalid Dataset ID"
+
+ task_id_field = _INDEX_TYPE_TO_TASK_ID_FIELD[index_type]
+ task_finish_at_field = f"{task_id_field.replace('_task_id', '_task_finish_at')}"
+ task_id = getattr(kb, task_id_field, None)
+
+ logging.info("delete_index: dataset=%s index_type=%s wipe=%s", dataset_id, index_type, wipe)
+
if task_id:
- ok, task = TaskService.get_by_id(task_id)
- if not ok:
- logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}")
+ from rag.utils.redis_conn import REDIS_CONN
+
+ try:
+ REDIS_CONN.set(f"{task_id}-cancel", "x")
+ except Exception as e:
+ logging.exception(e)
+ TaskService.delete_by_id(task_id)
+
+ if wipe and index_type == "graph":
+ from rag.nlp import search
+ from rag.graphrag.phase_markers import clear_phase_markers
+ settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation", "community_report"]},
+ search.index_name(kb.tenant_id), dataset_id)
+ # Wiping the graph invalidates any phase-completion markers used to
+ # short-circuit resolution / community detection on resume.
+ clear_phase_markers(dataset_id)
+ logging.info("delete_index: cleared GraphRAG artefacts and phase markers for dataset=%s", dataset_id)
+ elif wipe and index_type == "raptor":
+ from rag.nlp import search
+
+ settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), dataset_id)
+
+ KnowledgebaseService.update_by_id(kb.id, {task_id_field: "", task_finish_at_field: None})
+ return True, {}
- if task and task.progress not in [-1, 1]:
- return False, f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running."
+
+def run_embedding(dataset_id: str, tenant_id: str):
+ """
+ Run embedding for all documents in a dataset.
+
+ :param dataset_id: dataset ID
+ :param tenant_id: tenant ID
+ :return: (success, result) or (success, error_message)
+ """
+ if not dataset_id:
+ return False, 'Lack of "Dataset ID"'
+
+ if not KnowledgebaseService.accessible(dataset_id, tenant_id):
+ return False, "No authorization."
+
+ ok, kb = KnowledgebaseService.get_by_id(dataset_id)
+ if not ok:
+ return False, "Invalid Dataset ID"
documents, _ = DocumentService.get_by_kb_id(
kb_id=dataset_id,
@@ -523,23 +865,22 @@ def run_raptor(dataset_id: str, tenant_id: str):
if not documents:
return False, f"No documents in Dataset {dataset_id}"
- sample_document = documents[0]
- document_ids = [document["id"] for document in documents]
-
- task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
-
- if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
- logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}")
+ kb_table_num_map = {}
+ for doc in documents:
+ doc["tenant_id"] = tenant_id
+ DocumentService.run(tenant_id, doc, kb_table_num_map)
- return True, {"raptor_task_id": task_id}
+ return True, {"scheduled_count": len(documents)}
-def trace_raptor(dataset_id: str, tenant_id: str):
+def rename_tag(dataset_id: str, tenant_id: str, from_tag: str, to_tag: str):
"""
- Trace RAPTOR task for a dataset.
+ Rename a tag in a dataset.
:param dataset_id: dataset ID
:param tenant_id: tenant ID
+ :param from_tag: original tag name
+ :param to_tag: new tag name
:return: (success, result) or (success, error_message)
"""
if not dataset_id:
@@ -552,78 +893,522 @@ def trace_raptor(dataset_id: str, tenant_id: str):
if not ok:
return False, "Invalid Dataset ID"
- task_id = kb.raptor_task_id
- if not task_id:
- return True, {}
+ from rag.nlp import search
- ok, task = TaskService.get_by_id(task_id)
- if not ok:
- return False, "RAPTOR Task Not Found or Error Occurred"
+ settings.docStoreConn.update({"tag_kwd": from_tag, "kb_id": [dataset_id]}, {"remove": {"tag_kwd": from_tag.strip()}, "add": {"tag_kwd": to_tag}}, search.index_name(kb.tenant_id), dataset_id)
- return True, task.to_dict()
+ return True, {"from": from_tag, "to": to_tag}
-def get_auto_metadata(dataset_id: str, tenant_id: str):
+async def search(dataset_id: str, tenant_id: str, req: dict):
"""
- Get auto-metadata configuration for a dataset.
+ Search (retrieval test) within a dataset.
:param dataset_id: dataset ID
:param tenant_id: tenant ID
+ :param req: search request
:return: (success, result) or (success, error_message)
"""
- kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
- if kb is None:
- return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'"
+ from api.db.joint_services.tenant_model_service import (
+ get_model_config_by_id,
+ get_model_config_by_type_and_name,
+ get_tenant_default_model_by_type,
+ )
+ from api.db.services.doc_metadata_service import DocMetadataService
+ from api.db.services.llm_service import LLMBundle
+ from api.db.services.search_service import SearchService
+ from api.db.services.user_service import UserTenantService
+ from common.constants import LLMType
+ from common.metadata_utils import apply_meta_data_filter
+ from rag.app.tag import label_question
+ from rag.prompts.generator import cross_languages, keyword_extraction
+
+ logging.debug(
+ "search(dataset=%s, tenant=%s, question_len=%s)",
+ dataset_id,
+ tenant_id,
+ len(req.get("question", "")),
+ )
- parser_cfg = kb.parser_config or {}
- metadata = parser_cfg.get("metadata") or []
- enabled = parser_cfg.get("enable_metadata", bool(metadata))
- # Normalize to AutoMetadataConfig-like JSON
- fields = []
- for f in metadata:
- if not isinstance(f, dict):
- continue
- fields.append(
- {
- "name": f.get("name", ""),
- "type": f.get("type", ""),
- "description": f.get("description"),
- "examples": f.get("examples"),
- "restrict_values": f.get("restrict_values", False),
- }
+ page = int(req.get("page", 1))
+ size = int(req.get("size", 30))
+ question = req.get("question", "")
+ doc_ids = req.get("doc_ids", [])
+ use_kg = req.get("use_kg", False)
+ top = max(1, min(int(req.get("top_k", 1024)), 2048))
+ langs = req.get("cross_languages", [])
+
+ if not KnowledgebaseService.accessible(dataset_id, tenant_id):
+ logging.warning("search access denied: dataset=%s tenant=%s", dataset_id, tenant_id)
+ return False, "Only owner of dataset authorized for this operation."
+
+ e, kb = KnowledgebaseService.get_by_id(dataset_id)
+ if not e:
+ logging.warning("search dataset not found: dataset=%s", dataset_id)
+ return False, "Dataset not found!"
+
+ if doc_ids is not None and not isinstance(doc_ids, list):
+ return False, "`doc_ids` should be a list"
+ local_doc_ids = list(doc_ids) if doc_ids else []
+
+ meta_data_filter = {}
+ chat_mdl = None
+ if req.get("search_id", ""):
+ search_detail = SearchService.get_detail(req.get("search_id", ""))
+ if not search_detail:
+ logging.warning("search config not found: search_id=%s", req.get("search_id", ""))
+ return False, "Invalid search_id"
+ search_config = search_detail.get("search_config", {})
+ meta_data_filter = search_config.get("meta_data_filter", {})
+ if meta_data_filter.get("method") in ["auto", "semi_auto"]:
+ chat_id = search_config.get("chat_id", "")
+ if chat_id:
+ chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, search_config["chat_id"])
+ else:
+ chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
+ chat_mdl = LLMBundle(tenant_id, chat_model_config)
+ else:
+ meta_data_filter = req.get("meta_data_filter") or {}
+ if meta_data_filter.get("method") in ["auto", "semi_auto"]:
+ chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
+ chat_mdl = LLMBundle(tenant_id, chat_model_config)
+
+ if meta_data_filter:
+ local_doc_ids = await apply_meta_data_filter(
+ meta_data_filter,
+ None,
+ question,
+ chat_mdl,
+ local_doc_ids,
+ kb_ids=[dataset_id],
+ metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs([dataset_id]),
)
- return True, {"enabled": enabled, "fields": fields}
+ tenant_ids = []
+ tenants = UserTenantService.query(user_id=tenant_id)
+ for tenant in tenants:
+ if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=dataset_id):
+ tenant_ids.append(tenant.tenant_id)
+ break
+ else:
+ return False, "Only owner of dataset authorized for this operation."
+
+ _question = question
+ if langs:
+ _question = await cross_languages(kb.tenant_id, None, _question, langs)
+ if kb.tenant_embd_id:
+ embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
+ elif kb.embd_id:
+ embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
+ else:
+ embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
+ embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
+
+ rerank_mdl = None
+ if req.get("tenant_rerank_id"):
+ rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
+ rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
+ elif req.get("rerank_id"):
+ rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
+ rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
+
+ if req.get("keyword", False):
+ default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
+ chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
+ _question += await keyword_extraction(chat_mdl, _question)
+
+ labels = label_question(_question, [kb])
+ ranks = await settings.retriever.retrieval(
+ _question,
+ embd_mdl,
+ tenant_ids,
+ [dataset_id],
+ page,
+ size,
+ float(req.get("similarity_threshold", 0.0)),
+ float(req.get("vector_similarity_weight", 0.3)),
+ doc_ids=local_doc_ids,
+ top=top,
+ rerank_mdl=rerank_mdl,
+ rank_feature=labels,
+ )
-async def update_auto_metadata(dataset_id: str, tenant_id: str, cfg: dict):
+ if use_kg:
+ try:
+ default_chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
+ ck = await settings.kg_retriever.retrieval(_question, tenant_ids, [dataset_id], embd_mdl, LLMBundle(kb.tenant_id, default_chat_model_config))
+ if ck["content_with_weight"]:
+ ranks["chunks"].insert(0, ck)
+ except Exception:
+ logging.warning("search KG retrieval failed: dataset=%s tenant=%s", dataset_id, tenant_id, exc_info=True)
+ total = ranks.get("total", 0)
+ ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
+ ranks["total"] = total
+
+ for c in ranks["chunks"]:
+ c.pop("vector", None)
+ ranks["labels"] = labels
+
+ return True, ranks
+
+
+def check_embedding(dataset_id: str, tenant_id: str, req: dict):
"""
- Update auto-metadata configuration for a dataset.
+ Check embedding model compatibility by sampling random chunks,
+ re-embedding them with the new model, and computing cosine similarity.
:param dataset_id: dataset ID
:param tenant_id: tenant ID
- :param cfg: auto-metadata configuration
+ :param req: request body with embd_id
:return: (success, result) or (success, error_message)
"""
- kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
- if kb is None:
- return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'"
+ import random
- parser_cfg = kb.parser_config or {}
- fields = []
- for f in cfg.get("fields", []):
- fields.append(
- {
- "name": f.get("name", ""),
- "type": f.get("type", ""),
- "description": f.get("description"),
- "examples": f.get("examples"),
- "restrict_values": f.get("restrict_values", False),
- }
+ import numpy as np
+ from common.constants import RetCode
+ from common.doc_store.doc_store_base import OrderByExpr
+ from rag.nlp import search
+
+ from api.db.joint_services.tenant_model_service import (
+ get_model_config_by_type_and_name,
+ )
+ from api.db.services.llm_service import LLMBundle
+ from common.constants import LLMType
+
+ def _guess_vec_field(src: dict):
+ for k in src or {}:
+ if k.endswith("_vec"):
+ return k
+ return None
+
+ def _as_float_vec(v):
+ if v is None:
+ return []
+ if isinstance(v, str):
+ return [float(x) for x in v.split("\t") if x != ""]
+ if isinstance(v, (list, tuple, np.ndarray)):
+ return [float(x) for x in v]
+ return []
+
+ def _to_1d(x):
+ a = np.asarray(x, dtype=np.float32)
+ return a.reshape(-1)
+
+ def _cos_sim(a, b, eps=1e-12):
+ a = _to_1d(a)
+ b = _to_1d(b)
+ na = np.linalg.norm(a)
+ nb = np.linalg.norm(b)
+ if na < eps or nb < eps:
+ return 0.0
+ return float(np.dot(a, b) / (na * nb))
+
+ def sample_random_chunks_with_vectors(
+ docStoreConn,
+ tenant_id: str,
+ kb_id: str,
+ n: int = 5,
+ base_fields=("docnm_kwd", "doc_id", "content_with_weight", "page_num_int", "position_int", "top_int"),
+ ):
+ index_nm = search.index_name(tenant_id)
+
+ res0 = docStoreConn.search(
+ select_fields=[], highlight_fields=[],
+ condition={"kb_id": kb_id, "available_int": 1},
+ match_expressions=[], order_by=OrderByExpr(),
+ offset=0, limit=1,
+ index_names=index_nm, knowledgebase_ids=[kb_id],
)
- parser_cfg["metadata"] = fields
- parser_cfg["enable_metadata"] = cfg.get("enabled", True)
+ total = docStoreConn.get_total(res0)
+ if total <= 0:
+ return []
+
+ n = min(n, total)
+ offsets = sorted(random.sample(range(min(total, 1000)), n))
+ out = []
+
+ for off in offsets:
+ res1 = docStoreConn.search(
+ select_fields=list(base_fields),
+ highlight_fields=[],
+ condition={"kb_id": kb_id, "available_int": 1},
+ match_expressions=[], order_by=OrderByExpr(),
+ offset=off, limit=1,
+ index_names=index_nm, knowledgebase_ids=[kb_id],
+ )
+ ids = docStoreConn.get_doc_ids(res1)
+ if not ids:
+ continue
- if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": parser_cfg}):
- return False, "Update auto-metadata error.(Database error)"
+ cid = ids[0]
+ full_doc = docStoreConn.get(cid, index_nm, [kb_id]) or {}
+ vec_field = _guess_vec_field(full_doc)
+ vec = _as_float_vec(full_doc.get(vec_field))
+
+ out.append({
+ "chunk_id": cid,
+ "kb_id": kb_id,
+ "doc_id": full_doc.get("doc_id"),
+ "doc_name": full_doc.get("docnm_kwd"),
+ "vector_field": vec_field,
+ "vector_dim": len(vec),
+ "vector": vec,
+ "page_num_int": full_doc.get("page_num_int"),
+ "position_int": full_doc.get("position_int"),
+ "top_int": full_doc.get("top_int"),
+ "content_with_weight": full_doc.get("content_with_weight") or "",
+ "question_kwd": full_doc.get("question_kwd") or [],
+ })
+ return out
+
+ def _clean(s: str):
+ return re.sub(r"?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "").strip()
+
+ if not dataset_id:
+ return False, 'Lack of "Dataset ID"'
+
+ if not KnowledgebaseService.accessible(dataset_id, tenant_id):
+ return False, "No authorization."
+
+ ok, kb = KnowledgebaseService.get_by_id(dataset_id)
+ if not ok:
+ return False, "Invalid Dataset ID"
+
+ embd_id = req.get("embd_id", "")
+ if not embd_id:
+ return False, "`embd_id` is required."
+
+ logging.info("check_embedding: dataset=%s tenant=%s embd_id=%s", dataset_id, tenant_id, embd_id)
+
+ ok, err = verify_embedding_availability(embd_id, tenant_id)
+ if not ok:
+ return False, err
+
+ embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, embd_id)
+ emb_mdl = LLMBundle(kb.tenant_id, embd_model_config)
+
+ n = int(req.get("check_num", 5))
+ samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=kb.tenant_id, kb_id=dataset_id, n=n)
+ logging.info("check_embedding: dataset=%s sampled=%d chunks", dataset_id, len(samples))
+
+ results, eff_sims = [], []
+ mode = "content_only"
+ for ck in samples:
+ title = ck.get("doc_name") or "Title"
+
+ txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
+ txt_in = _clean(txt_in)
+ if not txt_in:
+ results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
+ continue
+
+ if not ck.get("vector"):
+ results.append({"chunk_id": ck["chunk_id"], "reason": "no_stored_vector"})
+ continue
+
+ try:
+ v, _ = emb_mdl.encode([title, txt_in])
+ assert len(v[1]) == len(ck["vector"]), (
+ f"The dimension ({len(v[1])}) of given embedding model is different from the original ({len(ck['vector'])})"
+ )
+ sim_content = _cos_sim(v[1], ck["vector"])
+ title_w = 0.1
+ qv_mix = title_w * v[0] + (1 - title_w) * v[1]
+ sim_mix = _cos_sim(qv_mix, ck["vector"])
+ sim = sim_content
+ mode = "content_only"
+ if sim_mix > sim:
+ sim = sim_mix
+ mode = "title+content"
+ except Exception as e:
+ return False, f"Embedding failure. {e}"
+
+ eff_sims.append(sim)
+ results.append({
+ "chunk_id": ck["chunk_id"],
+ "doc_id": ck["doc_id"],
+ "doc_name": ck["doc_name"],
+ "vector_field": ck["vector_field"],
+ "vector_dim": ck["vector_dim"],
+ "cos_sim": round(sim, 6),
+ })
+
+ summary = {
+ "kb_id": dataset_id,
+ "model": embd_id,
+ "sampled": len(samples),
+ "valid": len(eff_sims),
+ "avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
+ "min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
+ "max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
+ "match_mode": mode,
+ }
+
+ data = {"summary": summary, "results": results}
+ if not eff_sims:
+ logging.warning("check_embedding: dataset=%s no comparable chunks", dataset_id)
+ return False, "No embedded chunks are available to compare."
+ if summary["avg_cos_sim"] >= 0.9:
+ logging.info("check_embedding: dataset=%s compatible avg_cos_sim=%s valid=%d", dataset_id, summary["avg_cos_sim"], len(eff_sims))
+ return True, data
+ logging.warning("check_embedding: dataset=%s not_effective avg_cos_sim=%s valid=%d", dataset_id, summary["avg_cos_sim"], len(eff_sims))
+ return "not_effective", {"code": RetCode.NOT_EFFECTIVE, "message": "Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", "data": data}
+
+
+async def search_datasets(tenant_id: str, req: dict):
+ """
+ Search (retrieval test) across multiple datasets.
+
+ :param tenant_id: tenant ID
+ :param req: search request containing dataset_ids and other params
+ :return: (success, result) or (success, error_message)
+ """
+ from api.db.joint_services.tenant_model_service import (
+ get_model_config_by_id,
+ get_model_config_by_type_and_name,
+ get_tenant_default_model_by_type,
+ )
+ from api.db.services.doc_metadata_service import DocMetadataService
+ from api.db.services.llm_service import LLMBundle
+ from api.db.services.search_service import SearchService
+ from api.db.services.user_service import UserTenantService
+ from common.constants import LLMType
+ from common.metadata_utils import apply_meta_data_filter
+ from rag.app.tag import label_question
+ from rag.prompts.generator import cross_languages, keyword_extraction
+
+ kb_ids = req.get("dataset_ids", [])
+ page = int(req.get("page", 1))
+ size = int(req.get("size", 30))
+ question = req.get("question", "")
+ doc_ids = req.get("doc_ids", [])
+ use_kg = req.get("use_kg", False)
+ top = max(1, min(int(req.get("top_k", 1024)), 2048))
+ langs = req.get("cross_languages", [])
+
+ logging.debug(
+ "search_datasets(datasets=%s, tenant=%s, question_len=%s)",
+ kb_ids,
+ tenant_id,
+ len(question),
+ )
+
+ # Access check for all datasets
+ for kb_id in kb_ids:
+ if not KnowledgebaseService.accessible(kb_id, tenant_id):
+ logging.warning("search_datasets access denied: dataset=%s tenant=%s", kb_id, tenant_id)
+ return False, f"Only owner of dataset {kb_id} authorized for this operation."
+
+ kbs = KnowledgebaseService.get_by_ids(kb_ids)
+ if not kbs:
+ return False, "Datasets not found!"
+
+ # All datasets must use the same embedding model
+ embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs]))
+ if len(embd_nms) != 1:
+ return False, "Datasets use different embedding models."
+
+ if doc_ids is not None and not isinstance(doc_ids, list):
+ return False, "`doc_ids` should be a list"
+ local_doc_ids = list(doc_ids) if doc_ids else []
+
+ meta_data_filter = {}
+ chat_mdl = None
+ if req.get("search_id", ""):
+ search_detail = SearchService.get_detail(req.get("search_id", ""))
+ if not search_detail:
+ logging.warning("search config not found: search_id=%s", req.get("search_id", ""))
+ return False, "Invalid search_id"
+ search_config = search_detail.get("search_config", {})
+ meta_data_filter = search_config.get("meta_data_filter", {})
+ if meta_data_filter.get("method") in ["auto", "semi_auto"]:
+ chat_id = search_config.get("chat_id", "")
+ if chat_id:
+ chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, search_config["chat_id"])
+ else:
+ chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
+ chat_mdl = LLMBundle(tenant_id, chat_model_config)
+ else:
+ meta_data_filter = req.get("meta_data_filter") or {}
+ if meta_data_filter.get("method") in ["auto", "semi_auto"]:
+ chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
+ chat_mdl = LLMBundle(tenant_id, chat_model_config)
+
+ if meta_data_filter:
+ local_doc_ids = await apply_meta_data_filter(
+ meta_data_filter,
+ None,
+ question,
+ chat_mdl,
+ local_doc_ids,
+ kb_ids=kb_ids,
+ metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
+ )
+
+ tenant_ids = []
+ tenants = UserTenantService.query(user_id=tenant_id)
+ for tenant in tenants:
+ if any(KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id) for kb_id in kb_ids):
+ tenant_ids.append(tenant.tenant_id)
+ break
+ else:
+ return False, "Only owner of datasets authorized for this operation."
+
+ kb = kbs[0]
+ _question = question
+ if langs:
+ _question = await cross_languages(kb.tenant_id, None, _question, langs)
+ if kb.tenant_embd_id:
+ embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
+ elif kb.embd_id:
+ embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
+ else:
+ embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING)
+ embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
+
+ rerank_mdl = None
+ if req.get("tenant_rerank_id"):
+ rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"])
+ rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
+ elif req.get("rerank_id"):
+ rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"])
+ rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
+
+ if req.get("keyword", False):
+ default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
+ chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config)
+ _question += await keyword_extraction(chat_mdl, _question)
+
+ labels = label_question(_question, kbs)
+ ranks = await settings.retriever.retrieval(
+ _question,
+ embd_mdl,
+ tenant_ids,
+ kb_ids,
+ page,
+ size,
+ float(req.get("similarity_threshold", 0.0)),
+ float(req.get("vector_similarity_weight", 0.3)),
+ doc_ids=local_doc_ids,
+ top=top,
+ rerank_mdl=rerank_mdl,
+ rank_feature=labels,
+ )
+
+ if use_kg:
+ try:
+ default_chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
+ ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, default_chat_model_config))
+ if ck["content_with_weight"]:
+ ranks["chunks"].insert(0, ck)
+ except Exception:
+ logging.warning("search_datasets KG retrieval failed: datasets=%s tenant=%s", kb_ids, tenant_id, exc_info=True)
+ total = ranks.get("total", 0)
+ ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids)
+ ranks["total"] = total
+
+ for c in ranks["chunks"]:
+ c.pop("vector", None)
+ ranks["labels"] = labels
- return True, {"enabled": parser_cfg["enable_metadata"], "fields": fields}
+ return True, ranks
diff --git a/api/apps/services/document_api_service.py b/api/apps/services/document_api_service.py
index 82dfa37e353..59abbd25072 100644
--- a/api/apps/services/document_api_service.py
+++ b/api/apps/services/document_api_service.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import logging
+
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
@@ -58,7 +60,7 @@ def update_document_name_only(document_id, req_doc_name):
)
return None
-def update_chunk_method_only(req, doc, dataset_id, tenant_id):
+def update_chunk_method(req, doc, tenant_id):
"""
Update chunk method only (without validation).
@@ -69,28 +71,56 @@ def update_chunk_method_only(req, doc, dataset_id, tenant_id):
Args:
req: The request dictionary containing chunk_method and parser_config.
doc: The document model from the database.
- dataset_id: The ID of the dataset containing the document.
tenant_id: The tenant ID for the document store.
Returns:
None if successful, or an error result dictionary if failed.
"""
if doc.parser_id.lower() != req["chunk_method"].lower():
- # if chunk method changed
- e = DocumentService.update_by_id(
- doc.id,
- {
- "parser_id": req["chunk_method"],
- "progress": 0,
- "progress_msg": "",
- "run": TaskStatus.UNSTART.value,
- },
- )
- if not e:
- return get_error_data_result(message="Document not found!")
+ # if chunk method changed, reset document for reparse
+ result = reset_document_for_reparse(doc, tenant_id, parser_id=req["chunk_method"])
+ if result:
+ return result
if not req.get("parser_config"):
req["parser_config"] = get_parser_config(req["chunk_method"], req.get("parser_config"))
DocumentService.update_parser_config(doc.id, req["parser_config"])
+ return None
+
+
+def reset_document_for_reparse(doc, tenant_id, parser_id=None, pipeline_id=None):
+ """
+ Reset document for reparsing.
+
+ Updates the parser_id and/or pipeline_id for a document, resets its progress,
+ clears existing chunks from the document store, and removes chunk images.
+
+ Args:
+ doc: The document model from the database.
+ tenant_id: The tenant ID for the document store.
+ parser_id: Optional new parser_id (chunk method). If None, keeps existing.
+ pipeline_id: Optional new pipeline_id. If None, keeps existing.
+
+ Returns:
+ None if successful, or an error result dictionary if failed.
+ """
+
+ # Build update fields
+ update_fields = {
+ "progress": 0,
+ "progress_msg": "",
+ "run": TaskStatus.UNSTART.value,
+ }
+ if parser_id is not None:
+ update_fields["parser_id"] = parser_id
+ if pipeline_id is not None:
+ update_fields["pipeline_id"] = pipeline_id
+
+ # Update document
+ e = DocumentService.update_by_id(doc.id, update_fields)
+ if not e:
+ return get_error_data_result(message="Document not found!")
+
+ # Delete chunks from document store
if doc.token_num > 0:
e = DocumentService.increment_chunk_num(
doc.id,
@@ -98,12 +128,20 @@ def update_chunk_method_only(req, doc, dataset_id, tenant_id):
doc.token_num * -1,
doc.chunk_num * -1,
doc.process_duration * -1,
- )
+ )
if not e:
return get_error_data_result(message="Document not found!")
- settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
+ settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
+
+ # Delete chunk images
+ try:
+ DocumentService.delete_chunk_images(doc, tenant_id)
+ except Exception as e:
+ logging.error(f"error when delete chunk images:{e}")
+
return None
+
def update_document_status_only(status:int, doc, kb):
"""
Update document status only (without validation).
diff --git a/api/apps/services/file_api_service.py b/api/apps/services/file_api_service.py
index d6fe9248a50..cfde3de2948 100644
--- a/api/apps/services/file_api_service.py
+++ b/api/apps/services/file_api_service.py
@@ -67,14 +67,14 @@ async def upload_file(tenant_id: str, pf_id: str, file_objs: list):
if not e:
return False, "Folder not found!"
last_folder = await thread_pool_exec(
- FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, len_id_list
+ FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, len_id_list, tenant_id, tenant_id
)
else:
e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2])
if not e:
return False, "Folder not found!"
last_folder = await thread_pool_exec(
- FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, len_id_list
+ FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, len_id_list, tenant_id, tenant_id
)
filetype = filename_type(file_obj_names[file_len - 1])
@@ -121,7 +121,7 @@ async def create_folder(tenant_id: str, name: str, pf_id: str = None, file_type:
if FileService.query(name=name, parent_id=pf_id):
return False, "Duplicated folder name in the same folder."
- if file_type == FileType.FOLDER.value:
+ if (file_type or "").lower() == FileType.FOLDER.value:
ft = FileType.FOLDER.value
else:
ft = FileType.VIRTUAL.value
@@ -158,6 +158,7 @@ def list_files(tenant_id: str, args: dict):
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, tenant_id)
+ FileService.init_skills_folder(pf_id, tenant_id)
e, file = FileService.get_by_id(pf_id)
if not e:
@@ -173,92 +174,305 @@ def list_files(tenant_id: str, args: dict):
-def get_parent_folder(file_id: str):
+def get_parent_folder(file_id: str, user_id: str = None):
"""
- Get parent folder of a file.
+ Get parent folder of a file with permission check.
:param file_id: file ID
+ :param user_id: user ID for permission validation
:return: (success, result) or (success, error_message)
"""
+ from api.common.check_team_permission import check_file_team_permission
+
e, file = FileService.get_by_id(file_id)
if not e:
return False, "Folder not found!"
+ # Permission check
+ if user_id and not check_file_team_permission(file, user_id):
+ return False, "No authorization."
+
parent_folder = FileService.get_parent_folder(file_id)
return True, {"parent_folder": parent_folder.to_json()}
-def get_all_parent_folders(file_id: str):
+def get_all_parent_folders(file_id: str, user_id: str = None):
"""
- Get all ancestor folders of a file.
+ Get all ancestor folders of a file with permission check.
:param file_id: file ID
+ :param user_id: user ID for permission validation
:return: (success, result) or (success, error_message)
"""
+ from api.common.check_team_permission import check_file_team_permission
+
e, file = FileService.get_by_id(file_id)
if not e:
return False, "Folder not found!"
+ # Permission check
+ if user_id and not check_file_team_permission(file, user_id):
+ return False, "No authorization."
+
parent_folders = FileService.get_all_parent_folders(file_id)
return True, {"parent_folders": [pf.to_json() for pf in parent_folders]}
-async def delete_files(uid: str, file_ids: list):
+async def delete_files(uid: str, file_ids: list, auth_header: str = ""):
"""
Delete files/folders with team permission check and recursive deletion.
:param uid: user ID
:param file_ids: list of file IDs to delete
+ :param auth_header: Authorization header for Go backend API calls
:return: (success, result) or (success, error_message)
"""
- def _delete_single_file(file):
+ errors: list[str] = []
+ success_count = 0
+
+ def _get_space_uuid_by_name(tenant_id, space_name, authorization):
+ """Get space UUID by space name from Go backend"""
+ try:
+ import requests
+
+ host = getattr(settings, 'HOST_IP', '127.0.0.1')
+ # Go service runs on port+4 (9384 by default)
+ port = getattr(settings, 'HOST_PORT', 9380) + 4
+ service_url = f"http://{host}:{port}"
+
+ # List all spaces and find the one matching the name
+ url = f"{service_url}/api/v1/skills/spaces"
+ headers = {"Content-Type": "application/json"}
+ if authorization:
+ headers["Authorization"] = authorization
+
+ response = requests.get(url, headers=headers, timeout=10)
+
+ if response.status_code == 200:
+ data = response.json()
+ if data.get("code") == 0:
+ spaces = data.get("data", {}).get("spaces", [])
+ for space in spaces:
+ if space.get("name") == space_name:
+ return space.get("id")
+ except Exception as e:
+ logging.warning(f"Error getting space UUID: {e}")
+ return None
+
+ def _delete_skill_index(tenant_id, space_name, skill_name, authorization):
+ """Delete skill index from Go backend.
+
+ Returns:
+ bool: True if deletion succeeded (HTTP 200), False otherwise.
+ """
+ try:
+ import requests
+ from urllib.parse import quote
+
+ # Construct service URL from settings
+ host = getattr(settings, 'HOST_IP', '127.0.0.1')
+ # Go service runs on port+4 (9384 by default)
+ port = getattr(settings, 'HOST_PORT', 9380) + 4
+ service_url = f"http://{host}:{port}"
+
+ # Get space UUID from space name
+ space_uuid = _get_space_uuid_by_name(tenant_id, space_name, authorization)
+ space_id = space_uuid if space_uuid else space_name
+
+ url = f"{service_url}/api/v1/skills/index?skill_id={quote(skill_name)}&space_id={quote(space_id)}"
+ headers = {"Content-Type": "application/json"}
+ if authorization:
+ headers["Authorization"] = authorization
+
+ response = requests.delete(url, headers=headers, timeout=10)
+ if response.status_code == 200:
+ try:
+ data = response.json()
+ if data.get("code") == 0:
+ logging.info(
+ f"Successfully deleted skill index: space={space_name}, skill={skill_name}, "
+ f"status={response.status_code}, code=0"
+ )
+ return True
+ else:
+ app_code = data.get("code", "unknown")
+ app_msg = data.get("message", "no message")
+ logging.error(
+ f"Failed to delete skill index: space={space_name}, skill={skill_name}, "
+ f"status={response.status_code}, app_code={app_code}, app_msg={app_msg}, "
+ f"response={response.text}"
+ )
+ return False
+ except ValueError as json_err:
+ # JSON decode error - treat as failure
+ logging.error(
+ f"Failed to parse delete response JSON: space={space_name}, skill={skill_name}, "
+ f"error={json_err}, raw_response={response.text}"
+ )
+ return False
+ else:
+ logging.error(
+ f"Failed to delete skill index: space={space_name}, skill={skill_name}, "
+ f"status={response.status_code}, response={response.text}"
+ )
+ return False
+ except Exception as e:
+ logging.error(
+ f"Exception deleting skill index: space={space_name}, skill={skill_name}, error={e}"
+ )
+ return False
+
+ def _delete_single_file(file) -> int:
try:
if file.location:
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
except Exception as e:
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
+ errors.append(f"Failed to remove object {file.parent_id}/{file.location}: {e}")
informs = File2DocumentService.get_by_file_id(file.id)
for inform in informs:
doc_id = inform.document_id
e, doc = DocumentService.get_by_id(doc_id)
- if e and doc:
- tenant_id = DocumentService.get_tenant_id(doc_id)
- if tenant_id:
- DocumentService.remove_document(doc, tenant_id)
- File2DocumentService.delete_by_file_id(file.id)
+ if not e or not doc:
+ errors.append(f"Document not found for file {file.id}: {doc_id}")
+ continue
+
+ tenant_id = DocumentService.get_tenant_id(doc_id)
+ if not tenant_id:
+ errors.append(f"Tenant not found for document {doc_id}")
+ continue
+
+ if not DocumentService.remove_document(doc, tenant_id):
+ errors.append(f"Failed to remove document {doc_id} for file {file.id}")
- FileService.delete(file)
+ try:
+ File2DocumentService.delete_by_file_id(file.id)
+ except Exception as e:
+ logging.exception(f"Fail to remove file-document relations for file {file.id}, error: {e}")
+ errors.append(f"Failed to remove file-document relations for file {file.id}: {e}")
- def _delete_folder_recursive(folder, tenant_id):
+ try:
+ FileService.delete(file)
+ except Exception as e:
+ logging.exception(f"Fail to delete file record {file.id}, error: {e}")
+ errors.append(f"Failed to delete file record {file.id}: {e}")
+ else:
+ return 1
+
+ return 0
+
+ def _find_ancestor_skill_space(folder_id, tenant_id):
+ """Walk up the folder hierarchy to find an ancestor with source_type == 'skill_space'.
+
+ Returns:
+ tuple: (success, folder) where folder has source_type == 'skill_space', or (False, None)
+ """
+ visited = set()
+ current_id = folder_id
+ while current_id and current_id not in visited:
+ visited.add(current_id)
+ success, folder = FileService.get_by_id(current_id)
+ if not success or not folder:
+ return False, None
+ if folder.source_type == "skill_space":
+ return True, folder
+ # Move to parent
+ current_id = folder.parent_id
+ return False, None
+
+ def _delete_folder_recursive(folder, tenant_id) -> int:
+ deleted = 0
+ current_space_name = None
+ is_space_folder = folder.source_type == "skill_space"
+ is_skill_folder = False
+
+ if not is_space_folder:
+ parent_success, parent_folder = FileService.get_by_id(folder.parent_id)
+ if parent_success and parent_folder and parent_folder.source_type == "skill_space":
+ is_skill_folder = True
+ current_space_name = parent_folder.name
+ logging.info(f"Identified skill folder '{folder.name}' (parent space: {current_space_name})")
+ else:
+ ancestor_success, ancestor_folder = _find_ancestor_skill_space(folder.parent_id, tenant_id)
+ if ancestor_success and ancestor_folder:
+ is_skill_folder = True
+ current_space_name = ancestor_folder.name
+ logging.info(f"Identified skill folder '{folder.name}' (ancestor space: {current_space_name})")
+
+ if is_space_folder:
+ current_space_name = folder.name
+ logging.info(f"Processing space folder '{folder.name}' - will delete all skill indexes within")
+
+ if is_skill_folder and current_space_name and not is_space_folder:
+ logging.info(f"Deleting skill index for skill '{folder.name}' in space '{current_space_name}'")
+ index_deleted = _delete_skill_index(tenant_id, current_space_name, folder.name, auth_header)
+ if not index_deleted:
+ logging.error(
+ f"Aborting folder deletion due to index deletion failure: "
+ f"folder={folder.name}, space={current_space_name}"
+ )
+ errors.append(
+ f"Failed to delete skill index for folder '{folder.name}' in space '{current_space_name}'. "
+ f"Folder deletion aborted to prevent orphaned indexes."
+ )
+ return deleted
sub_files = FileService.list_all_files_by_parent_id(folder.id)
+ logging.info(f"Folder '{folder.name}': found {len(sub_files)} children to delete")
+
for sub_file in sub_files:
if sub_file.type == FileType.FOLDER.value:
- _delete_folder_recursive(sub_file, tenant_id)
+ deleted += _delete_folder_recursive(sub_file, tenant_id)
+ else:
+ deleted += _delete_single_file(sub_file)
+ try:
+ FileService.delete(folder)
+ except Exception as e:
+ logging.exception(f"Fail to delete folder record {folder.id}, error: {e}")
+ errors.append(f"Failed to delete folder record {folder.id}: {e}")
+ else:
+ deleted += 1
+
+ try:
+ if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
+ logging.info(f"Removing storage bucket for folder '{folder.name}' (id={folder.id})")
+ settings.STORAGE_IMPL.remove_bucket(folder.id)
else:
- _delete_single_file(sub_file)
- FileService.delete(folder)
+ logging.debug(f"Storage implementation does not support remove_bucket, skipping for folder '{folder.name}'")
+ except Exception as e:
+ logging.warning(f"Failed to remove storage bucket for folder '{folder.name}' (id={folder.id}): {e}")
+
+ return deleted
def _rm_sync():
+ nonlocal success_count
for file_id in file_ids:
e, file = FileService.get_by_id(file_id)
if not e or not file:
- return False, "File or Folder not found!"
+ errors.append(f"File or Folder not found: {file_id}")
+ continue
if not file.tenant_id:
- return False, "Tenant not found!"
+ errors.append(f"Tenant not found for file {file_id}")
+ continue
if not check_file_team_permission(file, uid):
- return False, "No authorization."
+ errors.append(f"No authorization for file {file_id}")
+ continue
if file.source_type == FileSource.KNOWLEDGEBASE:
continue
+ if file.source_type == "skill_space":
+ continue
+
if file.type == FileType.FOLDER.value:
- _delete_folder_recursive(file, uid)
+ success_count += _delete_folder_recursive(file, uid)
continue
- _delete_single_file(file)
+ success_count += _delete_single_file(file)
- return True, True
+ if errors:
+ return False, {"success_count": success_count, "errors": errors}
+ return True, {"success_count": success_count}
return await thread_pool_exec(_rm_sync)
@@ -307,6 +521,18 @@ async def move_files(uid: str, src_file_ids: list, dest_file_id: str = None, new
if f.name == new_name:
return False, "Duplicated file name in the same folder."
+ if dest_folder:
+ for file in files:
+ if file.type == FileType.FOLDER.value and file.id == dest_folder.id:
+ return False, "Cannot move a folder to itself."
+ # Check if any source folder is an ancestor of the destination folder
+ # to prevent infinite recursion in _move_entry_recursive
+ dest_ancestors = FileService.get_all_parent_folders(dest_folder.id)
+ dest_ancestor_ids = {f.id for f in dest_ancestors}
+ for file in files:
+ if file.type == FileType.FOLDER.value and file.id in dest_ancestor_ids:
+ return False, "Cannot move a folder into its own subfolder."
+
def _move_entry_recursive(source_file_entry, dest_folder_entry, override_name=None):
effective_name = override_name or source_file_entry.name
diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py
index 1b640cff66b..9040f0ce445 100644
--- a/api/apps/services/memory_api_service.py
+++ b/api/apps/services/memory_api_service.py
@@ -29,6 +29,49 @@
from common.time_utils import current_timestamp, timestamp_to_date
+def _split_filter_values(values):
+ if not values:
+ return []
+ if isinstance(values, str):
+ values = [values]
+ res = []
+ for value in values:
+ if not value:
+ continue
+ if isinstance(value, str):
+ res.extend([v.strip() for v in value.split(",") if v.strip()])
+ else:
+ res.append(value)
+ return res
+
+
+def _joined_tenant_ids(user_id: str) -> set[str]:
+ user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(user_id)
+ return {user_id, *[tenant["tenant_id"] for tenant in user_tenants]}
+
+
+def _memory_accessible(memory) -> bool:
+ if memory.tenant_id == current_user.id:
+ return True
+ if memory.permissions != TenantPermission.TEAM.value:
+ return False
+ return memory.tenant_id in _joined_tenant_ids(current_user.id)
+
+
+def _require_memory_access(memory_id: str):
+ memory = MemoryService.get_by_memory_id(memory_id)
+ if not memory or not _memory_accessible(memory):
+ raise NotFoundException(f"Memory '{memory_id}' not found.")
+ return memory
+
+
+def _filter_accessible_memories(memory_ids: list[str]):
+ memory_ids = _split_filter_values(memory_ids)
+ if not memory_ids:
+ return []
+ return [memory for memory in MemoryService.get_by_ids(memory_ids) if _memory_accessible(memory)]
+
+
async def create_memory(memory_info: dict):
"""
:param memory_info: {
@@ -137,9 +180,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
for field in ["avatar", "description", "system_prompt", "user_prompt"]:
if field in new_memory_setting:
update_dict[field] = new_memory_setting[field]
- current_memory = MemoryService.get_by_memory_id(memory_id)
- if not current_memory:
- raise NotFoundException(f"Memory '{memory_id}' not found.")
+ current_memory = _require_memory_access(memory_id)
memory_dict = current_memory.to_dict()
memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)})
@@ -168,9 +209,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict):
async def delete_memory(memory_id):
- memory = MemoryService.get_by_memory_id(memory_id)
- if not memory:
- raise NotFoundException(f"Memory '{memory_id}' not found.")
+ memory = _require_memory_access(memory_id)
MemoryService.delete_memory(memory_id)
if MessageService.has_index(memory.tenant_id, memory_id):
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
@@ -188,19 +227,16 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size
:param page: int
:param page_size: int
"""
- filter_dict: dict = {"storage_type": filter_params.get("storage_type")}
- tenant_ids = filter_params.get("tenant_id")
- if not filter_params.get("tenant_id"):
- # restrict to current user's tenants
- user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id)
- filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants]
+ filter_dict: dict = {"storage_type": filter_params.get("storage_type"), "accessible_user_id": current_user.id}
+ allowed_tenant_ids = _joined_tenant_ids(current_user.id)
+ tenant_ids = _split_filter_values(filter_params.get("tenant_id") or filter_params.get("owner_ids"))
+ if tenant_ids:
+ filter_dict["tenant_id"] = [tenant_id for tenant_id in tenant_ids if tenant_id in allowed_tenant_ids]
+ if not filter_dict["tenant_id"]:
+ return {"memory_list": [], "total_count": 0}
else:
- if len(tenant_ids) == 1 and ',' in tenant_ids[0]:
- tenant_ids = tenant_ids[0].split(',')
- filter_dict["tenant_id"] = tenant_ids
- memory_types = filter_params.get("memory_type")
- if memory_types and len(memory_types) == 1 and ',' in memory_types[0]:
- memory_types = memory_types[0].split(',')
+ filter_dict["tenant_id"] = list(allowed_tenant_ids)
+ memory_types = _split_filter_values(filter_params.get("memory_type"))
filter_dict["memory_type"] = memory_types
memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size)
@@ -212,15 +248,13 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size
async def get_memory_config(memory_id):
memory = MemoryService.get_with_owner_name_by_id(memory_id)
- if not memory:
+ if not memory or not _memory_accessible(memory):
raise NotFoundException(f"Memory '{memory_id}' not found.")
return format_ret_data_from_memory(memory)
async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int=1, page_size: int = 50):
- memory = MemoryService.get_by_memory_id(memory_id)
- if not memory:
- raise NotFoundException(f"Memory '{memory_id}' not found.")
+ memory = _require_memory_access(memory_id)
messages = MessageService.list_message(
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
agent_name_mapping = {}
@@ -253,13 +287,14 @@ async def add_message(memory_ids: list[str], message_dict: dict):
"message_type": str
}
"""
- return await queue_save_to_memory_task(memory_ids, message_dict)
+ accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)]
+ if not accessible_memory_ids:
+ return False, "Memory not found."
+ return await queue_save_to_memory_task(accessible_memory_ids, message_dict)
async def forget_message(memory_id: str, message_id: int):
- memory = MemoryService.get_by_memory_id(memory_id)
- if not memory:
- raise NotFoundException(f"Memory '{memory_id}' not found.")
+ memory = _require_memory_access(memory_id)
forget_time = timestamp_to_date(current_timestamp())
update_succeed = MessageService.update_message(
@@ -272,9 +307,7 @@ async def forget_message(memory_id: str, message_id: int):
async def update_message_status(memory_id: str, message_id: int, status: bool):
- memory = MemoryService.get_by_memory_id(memory_id)
- if not memory:
- raise NotFoundException(f"Memory '{memory_id}' not found.")
+ memory = _require_memory_access(memory_id)
update_succeed = MessageService.update_message(
{"memory_id": memory_id, "message_id": int(message_id)},
@@ -300,6 +333,11 @@ async def search_message(filter_dict: dict, params: dict):
"top_n": int
}
"""
+ memory_ids = _split_filter_values(filter_dict.get("memory_id"))
+ accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)]
+ if not accessible_memory_ids:
+ return []
+ filter_dict = {**filter_dict, "memory_id": accessible_memory_ids}
return query_message(filter_dict, params)
@@ -313,11 +351,14 @@ async def get_messages(memory_ids: list[str], agent_id: str = "", session_id: st
:param limit: maximum number of messages to return
:return: list of recent messages
"""
- memory_list = MemoryService.get_by_ids(memory_ids)
+ memory_list = _filter_accessible_memories(memory_ids)
+ if not memory_list:
+ return []
uids = [memory.tenant_id for memory in memory_list]
+ accessible_memory_ids = [memory.id for memory in memory_list]
res = MessageService.get_recent_messages(
uids,
- memory_ids,
+ accessible_memory_ids,
agent_id,
session_id,
limit
@@ -334,11 +375,9 @@ async def get_message_content(memory_id: str, message_id: int):
:return: message content
:raises NotFoundException: if memory or message not found
"""
- memory = MemoryService.get_by_memory_id(memory_id)
- if not memory:
- raise NotFoundException(f"Memory '{memory_id}' not found.")
+ memory = _require_memory_access(memory_id)
res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id)
if res:
return res
- raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.")
\ No newline at end of file
+ raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.")
diff --git a/api/apps/system_app.py b/api/apps/system_app.py
deleted file mode 100644
index 833a7819dd5..00000000000
--- a/api/apps/system_app.py
+++ /dev/null
@@ -1,197 +0,0 @@
-#
-# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License
-#
-import logging
-from datetime import datetime
-import json
-
-from api.apps import login_required
-
-from api.db.services.knowledgebase_service import KnowledgebaseService
-from api.utils.api_utils import (
- get_json_result,
-)
-
-from timeit import default_timer as timer
-
-from rag.utils.redis_conn import REDIS_CONN
-from api.utils.health_utils import get_oceanbase_status
-from common import settings
-
-@manager.route("/status", methods=["GET"]) # noqa: F821
-@login_required
-def status():
- """
- Get the system status.
- ---
- tags:
- - System
- security:
- - ApiKeyAuth: []
- responses:
- 200:
- description: System is operational.
- schema:
- type: object
- properties:
- es:
- type: object
- description: Elasticsearch status.
- storage:
- type: object
- description: Storage status.
- database:
- type: object
- description: Database status.
- 503:
- description: Service unavailable.
- schema:
- type: object
- properties:
- error:
- type: string
- description: Error message.
- """
- res = {}
- st = timer()
- try:
- res["doc_engine"] = settings.docStoreConn.health()
- res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
- except Exception as e:
- res["doc_engine"] = {
- "type": "unknown",
- "status": "red",
- "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
- "error": str(e),
- }
-
- st = timer()
- try:
- settings.STORAGE_IMPL.health()
- res["storage"] = {
- "storage": settings.STORAGE_IMPL_TYPE.lower(),
- "status": "green",
- "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
- }
- except Exception as e:
- res["storage"] = {
- "storage": settings.STORAGE_IMPL_TYPE.lower(),
- "status": "red",
- "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
- "error": str(e),
- }
-
- st = timer()
- try:
- KnowledgebaseService.get_by_id("x")
- res["database"] = {
- "database": settings.DATABASE_TYPE.lower(),
- "status": "green",
- "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
- }
- except Exception as e:
- res["database"] = {
- "database": settings.DATABASE_TYPE.lower(),
- "status": "red",
- "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
- "error": str(e),
- }
-
- st = timer()
- try:
- if not REDIS_CONN.health():
- raise Exception("Lost connection!")
- res["redis"] = {
- "status": "green",
- "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
- }
- except Exception as e:
- res["redis"] = {
- "status": "red",
- "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
- "error": str(e),
- }
-
- task_executor_heartbeats = {}
- try:
- task_executors = REDIS_CONN.smembers("TASKEXE")
- now = datetime.now().timestamp()
- for task_executor_id in task_executors:
- heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60 * 30, now)
- heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
- task_executor_heartbeats[task_executor_id] = heartbeats
- except Exception:
- logging.exception("get task executor heartbeats failed!")
- res["task_executor_heartbeats"] = task_executor_heartbeats
-
- return get_json_result(data=res)
-
-@manager.route("/oceanbase/status", methods=["GET"]) # noqa: F821
-@login_required
-def oceanbase_status():
- """
- Get OceanBase health status and performance metrics.
- ---
- tags:
- - System
- security:
- - ApiKeyAuth: []
- responses:
- 200:
- description: OceanBase status retrieved successfully.
- schema:
- type: object
- properties:
- status:
- type: string
- description: Status (alive/timeout).
- message:
- type: object
- description: Detailed status information including health and performance metrics.
- """
- try:
- status_info = get_oceanbase_status()
- return get_json_result(data=status_info)
- except Exception as e:
- return get_json_result(
- data={
- "status": "error",
- "message": f"Failed to get OceanBase status: {str(e)}"
- },
- code=500
- )
-
-
-@manager.route("/config", methods=["GET"]) # noqa: F821
-def get_config():
- """
- Get system configuration.
- ---
- tags:
- - System
- responses:
- 200:
- description: Return system configuration
- schema:
- type: object
- properties:
- registerEnable:
- type: integer 0 means disabled, 1 means enabled
- description: Whether user registration is enabled
- """
- return get_json_result(data={
- "registerEnabled": settings.REGISTER_ENABLED,
- "disablePasswordLogin": settings.DISABLE_PASSWORD_LOGIN,
- })
diff --git a/api/db/__init__.py b/api/db/__init__.py
index 0ebd9f56f3f..6d7ed9fcb97 100644
--- a/api/db/__init__.py
+++ b/api/db/__init__.py
@@ -74,3 +74,4 @@ class PipelineTaskType(StrEnum):
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
+SKILLS_FOLDER_NAME="skills"
diff --git a/api/db/db_models.py b/api/db/db_models.py
index 433ed78afe2..5fe64586c04 100644
--- a/api/db/db_models.py
+++ b/api/db/db_models.py
@@ -55,7 +55,7 @@
from common.time_utils import current_timestamp, timestamp_to_date, date_string_to_timestamp
from common.decorator import singleton
-from common.constants import ParserType
+from common.constants import ParserType, MAXIMUM_TASK_PAGE_NUMBER
from common import settings
@@ -726,7 +726,7 @@ def __str__(self):
return self.email
def get_id(self):
- jwt = Serializer(secret_key=settings.SECRET_KEY)
+ jwt = Serializer(secret_key=settings.get_secret_key())
return jwt.dumps(str(self.access_token))
class Meta:
@@ -945,7 +945,7 @@ class Task(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
doc_id = CharField(max_length=32, null=False, index=True)
from_page = IntegerField(default=0)
- to_page = IntegerField(default=100000000)
+ to_page = IntegerField(default=MAXIMUM_TASK_PAGE_NUMBER)
task_type = CharField(max_length=32, null=False, default="")
priority = IntegerField(default=0)
diff --git a/api/db/joint_services/tenant_model_service.py b/api/db/joint_services/tenant_model_service.py
index f53f83ab957..645d7563812 100644
--- a/api/db/joint_services/tenant_model_service.py
+++ b/api/db/joint_services/tenant_model_service.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import logging
import os
import enum
from common import settings
@@ -20,14 +21,22 @@
from api.db.services.llm_service import LLMService
from api.db.services.tenant_llm_service import TenantLLMService, TenantService
+logger = logging.getLogger(__name__)
+
def get_model_config_by_id(tenant_model_id: int) -> dict:
found, model_config = TenantLLMService.get_by_id(tenant_model_id)
if not found:
raise LookupError(f"Tenant Model with id {tenant_model_id} not found")
config_dict = model_config.to_dict()
+ api_key, is_tools, api_key_payload = TenantLLMService._decode_api_key_config(config_dict.get("api_key", ""))
+ config_dict["api_key"] = api_key
+ if api_key_payload is not None:
+ config_dict["api_key_payload"] = api_key_payload
+ if is_tools is not None:
+ config_dict["is_tools"] = is_tools
llm = LLMService.query(llm_name=config_dict["llm_name"])
- if llm:
+ if "is_tools" not in config_dict and llm:
config_dict["is_tools"] = llm[0].is_tools
return config_dict
@@ -57,6 +66,31 @@ def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_nam
"api_base": embedding_cfg["base_url"],
"model_type": LLMType.EMBEDDING.value,
}
+ elif model_type_val == LLMType.CHAT.value:
+ # Retry as CHAT with pure_model_name first; then fall back to a multimodal model registered under IMAGE2TEXT.
+ model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.CHAT.value)
+ if not model_config:
+ model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.IMAGE2TEXT.value)
+ if not model_config:
+ raise LookupError(f"Tenant Model with name {model_name} and type {model_type_val} not found")
+ config_dict = model_config.to_dict()
+ elif model_type_val == LLMType.IMAGE2TEXT.value:
+ model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.IMAGE2TEXT.value)
+ if not model_config:
+ # Fall back to a chat model only if it has declared IMAGE2TEXT capability (tag check via llm table)
+ chat_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.CHAT.value)
+ logger.debug("IMAGE2TEXT config not found for %s; chat_config found: %s", pure_model_name, chat_config is not None)
+ if chat_config:
+ llm_entry = LLMService.query(fid=chat_config.llm_factory, llm_name=chat_config.llm_name)
+ tags = [t.strip() for t in (llm_entry[0].tags or "").split(",")] if llm_entry else []
+ logger.debug("LLM tags for %s/%s: %s", chat_config.llm_factory, chat_config.llm_name, tags)
+ if "IMAGE2TEXT" in tags:
+ logger.debug("Promoting chat config to IMAGE2TEXT for %s", pure_model_name)
+ model_config = chat_config
+ if not model_config:
+ raise LookupError(f"Tenant Model with name {model_name} and type {model_type_val} not found")
+ config_dict = model_config.to_dict()
+ config_dict["model_type"] = LLMType.IMAGE2TEXT.value
else:
model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, model_type_val)
if not model_config:
@@ -65,14 +99,26 @@ def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_nam
else:
# model_name without @factory
config_dict = model_config.to_dict()
+ api_key, is_tools, api_key_payload = TenantLLMService._decode_api_key_config(config_dict.get("api_key", ""))
+ config_dict["api_key"] = api_key
+ if api_key_payload is not None:
+ config_dict["api_key_payload"] = api_key_payload
+ if is_tools is not None:
+ config_dict["is_tools"] = is_tools
config_model_type = config_dict.get("model_type")
config_model_type = config_model_type.value if hasattr(config_model_type, "value") else config_model_type
- if config_model_type != model_type_val:
+ if config_model_type != model_type_val and not (
+ model_type_val == LLMType.CHAT.value
+ and config_model_type == LLMType.IMAGE2TEXT.value
+ ) and not (
+ model_type_val == LLMType.IMAGE2TEXT.value
+ and config_model_type == LLMType.CHAT.value
+ ):
raise LookupError(
f"Tenant Model with name {model_name} has type {config_model_type}, expected {model_type_val}"
)
llm = LLMService.query(llm_name=config_dict["llm_name"])
- if llm:
+ if "is_tools" not in config_dict and llm:
config_dict["is_tools"] = llm[0].is_tools
return config_dict
diff --git a/api/db/services/api_service.py b/api/db/services/api_service.py
index be41dc1b642..8f60a1c5ab5 100644
--- a/api/db/services/api_service.py
+++ b/api/db/services/api_service.py
@@ -44,6 +44,14 @@ def delete_by_tenant_id(cls, tenant_id):
class API4ConversationService(CommonService):
model = API4Conversation
+ @staticmethod
+ def _normalize_query_date(value, is_end=False):
+ if "T" in value:
+ value = datetime.fromisoformat(value.replace("Z", "+00:00")).astimezone().replace(tzinfo=None).strftime("%Y-%m-%d %H:%M:%S")
+ elif len(value) == 10:
+ value = f"{value} 23:59:59" if is_end else f"{value} 00:00:00"
+ return value
+
@classmethod
@DB.connection_context()
def get_list(cls, dialog_id, tenant_id,
@@ -62,10 +70,11 @@ def get_list(cls, dialog_id, tenant_id,
sessions = sessions.where(cls.model.user_id == user_id)
if keywords:
sessions = sessions.where(peewee.fn.LOWER(cls.model.message).contains(keywords.lower()))
+ date_field = cls.model.update_date if orderby.startswith("update_") else cls.model.create_date
if from_date:
- sessions = sessions.where(cls.model.create_date >= from_date)
+ sessions = sessions.where(date_field >= cls._normalize_query_date(from_date))
if to_date:
- sessions = sessions.where(cls.model.create_date <= to_date)
+ sessions = sessions.where(date_field <= cls._normalize_query_date(to_date, is_end=True))
if exp_user_id:
sessions = sessions.where(cls.model.exp_user_id == exp_user_id)
if desc:
diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py
index 98925fa246a..4a5734e155d 100644
--- a/api/db/services/canvas_service.py
+++ b/api/db/services/canvas_service.py
@@ -139,10 +139,17 @@ def get_basic_info_by_canvas_ids(cls, canvas_id):
@classmethod
@DB.connection_context()
- def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
- page_number, items_per_page,
- orderby, desc, keywords, canvas_category=None
- ):
+ def get_by_tenant_ids(
+ cls,
+ joined_tenant_ids,
+ user_id,
+ page_number,
+ items_per_page,
+ orderby,
+ desc,
+ keywords,
+ canvas_category=None,
+ ):
fields = [
cls.model.id,
cls.model.avatar,
@@ -201,7 +208,11 @@ def accessible(cls, canvas_id, tenant_id):
return False
tids = [t.tenant_id for t in UserTenantService.query(user_id=tenant_id)]
- if c["user_id"] != canvas_id and c["user_id"] not in tids:
+ if c["user_id"] == tenant_id:
+ return True
+ if c["user_id"] not in tids:
+ return False
+ if c["permission"] != TenantPermission.TEAM.value:
return False
return True
@@ -210,8 +221,6 @@ def get_agent_dsl_with_release(cls, agent_id, release_mode=False, tenant_id=None
e, cvs = cls.get_by_id(agent_id)
if not e:
raise LookupError("Agent not found.")
- if tenant_id and cvs.user_id != tenant_id:
- raise PermissionError("You do not own the agent.")
if release_mode:
released_version = UserCanvasVersionService.get_latest_released(agent_id)
diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py
index 85d495d9d63..9f7b0e6ded1 100644
--- a/api/db/services/connector_service.py
+++ b/api/db/services/connector_service.py
@@ -29,6 +29,7 @@
from api.utils.common import hash128
from common.misc_utils import get_uuid
from common.constants import TaskStatus
+from common.settings import TIMEZONE
from common.time_utils import current_timestamp, timestamp_to_date
class ConnectorService(CommonService):
@@ -99,7 +100,7 @@ def cleanup_stale_documents_for_task(
return 0, []
source_type = f"{conn.source}/{conn.id}"
- retain_doc_ids = {hash128(file.id) for file in file_list}
+ retain_doc_ids = {hash128(f"{connector_id}:{file.id}") for file in file_list}
existing_docs = DocumentService.list_doc_headers_by_kb_and_source_type(
kb_id,
source_type,
@@ -179,14 +180,14 @@ def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15)
else:
database_type = os.getenv("DB_TYPE", "mysql")
if "postgres" in database_type.lower():
- interval_expr = SQL("make_interval(mins => t2.refresh_freq)")
+ expr = SQL(f"NOW() AT TIME ZONE '{TIMEZONE}' - make_interval(mins => t2.refresh_freq)")
else:
- interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE")
+ expr = SQL("NOW() - INTERVAL `t2`.`refresh_freq` MINUTE")
query = query.where(
Connector.input_type == InputType.POLL,
Connector.status == TaskStatus.SCHEDULE,
cls.model.status == TaskStatus.SCHEDULE,
- cls.model.update_date < (fn.NOW() - interval_expr)
+ cls.model.update_date < expr
)
query = query.distinct().order_by(cls.model.update_time.desc())
diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py
index 5a205b14219..2603676e98e 100644
--- a/api/db/services/conversation_service.py
+++ b/api/db/services/conversation_service.py
@@ -14,6 +14,7 @@
# limitations under the License.
#
import time
+import logging
from uuid import uuid4
from common.constants import StatusEnum
from api.db.db_models import Conversation, DB
@@ -26,6 +27,9 @@
from rag.prompts.generator import chunks_format
+logger = logging.getLogger(__name__)
+
+
class ConversationService(CommonService):
model = Conversation
@@ -201,9 +205,23 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses
break
yield answer
-async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
- e, dia = DialogService.get_by_id(dialog_id)
- assert e, "Dialog not found"
+async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, tenant_id=None, **kwargs):
+ if tenant_id:
+ exists, dia = DialogService.get_by_id(dialog_id)
+ if (not exists
+ or getattr(dia, "tenant_id", None) != tenant_id
+ or str(getattr(dia, "status", "")) != StatusEnum.VALID.value):
+ logger.warning(
+ "Dialog lookup failed for tenant-scoped iframe completion: "
+ "tenant_id=%s dialog_id=%s required_status=%s",
+ tenant_id,
+ dialog_id,
+ StatusEnum.VALID.value,
+ )
+ raise AssertionError("Dialog not found")
+ else:
+ e, dia = DialogService.get_by_id(dialog_id)
+ assert e, "Dialog not found"
if not session_id:
session_id = get_uuid()
conv = {
@@ -228,6 +246,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T
session_id = session_id
e, conv = API4ConversationService.get_by_id(session_id)
assert e, "Session not found!"
+ assert conv.dialog_id == dialog_id, "Session does not belong to this dialog"
if not conv.message:
conv.message = []
diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py
index cadf76c2aa8..6f981efb5e6 100644
--- a/api/db/services/dialog_service.py
+++ b/api/db/services/dialog_service.py
@@ -18,7 +18,10 @@
import logging
import re
import time
+import uuid
from copy import deepcopy
+
+logger = logging.getLogger(__name__)
from datetime import datetime
from functools import partial
from timeit import default_timer as timer
@@ -33,6 +36,10 @@
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle
from common.metadata_utils import apply_meta_data_filter
+from api.utils.reference_metadata_utils import (
+ enrich_chunks_with_document_metadata,
+ resolve_reference_metadata_preferences,
+)
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
from common.time_utils import current_timestamp, datetime_format
@@ -41,13 +48,22 @@
from rag.advanced_rag import DeepResearcher
from rag.app.tag import label_question
from rag.nlp.search import index_name
-from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
- PROMPT_JINJA_ENV, ASK_SUMMARY
+from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, PROMPT_JINJA_ENV, ASK_SUMMARY
from common.token_utils import num_tokens_from_string
from rag.utils.tavily_conn import Tavily
from common.string_utils import remove_redundant_spaces
from common import settings
+def _resolve_reference_metadata(request_payload=None, config=None):
+ return resolve_reference_metadata_preferences(request_payload or {}, config)
+
+def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None):
+ enrich_chunks_with_document_metadata(chunks, metadata_fields)
+
+def _chunk_kb_id_for_doc(row_dict, kb_ids, doc_id):
+ if len(kb_ids or []) == 1:
+ return kb_ids[0]
+ return row_dict.get("kb_id") or row_dict.get("kb_id_kwd")
def _normalize_internet_flag(value):
if isinstance(value, bool):
@@ -70,6 +86,15 @@ def _should_use_web_search(prompt_config, internet=None):
return normalized is True
+def _resolve_reference_metadata(config, request_payload=None):
+ return resolve_reference_metadata_preferences(request_payload or {}, config)
+
+
+def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None):
+ enrich_chunks_with_document_metadata(chunks, metadata_fields)
+
+
+
class DialogService(CommonService):
model = Dialog
@@ -168,8 +193,7 @@ def get_by_tenant_ids(
cls.model.select(*fields)
.join(User, on=(cls.model.tenant_id == User.id))
.where(
- (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id))
- & (cls.model.status == StatusEnum.VALID.value),
+ (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value),
)
)
if id:
@@ -210,22 +234,14 @@ def get_all_dialogs_by_tenant_id(cls, tenant_id):
@classmethod
@DB.connection_context()
def get_null_tenant_llm_id_row(cls):
- fields = [
- cls.model.id,
- cls.model.tenant_id,
- cls.model.llm_id
- ]
+ fields = [cls.model.id, cls.model.tenant_id, cls.model.llm_id]
objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null())
return list(objs)
@classmethod
@DB.connection_context()
def get_null_tenant_rerank_id_row(cls):
- fields = [
- cls.model.id,
- cls.model.tenant_id,
- cls.model.rerank_id
- ]
+ fields = [cls.model.id, cls.model.tenant_id, cls.model.rerank_id]
objs = cls.model.select(*fields).where(cls.model.tenant_rerank_id.is_null())
return list(objs)
@@ -241,7 +257,7 @@ async def async_chat_solo(dialog, messages, stream=True):
else:
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
attachments = "\n\n".join(text_attachments)
-
+
if dialog.llm_id:
model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
elif dialog.tenant_llm_id:
@@ -460,11 +476,11 @@ def find_and_replace(pattern, group_index=1, repl=lambda digits: f"ID:{digits}")
parts = []
last_idx = 0
for match in matches:
- parts.append(answer[last_idx:match.start()])
+ parts.append(answer[last_idx : match.start()])
try:
i = int(match.group(group_index))
except Exception:
- parts.append(answer[match.start():match.end()])
+ parts.append(answer[match.start() : match.end()])
last_idx = match.end()
continue
@@ -473,7 +489,7 @@ def find_and_replace(pattern, group_index=1, repl=lambda digits: f"ID:{digits}")
digits_original = answer[digit_start:digit_end]
parts.append(f"[{repl(digits_original)}]")
else:
- parts.append(answer[match.start():match.end()])
+ parts.append(answer[match.start() : match.end()])
last_idx = match.end()
parts.append(answer[last_idx:])
@@ -534,7 +550,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
attachments = None
if "doc_ids" in kwargs:
attachments = [doc_id for doc_id in kwargs["doc_ids"].split(",") if doc_id]
- attachments_= ""
+ attachments_ = ""
image_attachments = []
image_files = []
if "doc_ids" in messages[-1]:
@@ -547,6 +563,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
attachments_ = "\n\n".join(text_attachments)
prompt_config = dialog.prompt_config
+ include_reference_metadata, metadata_fields = _resolve_reference_metadata(prompt_config, request_payload=kwargs)
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
logging.debug(f"field_map retrieved: {field_map}")
# try to use sql if field mapping is good to go
@@ -555,6 +572,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
# For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid
if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")):
+ if include_reference_metadata and ans.get("reference", {}).get("chunks"):
+ if len(dialog.kb_ids) != 1 and any(not c.get("kb_id") for c in ans["reference"]["chunks"]):
+ logging.warning(
+ "Skipping some _enrich_chunks_with_document_metadata results because "
+ "dialog.kb_ids has %d entries and use_sql returned chunks without kb_id.",
+ len(dialog.kb_ids),
+ )
+ _enrich_chunks_with_document_metadata(ans["reference"]["chunks"], metadata_fields)
yield ans
return
else:
@@ -584,13 +609,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
if dialog.meta_data_filter:
- metas = DocMetadataService.get_flatted_meta_by_kbs(dialog.kb_ids)
attachments = await apply_meta_data_filter(
dialog.meta_data_filter,
- metas,
+ None,
questions[-1],
chat_mdl,
attachments,
+ kb_ids=dialog.kb_ids,
+ metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(dialog.kb_ids),
)
if prompt_config.get("keyword", False):
@@ -623,7 +649,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
internet_enabled=use_web_search,
)
queue = asyncio.Queue()
- async def callback(msg:str):
+
+ async def callback(msg: str):
nonlocal queue
await queue.put(msg + " ")
@@ -632,9 +659,9 @@ async def callback(msg:str):
while True:
msg = await queue.get()
if msg.find("") == 0:
- yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True}
+ yield {"answer": "", "reference": {}, "audio_binary": None, "final": False}
elif msg.find("") == 0:
- yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True}
+ yield {"answer": " ", "reference": {}, "audio_binary": None, "final": False}
break
else:
yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False}
@@ -670,25 +697,31 @@ async def callback(msg:str):
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT)
- ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
- LLMBundle(dialog.tenant_id, default_chat_model))
+ ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
+ if include_reference_metadata:
+ logging.debug(
+ "reference_metadata enrichment enabled for async_chat: chunk_count=%d metadata_fields=%s",
+ len(kbinfos.get("chunks", [])),
+ metadata_fields,
+ )
+ _enrich_chunks_with_document_metadata(kbinfos.get("chunks", []), metadata_fields)
+
knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
- yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
- "audio_binary": tts(tts_mdl, empty_res), "final": True}
+ yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res), "final": True}
return
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting
- msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}]
+ msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs) + attachments_}]
prompt4citation = ""
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
prompt4citation = citation_prompt()
@@ -783,8 +816,7 @@ def decorate_answer(answer):
if langfuse_tracer:
langfuse_generation = langfuse_tracer.start_generation(
- trace_context=trace_context, name="chat", model=llm_model_config["llm_name"],
- input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
+ trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
)
if stream:
@@ -802,7 +834,7 @@ def decorate_answer(answer):
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False}
full_answer = last_state.full_text if last_state else ""
if full_answer:
- final = decorate_answer(thought + full_answer)
+ final = decorate_answer(_extract_visible_answer(thought + full_answer))
final["final"] = True
final["audio_binary"] = None
yield final
@@ -821,6 +853,25 @@ def decorate_answer(answer):
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
+ """Answer a natural-language question by generating and executing SQL against the document index.
+
+ Detects the active document engine (Infinity, OceanBase, or Elasticsearch), asks the
+ chat model to produce the appropriate SQL, injects a validated kb_id filter, executes
+ the query, and returns formatted results with optional source citations.
+
+ Args:
+ question: Natural-language question from the user.
+ field_map: Mapping of field names to types describing the indexed document schema.
+ tenant_id: Tenant identifier used to derive the target index/table name.
+ chat_mdl: LLM bundle used to generate SQL from the question.
+ quota: Whether to enforce token-quota checks (default True).
+ kb_ids: Optional list of knowledge-base UUIDs to restrict the query scope.
+
+ Returns:
+ A dict with keys ``answer`` (formatted response string), ``reference``
+ (dict of supporting document chunks and doc_aggs), and ``prompt``
+ (the system prompt used), or ``None`` if SQL generation or execution fails.
+ """
logging.debug(f"use_sql: Question: {question}")
# Determine which document engine we're using
@@ -831,12 +882,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
else:
doc_engine = "es"
+ def _assert_valid_uuid(value: str, label: str = "id") -> None:
+ try:
+ uuid.UUID(str(value))
+ except (ValueError, AttributeError, TypeError):
+ logger.warning("SQL injection guard rejected invalid %s value (length=%d)", label, len(str(value)))
+ raise ValueError(f"Invalid {label} format: {value!r}")
+
# Construct the full table name
# For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause)
# For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table)
base_table = index_name(tenant_id)
if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1:
- # Infinity: append kb_id to table name
+ # Infinity: append kb_id to table name — validate before interpolating
+ _assert_valid_uuid(kb_ids[0], "kb_id")
table_name = f"{base_table}_{kb_ids[0]}"
logging.debug(f"use_sql: Using Infinity table name: {table_name}")
else:
@@ -847,13 +906,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N
expected_doc_name_column = "docnm" if doc_engine == "infinity" else "docnm_kwd"
def has_source_columns(columns):
+ """Return True if the result set contains the columns needed to build source citations."""
normalized_names = {str(col.get("name", "")).lower() for col in columns}
return "doc_id" in normalized_names and bool({"docnm_kwd", "docnm"} & normalized_names)
def is_aggregate_sql(sql_text):
+ """Return True if *sql_text* contains an aggregate function (COUNT, SUM, AVG, MAX, MIN, DISTINCT)."""
return bool(re.search(r"(count|sum|avg|max|min|distinct)\s*\(", (sql_text or "").lower()))
def normalize_sql(sql):
+ """Strip LLM artefacts from *sql* and return a clean, executable SQL string.
+
+ Removes ```` reasoning blocks, Chinese reasoning markers, markdown
+ code fences, and trailing semicolons that some engines reject.
+ """
logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}")
# Remove think blocks if present (format: ...)
sql = re.sub(r"\n.*?\n\s*", "", sql, flags=re.DOTALL)
@@ -862,18 +928,28 @@ def normalize_sql(sql):
sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE)
# Remove trailing semicolon that ES SQL parser doesn't like
- return sql.rstrip().rstrip(';').strip()
+ return sql.rstrip().rstrip(";").strip()
def add_kb_filter(sql):
+ """Inject a validated kb_id WHERE filter into *sql* for ES/OceanBase engines.
+
+ Infinity encodes the knowledge-base scope in the table name, so this
+ function is a no-op for that engine. All kb_id values are validated as
+ canonical UUIDs before interpolation to prevent SQL injection.
+ """
# Add kb_id filter for ES/OS only (Infinity already has it in table name)
if doc_engine == "infinity" or not kb_ids:
return sql
+ # Validate all kb_ids are UUIDs before interpolating into SQL
+ for kid in kb_ids:
+ _assert_valid_uuid(kid, "kb_id")
+
# Build kb_filter: single KB or multiple KBs with OR
if len(kb_ids) == 1:
kb_filter = f"kb_id = '{kb_ids[0]}'"
else:
- kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")"
+ kb_filter = "(" + " OR ".join([f"kb_id = '{kid}'" for kid in kb_ids]) + ")"
if "where " not in sql.lower():
o = sql.lower().split("order by")
@@ -886,6 +962,7 @@ def add_kb_filter(sql):
return sql
def is_row_count_question(q: str) -> bool:
+ """Return True if *q* is asking for a total row count of a dataset or table."""
q = (q or "").lower()
if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q):
return False
@@ -895,11 +972,7 @@ def is_row_count_question(q: str) -> bool:
if doc_engine == "infinity":
# Build Infinity prompts with JSON extraction context
json_field_names = list(field_map.keys())
- row_count_override = (
- f"SELECT COUNT(*) AS rows FROM {table_name}"
- if is_row_count_question(question)
- else None
- )
+ row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
@@ -923,19 +996,12 @@ def is_row_count_question(q: str) -> bool:
{}
Question: {}
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format(
- table_name,
- ", ".join(json_field_names),
- "\n".join([f" - {field}" for field in json_field_names]),
- question
+ table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question
)
elif doc_engine == "oceanbase":
# Build OceanBase prompts with JSON extraction context
json_field_names = list(field_map.keys())
- row_count_override = (
- f"SELECT COUNT(*) AS rows FROM {table_name}"
- if is_row_count_question(question)
- else None
- )
+ row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None
sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column.
JSON Extraction: json_extract_string(chunk_data, '$.FieldName')
@@ -959,10 +1025,7 @@ def is_row_count_question(q: str) -> bool:
{}
Question: {}
Write SQL using json_extract_string() with exact field names. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
- table_name,
- ", ".join(json_field_names),
- "\n".join([f" - {field}" for field in json_field_names]),
- question
+ table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question
)
else:
# Build ES/OS prompts with direct field access
@@ -980,11 +1043,7 @@ def is_row_count_question(q: str) -> bool:
Available fields:
{}
Question: {}
-Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(
- table_name,
- "\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
- question
- )
+Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question)
tried_times = 0
@@ -1022,13 +1081,7 @@ async def repair_table_for_missing_source_columns(previous_sql):
The previous SQL result is missing required source columns for citations.
Rewrite SQL to keep the same query intent and include doc_id and {} in the SELECT list.
For extracted JSON fields, use json_extract_string(chunk_data, '$.field_name').
-Return ONLY SQL.""".format(
- table_name,
- "\n".join([f" - {field}" for field in json_field_names]),
- question,
- previous_sql,
- expected_doc_name_column
- )
+Return ONLY SQL.""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, previous_sql, expected_doc_name_column)
else:
repair_prompt = """Table name: {}
Available fields:
@@ -1040,12 +1093,7 @@ async def repair_table_for_missing_source_columns(previous_sql):
The previous SQL result is missing required source columns for citations.
Rewrite SQL to keep the same query intent and include doc_id and docnm_kwd in the SELECT list.
-Return ONLY SQL.""".format(
- table_name,
- "\n".join([f" - {k} ({v})" for k, v in field_map.items()]),
- question,
- previous_sql
- )
+Return ONLY SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question, previous_sql)
return await get_table(custom_user_prompt=repair_prompt)
try:
@@ -1105,11 +1153,7 @@ async def repair_table_for_missing_source_columns(previous_sql):
logging.warning(f"use_sql: Non-aggregate SQL missing required source columns; retrying once. SQL: {sql}")
try:
repaired_tbl, repaired_sql = await repair_table_for_missing_source_columns(sql)
- if (
- repaired_tbl
- and len(repaired_tbl.get("rows", [])) > 0
- and has_source_columns(repaired_tbl.get("columns", []))
- ):
+ if repaired_tbl and len(repaired_tbl.get("rows", [])) > 0 and has_source_columns(repaired_tbl.get("columns", [])):
tbl, sql = repaired_tbl, repaired_sql
logging.info(f"use_sql: Source-column SQL repair succeeded. SQL: {sql}")
else:
@@ -1121,11 +1165,12 @@ async def repair_table_for_missing_source_columns(previous_sql):
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"])
doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]])
+ kb_id_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["kb_id", "kb_id_kwd"]])
logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}")
- logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}")
+ logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}, kb_id_idx={kb_id_idx}")
- column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
+ column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx | kb_id_idx)]
logging.debug(f"use_sql: column_idx={column_idx}")
logging.debug(f"use_sql: field_map={field_map}")
@@ -1137,9 +1182,9 @@ def map_column_name(col_name):
# First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.)
# Pattern: anything AS alias_name
- as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE)
+ as_match = re.search(r"\s+AS\s+([^\s,)]+)", col_name, re.IGNORECASE)
if as_match:
- alias = as_match.group(1).strip('"\'')
+ alias = as_match.group(1).strip("\"'")
# Use the alias for display name lookup
if alias in field_map:
@@ -1176,11 +1221,7 @@ def map_column_name(col_name):
return result
# compose Markdown table
- columns = (
- "|" + "|".join(
- [map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + (
- "|Source|" if docid_idx and doc_name_idx else "|")
- )
+ columns = "|" + "|".join([map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + ("|Source|" if docid_idx and doc_name_idx else "|")
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
@@ -1221,8 +1262,11 @@ def map_column_name(col_name):
where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE)
if where_match:
where_clause = where_match.group(1).strip()
- # Build a query to get doc_id and docnm_kwd with the same WHERE clause
- chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}"
+ # Build a query to get source fields with the same WHERE clause.
+ # Single-KB queries can derive kb_id from the dialog, while multi-KB
+ # ES/OS queries need the row value for metadata enrichment.
+ chunks_kb_column = ", kb_id" if not (kb_ids and len(kb_ids) == 1) else ""
+ chunks_sql = f"select doc_id, {expected_doc_name_column}{chunks_kb_column} from {table_name} where {where_clause}"
# Add LIMIT to avoid fetching too many chunks
if "limit" not in chunks_sql.lower():
chunks_sql += " limit 20"
@@ -1233,8 +1277,18 @@ def map_column_name(col_name):
# Build chunks reference - use case-insensitive matching
chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None)
chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None)
+ chunks_kb_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["kb_id", "kb_id_kwd"]), None)
if chunks_did_idx is not None and chunks_dn_idx is not None:
- chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]]
+ chunks = []
+ for r in chunks_tbl["rows"]:
+ chunk = {"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]}
+ row_dict = {chunks_tbl["columns"][i]["name"]: r[i] for i in range(len(chunks_tbl["columns"])) if i < len(r)}
+ kb_id = _chunk_kb_id_for_doc(row_dict, kb_ids, chunk["doc_id"])
+ if kb_id:
+ chunk["kb_id"] = kb_id
+ elif chunks_kb_idx is not None:
+ chunk["kb_id"] = r[chunks_kb_idx]
+ chunks.append(chunk)
# Build doc_aggs
doc_aggs = {}
for r in chunks_tbl["rows"]:
@@ -1264,7 +1318,22 @@ def map_column_name(col_name):
result = {
"answer": "\n".join([columns, line, rows]),
"reference": {
- "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
+ "chunks": [
+ {
+ key: value
+ for key, value in {
+ "doc_id": r[docid_idx],
+ "docnm_kwd": r[doc_name_idx],
+ "kb_id": _chunk_kb_id_for_doc(
+ {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)},
+ kb_ids,
+ r[docid_idx],
+ ),
+ }.items()
+ if value
+ }
+ for r in tbl["rows"]
+ ],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
},
"prompt": sys_prompt,
@@ -1272,6 +1341,7 @@ def map_column_name(col_name):
logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents")
return result
+
def clean_tts_text(text: str) -> str:
if not text:
return ""
@@ -1281,15 +1351,7 @@ def clean_tts_text(text: str) -> str:
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
emoji_pattern = re.compile(
- "[\U0001F600-\U0001F64F"
- "\U0001F300-\U0001F5FF"
- "\U0001F680-\U0001F6FF"
- "\U0001F1E0-\U0001F1FF"
- "\U00002700-\U000027BF"
- "\U0001F900-\U0001F9FF"
- "\U0001FA70-\U0001FAFF"
- "\U0001FAD0-\U0001FAFF]+",
- flags=re.UNICODE
+ "[\U0001f600-\U0001f64f\U0001f300-\U0001f5ff\U0001f680-\U0001f6ff\U0001f1e0-\U0001f1ff\U00002700-\U000027bf\U0001f900-\U0001f9ff\U0001fa70-\U0001faff\U0001fad0-\U0001faff]+", flags=re.UNICODE
)
text = emoji_pattern.sub("", text)
@@ -1301,6 +1363,7 @@ def clean_tts_text(text: str) -> str:
return text
+
def tts(tts_mdl, text):
if not tts_mdl or not text:
return None
@@ -1328,18 +1391,31 @@ def __init__(self) -> None:
self.buffer = ""
+def _extract_visible_answer(text: str) -> str:
+ text = text or ""
+ if "" not in text:
+ return re.sub(r"?think>", "", text)
+
+ thought, answer = text.rsplit("", 1)
+ thought = re.sub(r"?think>", "", thought).strip()
+ answer = re.sub(r"?think>", "", answer)
+ if not thought:
+ return answer
+ return f"{thought} {answer}"
+
+
def _next_think_delta(state: _ThinkStreamState) -> str:
full_text = state.full_text
if full_text == state.last_full:
return ""
state.last_full = full_text
- delta_ans = full_text[state.last_idx:]
+ delta_ans = full_text[state.last_idx :]
if delta_ans.find("") == 0:
state.last_idx += len("")
return ""
if delta_ans.find("") > 0:
- delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("")]
+ delta_text = full_text[state.last_idx : state.last_idx + delta_ans.find("")]
state.last_idx += delta_ans.find("")
return delta_text
if delta_ans.endswith(" "):
@@ -1360,7 +1436,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
if not chunk:
continue
if chunk.startswith(state.last_model_full):
- new_part = chunk[len(state.last_model_full):]
+ new_part = chunk[len(state.last_model_full) :]
state.last_model_full = chunk
else:
new_part = chunk
@@ -1394,6 +1470,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16):
if state.endswith_think:
yield ("marker", " ", state)
+
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
@@ -1401,6 +1478,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
chat_llm_name = search_config.get("chat_id", chat_llm_name)
rerank_id = search_config.get("rerank_id", "")
meta_data_filter = search_config.get("meta_data_filter")
+ include_reference_metadata, metadata_fields = _resolve_reference_metadata(search_config)
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
@@ -1419,8 +1497,15 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
if meta_data_filter:
- metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
- doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
+ doc_ids = await apply_meta_data_filter(
+ meta_data_filter,
+ None,
+ question,
+ chat_mdl,
+ doc_ids,
+ kb_ids=kb_ids,
+ metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
+ )
kbinfos = await retriever.retrieval(
question=question,
@@ -1435,8 +1520,15 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
doc_ids=doc_ids,
aggs=True,
rerank_mdl=rerank_mdl,
- rank_feature=label_question(question, kbs)
+ rank_feature=label_question(question, kbs),
)
+ if include_reference_metadata:
+ logging.debug(
+ "reference_metadata enrichment enabled for async_ask: chunk_count=%d metadata_fields=%s",
+ len(kbinfos.get("chunks", [])),
+ metadata_fields,
+ )
+ _enrich_chunks_with_document_metadata(kbinfos.get("chunks", []), metadata_fields)
knowledges = kb_prompt(kbinfos, max_tokens)
sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges))
@@ -1445,8 +1537,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
def decorate_answer(answer):
nonlocal knowledges, kbinfos, sys_prompt
- answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]],
- embd_mdl, tkweight=0.7, vtweight=0.3)
+ answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs:
@@ -1472,7 +1563,7 @@ def decorate_answer(answer):
continue
yield {"answer": value, "reference": {}, "final": False}
full_answer = last_state.full_text if last_state else ""
- final = decorate_answer(full_answer)
+ final = decorate_answer(_extract_visible_answer(full_answer))
final["final"] = True
yield final
@@ -1505,8 +1596,15 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
rerank_mdl = LLMBundle(tenant_id, rerank_model_config)
if meta_data_filter:
- metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids)
- doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
+ doc_ids = await apply_meta_data_filter(
+ meta_data_filter,
+ None,
+ question,
+ chat_mdl,
+ doc_ids,
+ kb_ids=kb_ids,
+ metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
+ )
ranks = await settings.retriever.retrieval(
question=question,
diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py
index 7a9e435e072..1cf887c2d3f 100644
--- a/api/db/services/doc_metadata_service.py
+++ b/api/db/services/doc_metadata_service.py
@@ -454,19 +454,27 @@ def update_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool:
# Index exists - check if document exists
try:
doc_exists = settings.docStoreConn.get(
- index_name=index_name,
- id=doc_id,
- kb_id=kb_id
+ doc_id,
+ index_name,
+ [kb_id]
)
if doc_exists:
- # Document exists - use partial update
+ # Document exists - replace meta_fields entirely
+ # Use upsert to fully replace the meta_fields field
+ # (ES update with doc parameter does deep merge on object fields,
+ # which would retain old keys that should be removed)
settings.docStoreConn.es.update(
index=index_name,
id=doc_id,
refresh=True,
- doc={"meta_fields": processed_meta}
+ body={
+ "script": {
+ "source": "ctx._source.meta_fields = params.meta_fields",
+ "params": {"meta_fields": processed_meta}
+ }
+ }
)
- logging.debug(f"Successfully updated metadata for document {doc_id} using ES partial update")
+ logging.debug(f"Successfully updated metadata for document {doc_id} using ES script update")
return True
except Exception as e:
logging.debug(f"Document {doc_id} not found in index, will insert: {e}")
@@ -764,6 +772,140 @@ def get_flatted_meta_by_kbs(cls, kb_ids: List[str]) -> Dict:
logging.error(f"Error getting flattened metadata for KBs {kb_ids}: {e}")
return {}
+ @classmethod
+ def filter_doc_ids_by_meta_pushdown(
+ cls,
+ kb_ids: List[str],
+ filters: List[Dict],
+ logic: str = "and",
+ limit: int = 10000,
+ ) -> Optional[List[str]]:
+ """Run a metadata filter directly against ES, returning matching doc IDs.
+
+ Returns ``None`` to signal "push-down not viable, use the in-memory
+ ``meta_filter`` fallback". Reasons for ``None``:
+
+ - Active doc store is not Elasticsearch (Infinity / OceanBase have
+ different filter semantics for the JSON ``meta_fields`` column).
+ - One of the user filters cannot be expressed in ES DSL.
+ - The ES request itself failed (network, mapping, missing index).
+
+ On success returns the deduplicated, ordered list of document IDs the
+ ES query matched. Callers can union or intersect this with their own
+ base ``doc_ids`` rather than fetching the entire metadata table.
+ """
+ from common.metadata_es_filter import (
+ UnsupportedMetaFilter,
+ build_meta_filter_query,
+ extract_doc_ids,
+ is_pushdown_supported,
+ )
+
+ if not kb_ids:
+ return []
+
+ if settings.DOC_ENGINE_INFINITY:
+ # Infinity stores ``meta_fields`` as a JSON column without dotted
+ # field access; the in-memory path is still the reliable answer.
+ return None
+
+ es_client = getattr(settings.docStoreConn, "es", None)
+ if es_client is None:
+ return None
+
+ if not is_pushdown_supported(filters):
+ return None
+
+ try:
+ kb = Knowledgebase.get_by_id(kb_ids[0])
+ except Exception as e:
+ logging.warning(f"[meta_pushdown] cannot resolve tenant for kb {kb_ids[0]}: {e}")
+ return None
+ if not kb:
+ return None
+
+ tenant_id = kb.tenant_id
+ index_name = cls._get_doc_meta_index_name(tenant_id)
+
+ try:
+ if not settings.docStoreConn.index_exist(index_name, ""):
+ # No metadata index → no metadata-filtered docs. Returning an
+ # empty list (rather than ``None``) so callers don't bounce
+ # back to the in-memory path and re-query MySQL for nothing.
+ return []
+ except Exception as e:
+ logging.warning(f"[meta_pushdown] index_exist check failed for {index_name}: {e}")
+ return None
+
+ try:
+ query_body = build_meta_filter_query(filters, logic, kb_ids)
+ except UnsupportedMetaFilter as e:
+ logging.debug(f"[meta_pushdown] falling back to in-memory: {e.reason}")
+ return None
+
+ # Only the doc id is needed downstream; trimming ``_source`` keeps the
+ # response small when the metadata blob is large.
+ request_body = {
+ **query_body,
+ "size": limit,
+ "_source": ["id"],
+ }
+
+ try:
+ response = es_client.search(index=index_name, body=request_body)
+ except Exception as e:
+ logging.warning(f"[meta_pushdown] ES query failed for {index_name}: {e}")
+ return None
+
+ doc_ids = extract_doc_ids(response if isinstance(response, dict) else dict(response))
+ # Preserve order while removing duplicates so caller-side de-dupe stays
+ # cheap.
+ seen: set[str] = set()
+ unique: List[str] = []
+ for did in doc_ids:
+ if did in seen:
+ continue
+ seen.add(did)
+ unique.append(did)
+
+ if len(unique) >= limit:
+ logging.warning(
+ f"[meta_pushdown] hit limit {limit} for KBs {kb_ids}; some matches may be missing"
+ )
+
+ logging.debug(f"[meta_pushdown] {len(unique)} matches for KBs {kb_ids}")
+ return unique
+
+ @classmethod
+ def get_metadata_keys_by_kbs(cls, kb_ids: List[str]) -> List[str]:
+ """
+ Get unique metadata field names across multiple knowledge bases.
+
+ Args:
+ kb_ids: List of knowledge base IDs
+
+ Returns:
+ Sorted list of unique metadata field names
+ """
+ if not kb_ids:
+ return []
+
+ logging.debug(f"get_metadata_keys_by_kbs start: n_kbs={len(kb_ids)}")
+ keys: set[str] = set()
+ try:
+ for kb_id in kb_ids:
+ results = cls._search_metadata(kb_id, condition={"kb_id": kb_id})
+ for _doc_id, doc in cls._iter_search_results(results):
+ doc_meta = cls._extract_metadata(doc)
+ if not isinstance(doc_meta, dict):
+ continue
+ keys.update(str(k) for k in doc_meta.keys())
+ logging.debug(f"get_metadata_keys_by_kbs end: n_keys={len(keys)}, kb_ids={kb_ids}")
+ return sorted(keys)
+ except Exception as e:
+ logging.error(f"Error getting metadata keys for KBs {kb_ids}: {e}")
+ return []
+
@classmethod
def get_metadata_for_documents(cls, doc_ids: Optional[List[str]], kb_id: str) -> Dict[str, Dict]:
"""
diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py
index 0c6e8b89195..7992cdb6105 100644
--- a/api/db/services/document_service.py
+++ b/api/db/services/document_service.py
@@ -13,15 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import asyncio
-import json
import logging
import random
-import re
-from concurrent.futures import ThreadPoolExecutor
-from copy import deepcopy
from datetime import datetime
-from io import BytesIO
import xxhash
from peewee import fn, Case, JOIN
@@ -33,13 +27,15 @@
from api.db.services.common_service import CommonService, retry_deadlock_operation
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.doc_metadata_service import DocMetadataService
+
+from common import settings
+from common.constants import ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME, MAXIMUM_TASK_PAGE_NUMBER
+from common.doc_store.doc_store_base import OrderByExpr
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, get_format_time
-from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
-from rag.nlp import rag_tokenizer, search
+
+from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
-from common.doc_store.doc_store_base import OrderByExpr
-from common import settings
class DocumentService(CommonService):
@@ -127,7 +123,7 @@ def check_doc_health(cls, tenant_id: str, filename):
@classmethod
@DB.connection_context()
- def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_id=None, name=None, doc_ids_filter=None, return_empty_metadata=False):
+ def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, name=None, doc_ids=None, return_empty_metadata=False):
fields = cls.get_cls_model_fields()
if keywords:
docs = (
@@ -147,10 +143,8 @@ def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keyword
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)
.where(cls.model.kb_id == kb_id)
)
- if doc_id:
- docs = docs.where(cls.model.id == doc_id)
- if doc_ids_filter:
- docs = docs.where(cls.model.id.in_(doc_ids_filter))
+ if doc_ids:
+ docs = docs.where(cls.model.id.in_(doc_ids))
if run_status:
docs = docs.where(cls.model.run.in_(run_status))
if types:
@@ -429,6 +423,9 @@ def remove_document(cls, doc, tenant_id):
if not cls.delete_document_and_update_kb_counts(doc.id):
return True
+ chunk_index_name = search.index_name(tenant_id)
+ chunk_index_exists = settings.docStoreConn.index_exist(chunk_index_name, doc.kb_id)
+
# Cancel all running tasks first Using preset function in task_service.py --- set cancel flag in Redis
try:
cancel_all_task_of(doc.id)
@@ -444,7 +441,8 @@ def remove_document(cls, doc, tenant_id):
# Delete chunk images (non-critical, log and continue)
try:
- cls.delete_chunk_images(doc, tenant_id)
+ if chunk_index_exists:
+ cls.delete_chunk_images(doc, tenant_id)
except Exception as e:
logging.warning(f"Failed to delete chunk images for document {doc.id}: {e}")
@@ -458,7 +456,7 @@ def remove_document(cls, doc, tenant_id):
# Delete chunks from doc store - this is critical, log errors
try:
- settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
+ settings.docStoreConn.delete({"doc_id": doc.id}, chunk_index_name, doc.kb_id)
except Exception as e:
logging.error(f"Failed to delete chunks from doc store for document {doc.id}: {e}")
@@ -470,23 +468,24 @@ def remove_document(cls, doc, tenant_id):
# Cleanup knowledge graph references (non-critical, log and continue)
try:
- graph_source = settings.docStoreConn.get_fields(
- settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]),
- ["source_id"],
- )
- if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
- settings.docStoreConn.update(
- {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
- {"remove": {"source_id": doc.id}},
- search.index_name(tenant_id),
- doc.kb_id,
- )
- settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id)
- settings.docStoreConn.delete(
- {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
- search.index_name(tenant_id),
- doc.kb_id,
+ if chunk_index_exists:
+ graph_source = settings.docStoreConn.get_fields(
+ settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, chunk_index_name, [doc.kb_id]),
+ ["source_id"],
)
+ if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
+ settings.docStoreConn.update(
+ {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
+ {"remove": {"source_id": doc.id}},
+ chunk_index_name,
+ doc.kb_id,
+ )
+ settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, chunk_index_name, doc.kb_id)
+ settings.docStoreConn.delete(
+ {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
+ chunk_index_name,
+ doc.kb_id,
+ )
except Exception as e:
logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}")
@@ -679,17 +678,10 @@ def get_tenant_id_by_name(cls, name):
@classmethod
@DB.connection_context()
def accessible(cls, doc_id, user_id):
- docs = (
- cls.model.select(cls.model.id)
- .join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id))
- .join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id))
- .where(cls.model.id == doc_id, UserTenant.user_id == user_id)
- .paginate(0, 1)
- )
- docs = docs.dicts()
- if not docs:
+ e, doc = cls.get_by_id(doc_id)
+ if not e:
return False
- return True
+ return KnowledgebaseService.accessible(doc.kb_id, user_id)
@classmethod
@DB.connection_context()
@@ -1002,8 +994,8 @@ def new_task():
return {
"id": get_uuid(),
"doc_id": fake_doc_id,
- "from_page": 100000000,
- "to_page": 100000000,
+ "from_page": MAXIMUM_TASK_PAGE_NUMBER,
+ "to_page": MAXIMUM_TASK_PAGE_NUMBER,
"task_type": ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
"begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
@@ -1027,138 +1019,3 @@ def get_queue_length(priority):
if not group_info:
return 0
return int(group_info.get("lag", 0) or 0)
-
-
-def doc_upload_and_parse(conversation_id, file_objs, user_id):
- from api.db.services.api_service import API4ConversationService
- from api.db.services.conversation_service import ConversationService
- from api.db.services.dialog_service import DialogService
- from api.db.services.file_service import FileService
- from api.db.services.llm_service import LLMBundle
- from api.db.services.user_service import TenantService
- from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
- from rag.app import audio, email, naive, picture, presentation
-
- e, conv = ConversationService.get_by_id(conversation_id)
- if not e:
- e, conv = API4ConversationService.get_by_id(conversation_id)
- assert e, "Conversation not found!"
-
- e, dia = DialogService.get_by_id(conv.dialog_id)
- if not dia.kb_ids:
- raise LookupError("No dataset associated with this conversation. Please add a dataset before uploading documents")
- kb_id = dia.kb_ids[0]
- e, kb = KnowledgebaseService.get_by_id(kb_id)
- if not e:
- raise LookupError("Can't find this dataset!")
- if kb.tenant_embd_id:
- embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
- else:
- embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
- embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, lang=kb.language)
-
- err, files = FileService.upload_document(kb, file_objs, user_id)
- assert not err, "\n".join(err)
-
- def dummy(prog=None, msg=""):
- pass
-
- FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
- parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
- exe = ThreadPoolExecutor(max_workers=12)
- threads = []
- doc_nm = {}
- for d, blob in files:
- doc_nm[d["id"]] = d["name"]
- for d, blob in files:
- kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language}
- threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))
-
- for (docinfo, _), th in zip(files, threads):
- docs = []
- doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]}
- for ck in th.result():
- d = deepcopy(doc)
- d.update(ck)
- d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
- d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
- d["create_timestamp_flt"] = datetime.now().timestamp()
- if not d.get("image"):
- docs.append(d)
- continue
-
- output_buffer = BytesIO()
- if isinstance(d["image"], bytes):
- output_buffer = BytesIO(d["image"])
- else:
- d["image"].save(output_buffer, format="JPEG")
-
- settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
- d["img_id"] = "{}-{}".format(kb.id, d["id"])
- d.pop("image", None)
- docs.append(d)
-
- parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
- docids = [d["id"] for d, _ in files]
- chunk_counts = {id: 0 for id in docids}
- token_counts = {id: 0 for id in docids}
- es_bulk_size = 64
-
- def embedding(doc_id, cnts, batch_size=16):
- nonlocal embd_mdl, chunk_counts, token_counts
- vectors = []
- for i in range(0, len(cnts), batch_size):
- vts, c = embd_mdl.encode(cnts[i : i + batch_size])
- vectors.extend(vts.tolist())
- chunk_counts[doc_id] += len(cnts[i : i + batch_size])
- token_counts[doc_id] += c
- return vectors
-
- idxnm = search.index_name(kb.tenant_id)
- try_create_idx = True
-
- _, tenant = TenantService.get_by_id(kb.tenant_id)
- tenant_llm_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
- llm_bdl = LLMBundle(kb.tenant_id, tenant_llm_config)
- for doc_id in docids:
- cks = [c for c in docs if c["doc_id"] == doc_id]
-
- if parser_ids[doc_id] != ParserType.PICTURE.value:
- from rag.graphrag.general.mind_map_extractor import MindMapExtractor
-
- mindmap = MindMapExtractor(llm_bdl)
- try:
- mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
- mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
- if len(mind_map) < 32:
- raise Exception("Few content: " + mind_map)
- cks.append(
- {
- "id": get_uuid(),
- "doc_id": doc_id,
- "kb_id": [kb.id],
- "docnm_kwd": doc_nm[doc_id],
- "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
- "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
- "content_with_weight": mind_map,
- "knowledge_graph_kwd": "mind_map",
- }
- )
- except Exception:
- logging.exception("Mind map generation error")
-
- vectors = embedding(doc_id, [c["content_with_weight"] for c in cks])
- assert len(cks) == len(vectors)
- for i, d in enumerate(cks):
- v = vectors[i]
- d["q_%d_vec" % len(v)] = v
- for b in range(0, len(cks), es_bulk_size):
- if try_create_idx:
- if not settings.docStoreConn.index_exist(idxnm, kb_id):
- settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
- try_create_idx = False
- settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id)
-
- DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
-
- return [d["id"] for d, _ in files]
diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py
index 11940b88c21..db8ae4b72f5 100644
--- a/api/db/services/file_service.py
+++ b/api/db/services/file_service.py
@@ -23,17 +23,20 @@
from pathlib import Path
from typing import Union
+logger = logging.getLogger(__name__)
+
import xxhash
from peewee import fn
-from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileType
+from api.db import KNOWLEDGEBASE_FOLDER_NAME, SKILLS_FOLDER_NAME, FileType
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase, Task
from api.db.services import duplicate_name
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from common.misc_utils import get_uuid
-from common.constants import TaskStatus, FileSource, ParserType
+from common.ssrf_guard import assert_url_is_safe
+from common.constants import TaskStatus, FileSource, ParserType, MAXIMUM_PAGE_NUMBER
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path
@@ -188,23 +191,24 @@ def get_all_file_ids_by_tenant_id(cls, tenant_id):
@classmethod
@DB.connection_context()
- def create_folder(cls, file, parent_id, name, count):
- from api.apps import current_user
+ def create_folder(cls, file, parent_id, name, count, tenant_id, created_by):
# Recursively create folder structure
# Args:
# file: Current file object
# parent_id: Parent folder ID
# name: List of folder names to create
# count: Current depth in creation
+ # tenant_id: Tenant ID
+ # created_by: Created by user ID
# Returns:
# Created file object
if count > len(name) - 2:
return file
else:
file = cls.insert(
- {"id": get_uuid(), "parent_id": parent_id, "tenant_id": current_user.id, "created_by": current_user.id, "name": name[count], "location": "", "size": 0, "type": FileType.FOLDER.value}
+ {"id": get_uuid(), "parent_id": parent_id, "tenant_id": tenant_id, "created_by": created_by, "name": name[count], "location": "", "size": 0, "type": FileType.FOLDER.value}
)
- return cls.create_folder(file, file.id, name, count + 1)
+ return cls.create_folder(file, file.id, name, count + 1, tenant_id, created_by)
@classmethod
@DB.connection_context()
@@ -290,6 +294,28 @@ def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value
cls.save(**file)
return file
+ @classmethod
+ @DB.connection_context()
+ def init_skills_folder(cls, root_id, tenant_id):
+ # Initialize skills folder if not exists
+ # Args:
+ # root_id: Root folder ID
+ # tenant_id: Tenant ID
+ for _ in cls.model.select().where((cls.model.name == SKILLS_FOLDER_NAME) & (cls.model.parent_id == root_id)):
+ return
+ file_id = get_uuid()
+ file = {
+ "id": file_id,
+ "parent_id": root_id,
+ "tenant_id": tenant_id,
+ "created_by": tenant_id,
+ "name": SKILLS_FOLDER_NAME,
+ "type": FileType.FOLDER.value,
+ "size": 0,
+ "location": "",
+ }
+ cls.save(**file)
+
@classmethod
@DB.connection_context()
def init_knowledgebase_docs(cls, root_id, tenant_id):
@@ -550,7 +576,7 @@ def dummy(prog=None, msg=""):
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": layout_recognize or "Plain Text"}
- kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": current_user.id if current_user else tenant_id}
+ kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": MAXIMUM_PAGE_NUMBER, "tenant_id": current_user.id if current_user else tenant_id}
file_type = filename_type(filename)
if img_base64 and file_type == FileType.VISUAL.value:
return GptV4.image2base64(blob)
@@ -624,6 +650,26 @@ def delete_docs(cls, doc_ids, tenant_id):
return errors
+ _ALLOWED_SCHEMES = {"http", "https"}
+
+ @staticmethod
+ def _validate_url_for_crawl(url: str) -> tuple[str, str]:
+ """Raise ValueError if the URL is not safe to crawl (SSRF guard).
+
+ Delegates to :func:`common.ssrf_guard.assert_url_is_safe`, which
+ validates the scheme, hostname, and every DNS-resolved address, and
+ returns ``(hostname, resolved_ip)`` for DNS pinning.
+
+ Only the scheme and host (and port when present) are forwarded to the
+ guard so that credentials or query parameters in *url* are never
+ written to the log.
+ """
+ from urllib.parse import urlparse
+ parsed = urlparse(url)
+ port_suffix = f":{parsed.port}" if parsed.port else ""
+ redacted = f"{parsed.scheme}://{parsed.hostname}{port_suffix}"
+ return assert_url_is_safe(redacted, allowed_schemes=FileService._ALLOWED_SCHEMES)
+
@staticmethod
def upload_info(user_id, file, url: str|None=None):
def structured(filename, filetype, blob, content_type):
@@ -646,6 +692,53 @@ def structured(filename, filetype, blob, content_type):
}
if url:
+ import requests as _requests
+ from urllib.parse import urljoin as _urljoin
+
+ _MAX_CRAWL_REDIRECTS = 10
+
+ # Pre-resolve the full redirect chain so that AsyncWebCrawler never
+ # follows a server-sent redirect to an unvalidated (potentially
+ # internal) host. Each hop is SSRF-checked before being followed;
+ # the validated (hostname, ip) pairs are pinned via Chromium's
+ # --host-resolver-rules so the browser cannot re-resolve any of them
+ # through a fresh DNS query.
+ current_url = url
+ current_hostname, current_ip = FileService._validate_url_for_crawl(current_url)
+ # Accumulate MAP rules for every hostname we encounter in the chain.
+ host_pins: dict[str, str] = {current_hostname: current_ip}
+
+ for _ in range(_MAX_CRAWL_REDIRECTS):
+ try:
+ _resp = _requests.get(
+ current_url,
+ timeout=10,
+ allow_redirects=False,
+ )
+ except _requests.RequestException as _exc:
+ raise ValueError(f"Failed to fetch {current_url!r}: {_exc}") from _exc
+
+ if _resp.status_code not in (301, 302, 303, 307, 308):
+ break
+
+ _location = _resp.headers.get("Location")
+ if not _location:
+ break
+
+ _next_url = _urljoin(current_url, _location)
+ _next_hostname, _next_ip = FileService._validate_url_for_crawl(_next_url)
+ host_pins[_next_hostname] = _next_ip
+ current_url = _next_url
+ else:
+ raise ValueError(
+ f"Exceeded {_MAX_CRAWL_REDIRECTS} redirects fetching {url!r}"
+ )
+
+ # Build a single MAP rule string covering every validated hostname
+ # in the redirect chain. Chromium uses the pinned IP for each,
+ # skipping DNS entirely and eliminating the rebinding window.
+ _map_rules = ",".join(f"MAP {h} {ip}" for h, ip in host_pins.items())
+
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
@@ -659,6 +752,7 @@ async def adownload():
browser_config = BrowserConfig(
headless=True,
verbose=False,
+ extra_args=[f"--host-resolver-rules={_map_rules}"],
)
async with AsyncWebCrawler(config=browser_config) as crawler:
crawler_config = CrawlerRunConfig(
@@ -668,8 +762,10 @@ async def adownload():
pdf=True,
screenshot=False
)
+ # Use the final resolved URL so the browser starts at the
+ # redirect destination rather than re-following the chain.
result: CrawlResult = await crawler.arun(
- url=url,
+ url=current_url,
config=crawler_config
)
return result
@@ -679,7 +775,7 @@ async def adownload():
filename += ".pdf"
return structured(filename, "pdf", page.pdf, page.response_headers["content-type"])
- return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id)
+ return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"])
DocumentService.check_doc_health(user_id, file.filename)
return structured(file.filename, filename_type(file.filename), file.read(), file.content_type)
diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py
index c66d66a6821..a164287fa4e 100644
--- a/api/db/services/knowledgebase_service.py
+++ b/api/db/services/knowledgebase_service.py
@@ -18,7 +18,7 @@
from peewee import fn, JOIN
from api.db import TenantPermission
-from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas
+from api.db.db_models import DB, Document, Knowledgebase, User, UserCanvas
from api.db.services.common_service import CommonService
from common.time_utils import current_timestamp, datetime_format
from api.db.services import duplicate_name
@@ -485,13 +485,21 @@ def accessible(cls, kb_id, user_id):
# user_id: User ID
# Returns:
# Boolean indicating accessibility
- docs = cls.model.select(
- cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
- ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
- docs = docs.dicts()
- if not docs:
+ e, kb = cls.get_by_id(kb_id)
+ if not e:
return False
- return True
+
+ if kb.status != StatusEnum.VALID.value:
+ return False
+
+ if kb.tenant_id == user_id:
+ return True
+
+ if kb.permission != TenantPermission.TEAM.value:
+ return False
+
+ joined_tenants = TenantService.get_joined_tenants_by_user_id(user_id)
+ return any(tenant["tenant_id"] == kb.tenant_id for tenant in joined_tenants)
@classmethod
@DB.connection_context()
@@ -502,10 +510,10 @@ def get_kb_by_id(cls, kb_id, user_id):
# user_id: User ID
# Returns:
# List containing dataset information
- kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
- ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
- kbs = kbs.dicts()
- return list(kbs)
+ e, kb = cls.get_by_id(kb_id)
+ if not e or not cls.accessible(kb_id, user_id):
+ return []
+ return [kb.to_dict()]
@classmethod
@DB.connection_context()
@@ -516,10 +524,11 @@ def get_kb_by_name(cls, kb_name, user_id):
# user_id: User ID
# Returns:
# List containing dataset information
- kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
- ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
- kbs = kbs.dicts()
- return list(kbs)
+ kbs = cls.query(name=kb_name, status=StatusEnum.VALID.value)
+ for kb in kbs:
+ if cls.accessible(kb.id, user_id):
+ return [kb.to_dict()]
+ return []
@classmethod
@DB.connection_context()
diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py
index 6058c6b69f7..60090bb0409 100644
--- a/api/db/services/llm_service.py
+++ b/api/db/services/llm_service.py
@@ -94,7 +94,7 @@ def bind_tools(self, toolcall_session, tools):
def encode(self, texts: list):
if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.model_config["llm_name"], input={"texts": texts})
+ generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode", model=self.model_config["llm_name"], input={"texts": texts})
safe_texts = []
for text in texts:
@@ -119,7 +119,7 @@ def encode(self, texts: list):
def encode_queries(self, query: str):
if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.model_config["llm_name"], input={"query": query})
+ generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode_queries", model=self.model_config["llm_name"], input={"query": query})
emd, used_tokens = self.mdl.encode_queries(query)
if self.model_config["llm_factory"] == "Builtin":
@@ -135,7 +135,7 @@ def encode_queries(self, query: str):
def similarity(self, query: str, texts: list):
if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
+ generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts})
sim, used_tokens = self.mdl.similarity(query, texts)
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
@@ -149,7 +149,7 @@ def similarity(self, query: str, texts: list):
def describe(self, image, max_tokens=300):
if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.model_config["llm_name"]})
+ generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe", metadata={"model": self.model_config["llm_name"]})
txt, used_tokens = self.mdl.describe(image)
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
@@ -163,7 +163,7 @@ def describe(self, image, max_tokens=300):
def describe_with_prompt(self, image, prompt):
if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
+ generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt})
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
@@ -177,7 +177,7 @@ def describe_with_prompt(self, image, prompt):
def transcription(self, audio):
if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.model_config["llm_name"]})
+ generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="transcription", metadata={"model": self.model_config["llm_name"]})
txt, used_tokens = self.mdl.transcription(audio)
if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens):
@@ -194,7 +194,7 @@ def stream_transcription(self, audio):
supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription"))
if supports_stream:
if self.langfuse:
- generation = self.langfuse.start_generation(
+ generation = self.langfuse.start_observation(as_type="generation",
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.model_config["llm_name"]},
@@ -228,7 +228,7 @@ def stream_transcription(self, audio):
return
if self.langfuse:
- generation = self.langfuse.start_generation(
+ generation = self.langfuse.start_observation(as_type="generation",
trace_context=self.trace_context,
name="stream_transcription",
metadata={"model": self.model_config["llm_name"]},
@@ -253,7 +253,7 @@ def stream_transcription(self, audio):
def tts(self, text: str) -> Generator[bytes, None, None]:
if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
+ generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="tts", input={"text": text})
for chunk in self.mdl.tts(text):
if isinstance(chunk, int):
@@ -376,7 +376,7 @@ async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kw
generation = None
if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
+ generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history})
chat_partial = partial(base_fn, system, history, gen_conf)
use_kwargs = self._clean_param(chat_partial, **kwargs)
@@ -417,7 +417,7 @@ async def async_chat_streamly(self, system: str, history: list, gen_conf: dict =
generation = None
if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
+ generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
@@ -460,7 +460,7 @@ async def async_chat_streamly_delta(self, system: str, history: list, gen_conf:
generation = None
if self.langfuse:
- generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
+ generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history})
if stream_fn:
chat_partial = partial(stream_fn, system, history, gen_conf)
diff --git a/api/db/services/memory_service.py b/api/db/services/memory_service.py
index d2433d01d0e..530fc5ad9ea 100644
--- a/api/db/services/memory_service.py
+++ b/api/db/services/memory_service.py
@@ -92,6 +92,11 @@ def get_by_filter(cls, filter_dict: dict, keywords: str, page: int = 1, page_siz
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
if filter_dict.get("tenant_id"):
memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"]))
+ if filter_dict.get("accessible_user_id"):
+ memories = memories.where(
+ (cls.model.tenant_id == filter_dict["accessible_user_id"]) |
+ (cls.model.permissions == "team")
+ )
if filter_dict.get("memory_type"):
memory_type_int = calculate_memory_type(filter_dict["memory_type"])
memories = memories.where(cls.model.memory_type.bin_and(memory_type_int) > 0)
diff --git a/api/db/services/pipeline_operation_log_service.py b/api/db/services/pipeline_operation_log_service.py
index 344e2381b7e..ad90acb1f34 100644
--- a/api/db/services/pipeline_operation_log_service.py
+++ b/api/db/services/pipeline_operation_log_service.py
@@ -250,20 +250,16 @@ def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, des
@DB.connection_context()
def get_documents_info(cls, id):
fields = [Document.id, Document.name, Document.progress, Document.kb_id]
- return (
- cls.model.select(*fields)
- .join(Document, on=(cls.model.document_id == Document.id))
- .where(
- cls.model.id == id
- )
- .dicts()
- )
+ return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where(cls.model.id == id).dicts()
@classmethod
@DB.connection_context()
- def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None):
+ def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None, keywords=None):
fields = cls.get_dataset_logs_fields()
- logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
+ if keywords:
+ logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
+ else:
+ logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status))
diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py
index 80817323076..640c8fbd25e 100644
--- a/api/db/services/task_service.py
+++ b/api/db/services/task_service.py
@@ -29,7 +29,7 @@
from api.db.services.document_service import DocumentService
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp
-from common.constants import StatusEnum, TaskStatus
+from common.constants import StatusEnum, TaskStatus, MAXIMUM_PAGE_NUMBER, MAXIMUM_TASK_PAGE_NUMBER
from deepdoc.parser.excel_parser import RAGFlowExcelParser
from rag.utils.redis_conn import REDIS_CONN
from common import settings
@@ -37,6 +37,7 @@
CANVAS_DEBUG_DOC_ID = "dataflow_x"
GRAPH_RAPTOR_FAKE_DOC_ID = "graph_raptor_x"
+TASK_MAX_LOG_LENGTH = int(os.environ.get("TASK_MAX_LOG_LENGTH", 3000)) # TEXT MAX is 64 KiB bytes!
def trim_header_by_lines(text: str, max_length) -> str:
# Trim header text to maximum length while preserving line breaks
@@ -320,7 +321,7 @@ def update_progress(cls, id, info):
if os.environ.get("MACOS"):
if info["progress_msg"]:
- progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
+ progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], TASK_MAX_LOG_LENGTH)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
prog = info["progress"]
@@ -332,7 +333,7 @@ def update_progress(cls, id, info):
else:
with DB.lock("update_progress", -1):
if info["progress_msg"]:
- progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
+ progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], TASK_MAX_LOG_LENGTH)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
prog = info["progress"]
@@ -379,7 +380,7 @@ def new_task():
"doc_id": doc["id"],
"progress": 0.0,
"from_page": 0,
- "to_page": 100000000,
+ "to_page": MAXIMUM_TASK_PAGE_NUMBER,
"begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
@@ -395,8 +396,8 @@ def new_task():
if doc["parser_id"] == "paper":
page_size = doc["parser_config"].get("task_page_size") or 22
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC" or doc["parser_config"].get("toc_extraction", False):
- page_size = 10 ** 9
- page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
+ page_size = MAXIMUM_TASK_PAGE_NUMBER
+ page_ranges = doc["parser_config"].get("pages") or [(1, MAXIMUM_PAGE_NUMBER)]
for s, e in page_ranges:
s -= 1
s = max(0, s)
@@ -495,7 +496,7 @@ def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config:
return 0
task["chunk_ids"] = prev_task["chunk_ids"]
task["progress"] = 1.0
- if "from_page" in task and "to_page" in task and int(task['to_page']) - int(task['from_page']) >= 10 ** 6:
+ if "from_page" in task and "to_page" in task and (int(task['to_page']) - int(task['from_page']) >= 10 ** 6 or (int(task['from_page']) == MAXIMUM_TASK_PAGE_NUMBER and int(task['to_page']) == MAXIMUM_TASK_PAGE_NUMBER)):
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): "
else:
task["progress_msg"] = ""
@@ -530,7 +531,7 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE
id=task_id,
doc_id=doc_id,
from_page=0,
- to_page=100000000,
+ to_page=MAXIMUM_TASK_PAGE_NUMBER,
task_type="dataflow" if not rerun else "dataflow_rerun",
priority=priority,
begin_at= datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py
index a27f1352d44..ee2eab6648a 100644
--- a/api/db/services/tenant_llm_service.py
+++ b/api/db/services/tenant_llm_service.py
@@ -19,7 +19,7 @@
from peewee import IntegrityError
from langfuse import Langfuse
from common import settings
-from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, PADDLEOCR_DEFAULT_CONFIG, PADDLEOCR_ENV_KEYS, LLMType
+from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, OPENDATALOADER_DEFAULT_CONFIG, OPENDATALOADER_ENV_KEYS, PADDLEOCR_DEFAULT_CONFIG, PADDLEOCR_ENV_KEYS, LLMType
from api.db.db_models import DB, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService
from api.db.services.langfuse_service import TenantLangfuseService
@@ -34,6 +34,42 @@ class LLMFactoriesService(CommonService):
class TenantLLMService(CommonService):
model = TenantLLM
+ @staticmethod
+ def _decode_api_key_config(raw_api_key: str) -> tuple[str, bool | None, str | None]:
+ if not raw_api_key:
+ return raw_api_key, None, None
+
+ try:
+ parsed = json.loads(raw_api_key)
+ except Exception:
+ return raw_api_key, None, None
+
+ if not isinstance(parsed, dict):
+ return raw_api_key, None, None
+
+ is_tools = bool(parsed["is_tools"]) if "is_tools" in parsed else None
+ if set(parsed.keys()) <= {"api_key", "is_tools"}:
+ return parsed.get("api_key", ""), is_tools, None
+
+ return parsed.get("api_key", raw_api_key), is_tools, raw_api_key
+
+ @staticmethod
+ def _encode_api_key_config(raw_api_key: str, is_tools: bool | None) -> str:
+ if is_tools is None:
+ return raw_api_key
+
+ try:
+ parsed = json.loads(raw_api_key or "{}")
+ except Exception:
+ parsed = None
+
+ if isinstance(parsed, dict):
+ payload = dict(parsed)
+ payload["is_tools"] = bool(is_tools)
+ return json.dumps(payload)
+
+ return json.dumps({"api_key": raw_api_key or "", "is_tools": bool(is_tools)})
+
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name, model_type=None):
@@ -123,6 +159,12 @@ def get_model_config(cls, tenant_id, llm_type, llm_name=None):
model_config = cls.get_api_key(tenant_id, mdlnm, llm_type)
if model_config:
model_config = model_config.to_dict()
+ api_key, is_tools, api_key_payload = cls._decode_api_key_config(model_config.get("api_key", ""))
+ model_config["api_key"] = api_key
+ if api_key_payload is not None:
+ model_config["api_key_payload"] = api_key_payload
+ if is_tools is not None:
+ model_config["is_tools"] = is_tools
elif llm_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv("TEI_MODEL", ""):
embedding_cfg = settings.EMBEDDING_CFG
model_config = {"llm_factory": "Builtin", "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
@@ -132,7 +174,7 @@ def get_model_config(cls, tenant_id, llm_type, llm_name=None):
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if not llm and fid: # for some cases seems fid mismatch
llm = LLMService.query(llm_name=mdlnm)
- if llm:
+ if "is_tools" not in model_config and llm:
model_config["is_tools"] = llm[0].is_tools
return model_config
@@ -142,35 +184,36 @@ def model_instance(cls, model_config: dict, lang="Chinese", **kwargs):
if not model_config:
raise LookupError("Model config is required")
kwargs.update({"provider": model_config["llm_factory"]})
+ api_key = model_config.get("api_key_payload", model_config["api_key"])
if model_config["model_type"] == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return None
- return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
+ return EmbeddingModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"])
elif model_config["model_type"] == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return None
- return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
+ return RerankModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"])
elif model_config["model_type"] == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return None
- return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
+ return CvModel[model_config["llm_factory"]](api_key, model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
elif model_config["model_type"] == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return None
- return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
+ return ChatModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
elif model_config["model_type"] == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return None
- return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
+ return Seq2txtModel[model_config["llm_factory"]](key=api_key, model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
elif model_config["model_type"] == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return None
return TTSModel[model_config["llm_factory"]](
- model_config["api_key"],
+ api_key,
model_config["llm_name"],
base_url=model_config["api_base"],
)
@@ -179,7 +222,7 @@ def model_instance(cls, model_config: dict, lang="Chinese", **kwargs):
if model_config["llm_factory"] not in OcrModel:
return None
return OcrModel[model_config["llm_factory"]](
- key=model_config["api_key"],
+ key=api_key,
model_name=model_config["llm_name"],
base_url=model_config.get("api_base", ""),
**kwargs,
@@ -364,6 +407,67 @@ def _parse_api_key(raw: str) -> dict:
idx += 1
continue
+ @classmethod
+ def _collect_opendataloader_env_config(cls) -> dict | None:
+ cfg = dict(OPENDATALOADER_DEFAULT_CONFIG)
+ found = False
+ for key in OPENDATALOADER_ENV_KEYS:
+ val = os.environ.get(key)
+ if val:
+ found = True
+ cfg[key] = val
+ return cfg if found else None
+
+ @classmethod
+ @DB.connection_context()
+ def ensure_opendataloader_from_env(cls, tenant_id: str) -> str | None:
+ """
+ Ensure an OpenDataLoader OCR model exists for the tenant if env variables are present.
+ Return the existing or newly created llm_name, or None if env not set.
+ """
+ cfg = cls._collect_opendataloader_env_config()
+ if not cfg:
+ return None
+
+ saved_models = cls.query(tenant_id=tenant_id, llm_factory="OpenDataLoader", model_type=LLMType.OCR.value)
+
+ def _parse_api_key(raw: str) -> dict:
+ try:
+ return json.loads(raw or "{}")
+ except Exception:
+ return {}
+
+ for item in saved_models:
+ api_cfg = _parse_api_key(item.api_key)
+ normalized = {k: api_cfg.get(k, OPENDATALOADER_DEFAULT_CONFIG.get(k)) for k in OPENDATALOADER_ENV_KEYS}
+ if normalized == cfg:
+ return item.llm_name
+
+ used_names = {item.llm_name for item in saved_models}
+ idx = 1
+ base_name = "opendataloader-from-env"
+ while True:
+ candidate = f"{base_name}-{idx}"
+ if candidate in used_names:
+ idx += 1
+ continue
+ try:
+ cls.save(
+ tenant_id=tenant_id,
+ llm_factory="OpenDataLoader",
+ llm_name=candidate,
+ model_type=LLMType.OCR.value,
+ api_key=json.dumps(cfg),
+ api_base="",
+ max_tokens=0,
+ )
+ return candidate
+ except IntegrityError:
+ logging.warning("OpenDataLoader env model %s already exists for tenant %s, retry with next name", candidate, tenant_id)
+ used_names.add(candidate)
+ idx += 1
+ continue
+
@classmethod
@DB.connection_context()
def delete_by_tenant_id(cls, tenant_id):
@@ -397,7 +501,7 @@ def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs)
self.llm_name = model_config["llm_name"]
self.model_config = model_config
self.mdl = TenantLLMService.model_instance(model_config, lang=lang, **kwargs)
- assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, model_config["llm_type"], model_config["llm_name"])
+ assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, model_config["model_type"], model_config["llm_name"])
self.max_length = model_config.get("max_tokens", 8192)
self.is_tools = model_config.get("is_tools", False)
diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py
index fe6f6d0d445..a041ee0819f 100644
--- a/api/utils/api_utils.py
+++ b/api/utils/api_utils.py
@@ -325,7 +325,7 @@ async def wrapper(*args, **kwargs):
from common import settings
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
try:
- jwt = Serializer(secret_key=settings.SECRET_KEY)
+ jwt = Serializer(secret_key=settings.get_secret_key())
raw_token = str(jwt.loads(token))
user = UserService.query(access_token=raw_token, status=StatusEnum.VALID.value)
if user:
diff --git a/api/utils/health_utils.py b/api/utils/health_utils.py
index 288eb79ff67..34f098b8c92 100644
--- a/api/utils/health_utils.py
+++ b/api/utils/health_utils.py
@@ -293,7 +293,7 @@ def check_ragflow_server_alive():
url = f'http://{settings.HOST_IP}:{settings.HOST_PORT}/api/v1/system/ping'
if '0.0.0.0' in url:
url = url.replace('0.0.0.0', '127.0.0.1')
- response = requests.get(url)
+ response = requests.get(url, timeout=10)
if response.status_code == 200:
return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
else:
diff --git a/api/utils/reference_metadata_utils.py b/api/utils/reference_metadata_utils.py
new file mode 100644
index 00000000000..58d5beffb0a
--- /dev/null
+++ b/api/utils/reference_metadata_utils.py
@@ -0,0 +1,125 @@
+#
+# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def resolve_reference_metadata_preferences(
+ request_payload: dict | None = None,
+ config_payload: dict | None = None,
+) -> tuple[bool, set[str] | None]:
+ """
+ Resolve metadata include/fields from request and optional config.
+ Request values take precedence over config values.
+ Supports legacy request keys: include_metadata / metadata_fields.
+ """
+ request_payload = request_payload or {}
+ config_payload = config_payload or {}
+
+ config_ref = config_payload.get("reference_metadata", {})
+ request_ref = request_payload.get("reference_metadata", {})
+
+ resolved: dict = {}
+ if isinstance(config_ref, dict):
+ resolved.update(config_ref)
+ if isinstance(request_ref, dict):
+ resolved.update(request_ref)
+
+ if "include_metadata" in request_payload:
+ resolved["include"] = bool(request_payload.get("include_metadata"))
+ if "metadata_fields" in request_payload:
+ resolved["fields"] = request_payload.get("metadata_fields")
+
+ include_metadata = bool(resolved.get("include", False))
+ fields = resolved.get("fields")
+ if fields is None:
+ return include_metadata, None
+ if not isinstance(fields, list):
+ logger.warning(
+ "reference_metadata.fields is not a list; include_metadata=%s fields=%r type=%s resolved=%r. "
+ "enrich_chunks_with_document_metadata will skip enrichment.",
+ include_metadata,
+ fields,
+ type(fields).__name__,
+ resolved,
+ )
+ return include_metadata, set()
+ return include_metadata, {f for f in fields if isinstance(f, str)}
+
+
+def enrich_chunks_with_document_metadata(
+ chunks: list[dict],
+ metadata_fields: set[str] | None = None,
+ *,
+ kb_field: str = "kb_id",
+ doc_field: str = "doc_id",
+ output_field: str = "document_metadata",
+) -> None:
+ """
+ Mutates chunk payloads in-place by attaching `document_metadata`.
+ Field names can be customized for different chunk schemas.
+ """
+ if metadata_fields is not None and not metadata_fields:
+ return
+
+ doc_ids_by_kb: dict[str, set[str]] = {}
+ for chunk in chunks:
+ kb_ids = chunk.get(kb_field)
+ doc_id = chunk.get(doc_field)
+ if not kb_ids or not doc_id:
+ continue
+ if isinstance(kb_ids, (list, tuple)):
+ for kid in kb_ids:
+ if kid:
+ doc_ids_by_kb.setdefault(kid, set()).add(doc_id)
+ else:
+ doc_ids_by_kb.setdefault(kb_ids, set()).add(doc_id)
+
+ if not doc_ids_by_kb:
+ return
+
+ # Resolve service lazily so callers/tests that swap service modules at runtime
+ # (e.g. via monkeypatch) don't get stuck with a stale class reference.
+ from api.db.services.doc_metadata_service import DocMetadataService
+ metadata_getter = getattr(DocMetadataService, "get_metadata_for_documents", None)
+ if not callable(metadata_getter):
+ logging.warning(
+ "DocMetadataService.get_metadata_for_documents is unavailable; "
+ "skipping metadata enrichment."
+ )
+ return
+
+ meta_by_doc: dict[str, dict] = {}
+ for kb_id, doc_ids in doc_ids_by_kb.items():
+ meta_map = metadata_getter(list(doc_ids), kb_id)
+ if meta_map:
+ meta_by_doc.update(meta_map)
+ logging.debug("Fetched metadata for %d docs in kb_id=%s", len(meta_map), kb_id)
+
+ for chunk in chunks:
+ doc_id = chunk.get(doc_field)
+ if not doc_id:
+ continue
+ meta = meta_by_doc.get(doc_id)
+ if not meta:
+ continue
+ if metadata_fields is not None:
+ meta = {k: v for k, v in meta.items() if k in metadata_fields}
+ if meta:
+ chunk[output_field] = meta
+ logging.debug("Enriched chunk for doc_id=%s with %d metadata fields: %s", doc_id, len(meta), list(meta.keys()))
diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py
index acce4926277..94e0fa2ab83 100644
--- a/api/utils/validation_utils.py
+++ b/api/utils/validation_utils.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import logging
import math
import pathlib
import re
@@ -22,16 +23,7 @@
from uuid import UUID
from quart import Request
-from pydantic import (
- BaseModel,
- ConfigDict,
- Field,
- StringConstraints,
- ValidationError,
- field_validator,
- model_validator,
- ValidationInfo
-)
+from pydantic import BaseModel, ConfigDict, Field, StringConstraints, ValidationError, field_validator, model_validator, ValidationInfo
from pydantic_core import PydanticCustomError
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
@@ -170,12 +162,13 @@ def validate_and_parse_request_args(request: Request, validator: type[BaseModel]
args = request.args.to_dict(flat=True)
# Handle ext parameter: parse JSON string to dict if it's a string
- if 'ext' in args and isinstance(args['ext'], str):
+ if "ext" in args and isinstance(args["ext"], str):
import json
+
try:
- args['ext'] = json.loads(args['ext'])
+ args["ext"] = json.loads(args["ext"])
except json.JSONDecodeError:
- pass # Keep the string and let validation handle the error
+ logging.debug("Failed to decode query arg 'ext' as JSON; passing raw value to validator")
try:
if extras is not None:
@@ -350,6 +343,7 @@ class RaptorConfig(Base):
threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)]
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
random_seed: Annotated[int, Field(default=0, ge=0)]
+ scope: Annotated[Literal["file", "dataset"], Field(default="file")]
auto_disable_for_structured_data: Annotated[bool, Field(default=True)]
ext: Annotated[dict, Field(default={})]
@@ -370,18 +364,17 @@ class ParentChildConfig(Base):
class AutoMetadataField(Base):
"""Schema for a single auto-metadata field configuration."""
- name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(...)]
- type: Annotated[Literal["string", "list", "time"], Field(...)]
+ key: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(...)]
+ type: Annotated[Literal["string", "list", "time", "number"], Field(...)]
description: Annotated[str | None, Field(default=None, max_length=65535)]
- examples: Annotated[list[str] | None, Field(default=None)]
- restrict_values: Annotated[bool, Field(default=False)]
+ enum: Annotated[list[str] | None, Field(default=None)]
class AutoMetadataConfig(Base):
"""Top-level auto-metadata configuration attached to a dataset."""
- enabled: Annotated[bool, Field(default=True)]
- fields: Annotated[list[AutoMetadataField], Field(default_factory=list)]
+ metadata: Annotated[list[AutoMetadataField], Field(default_factory=list)]
+ built_in_metadata: Annotated[list[AutoMetadataField], Field(default_factory=list)]
class ParserConfig(Base):
@@ -401,6 +394,7 @@ class ParserConfig(Base):
pages: Annotated[list[list[int]] | None, Field(default=None)]
ext: Annotated[dict, Field(default={})]
+
class UpdateDocumentReq(Base):
"""
Request model for updating a document.
@@ -408,9 +402,11 @@ class UpdateDocumentReq(Base):
This model validates the request parameters for updating a document,
including name, chunk method, enabled status, and other metadata.
"""
- model_config = ConfigDict(extra='ignore')
+
+ model_config = ConfigDict(extra="ignore")
name: Annotated[str | None, Field(default=None, max_length=65535)]
chunk_method: Annotated[str | None, Field(default=None, max_length=65535)]
+ pipeline_id: Annotated[str | None, Field(default=None, max_length=65535)]
enabled: Annotated[int | None, Field(default=None, ge=0, le=1)]
chunk_count: Annotated[int | None, Field(default=None, ge=0)]
token_count: Annotated[int | None, Field(default=None, ge=0)]
@@ -425,7 +421,7 @@ def validate_document_chunk_method(cls, chunk_method: str | None):
# Validate chunk method if present
valid_chunk_method = {"naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "knowledge_graph", "email", "tag"}
if chunk_method not in valid_chunk_method:
- raise PydanticCustomError("format_invalid", "`chunk_method` {chunk_method} doesn't exist", {"chunk_method":chunk_method})
+ raise PydanticCustomError("format_invalid", "`chunk_method` {chunk_method} doesn't exist", {"chunk_method": chunk_method})
return chunk_method
@@ -435,7 +431,7 @@ def validate_document_enabled(cls, enabled: str | None):
if enabled:
converted = int(enabled)
if converted < 0 or converted > 1:
- raise PydanticCustomError("format_invalid", "`enabled` value invalid, only accept 0 or 1 but is {enabled}", {"enabled":enabled})
+ raise PydanticCustomError("format_invalid", "`enabled` value invalid, only accept 0 or 1 but is {enabled}", {"enabled": enabled})
return enabled
@@ -450,11 +446,12 @@ def validate_document_meta_fields(cls, meta_fields: dict | None):
for k, v in meta_fields.items():
if isinstance(v, list):
if not all(isinstance(i, (str, int, float)) for i in v):
- raise PydanticCustomError("format_invalid", "The type is not supported in list: {v}", {"v":v})
+ raise PydanticCustomError("format_invalid", "The type is not supported in list: {v}", {"v": v})
elif not isinstance(v, (str, int, float)):
- raise PydanticCustomError("format_invalid", "The type is not supported: {v}", {"v":v})
+ raise PydanticCustomError("format_invalid", "The type is not supported: {v}", {"v": v})
return meta_fields
+
class CreateDatasetReq(Base):
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
avatar: Annotated[str | None, Field(default=None, max_length=65535)]
@@ -707,8 +704,7 @@ def validate_parser_dependency(self) -> "CreateDatasetReq":
@classmethod
def validate_chunk_method(cls, v: Any, handler, info: ValidationInfo) -> Any:
"""Wrap validation to unify error messages, including type errors (e.g. list)."""
- allowed = {"naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table",
- "tag", "resume"}
+ allowed = {"naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag", "resume"}
error_msg = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'"
try:
# Run inner validation (type checking)
@@ -818,6 +814,70 @@ def validate_ids(cls, v_list: list[str] | None) -> list[str] | None:
class DeleteDatasetReq(DeleteReq): ...
+class DeleteDocumentReq(DeleteReq):
+ @field_validator("ids", mode="after")
+ @classmethod
+ def validate_ids(cls, v_list: list[str] | None) -> list[str] | None:
+ """
+ Validate document IDs without enforcing UUIDv1.
+
+ Connector-backed documents can use non-UUID identifiers, so we only
+ enforce uniqueness here and leave existence checks to the delete API.
+ """
+ if v_list is None:
+ return None
+
+ duplicates = [item for item, count in Counter(v_list).items() if count > 1]
+ if duplicates:
+ duplicates_str = ", ".join(duplicates)
+ raise PydanticCustomError(
+ "duplicate_uuids",
+ "Duplicate ids: '{duplicate_ids}'",
+ {"duplicate_ids": duplicates_str},
+ )
+
+ return v_list
+
+
+class SearchDatasetReq(BaseModel):
+ model_config = ConfigDict(extra="ignore")
+
+ question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)]
+ doc_ids: Annotated[list[str], Field(default=[])]
+ page: Annotated[int, Field(default=1, ge=1)]
+ size: Annotated[int, Field(default=30, ge=1)]
+ top_k: Annotated[int, Field(default=1024, ge=1)]
+ similarity_threshold: Annotated[float, Field(default=0.0, ge=0.0, le=1.0)]
+ vector_similarity_weight: Annotated[float, Field(default=0.3, ge=0.0, le=1.0)]
+ use_kg: Annotated[bool, Field(default=False)]
+ cross_languages: Annotated[list[str], Field(default=[])]
+ keyword: Annotated[bool, Field(default=False)]
+ search_id: Annotated[str | None, Field(default=None)]
+ rerank_id: Annotated[str | None, Field(default=None)]
+ tenant_rerank_id: Annotated[int | None, Field(default=None)]
+ meta_data_filter: Annotated[dict | None, Field(default=None)]
+
+
+class SearchDatasetsReq(BaseModel):
+ model_config = ConfigDict(extra="ignore")
+
+ dataset_ids: Annotated[list[str], Field(..., min_length=1)]
+ question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)]
+ doc_ids: Annotated[list[str], Field(default=[])]
+ page: Annotated[int, Field(default=1, ge=1)]
+ size: Annotated[int, Field(default=30, ge=1)]
+ top_k: Annotated[int, Field(default=1024, ge=1)]
+ similarity_threshold: Annotated[float, Field(default=0.0, ge=0.0, le=1.0)]
+ vector_similarity_weight: Annotated[float, Field(default=0.3, ge=0.0, le=1.0)]
+ use_kg: Annotated[bool, Field(default=False)]
+ cross_languages: Annotated[list[str], Field(default=[])]
+ keyword: Annotated[bool, Field(default=False)]
+ search_id: Annotated[str | None, Field(default=None)]
+ rerank_id: Annotated[str | None, Field(default=None)]
+ tenant_rerank_id: Annotated[str | None, Field(default=None)]
+ meta_data_filter: Annotated[dict | None, Field(default=None)]
+
+
class BaseListReq(BaseModel):
model_config = ConfigDict(extra="forbid")
@@ -841,6 +901,7 @@ class ListDatasetReq(BaseListReq):
# ---- File Management Request Models ----
+
class CreateFolderReq(Base):
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(...)]
parent_id: Annotated[str | None, Field(default=None)]
@@ -856,7 +917,7 @@ class MoveFileReq(Base):
dest_file_id: Annotated[str | None, Field(default=None)]
new_name: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(default=None)]
- @model_validator(mode='after')
+ @model_validator(mode="after")
def check_operation(self):
if not self.dest_file_id and not self.new_name:
raise ValueError("At least one of dest_file_id or new_name must be provided")
@@ -876,7 +937,7 @@ class ListFileReq(BaseModel):
desc: Annotated[bool, Field(default=True)]
-def validate_immutable_fields(update_doc_req:UpdateDocumentReq, doc):
+def validate_immutable_fields(update_doc_req: UpdateDocumentReq, doc):
"""
Validate that immutable fields have not been changed.
@@ -906,7 +967,7 @@ def validate_immutable_fields(update_doc_req:UpdateDocumentReq, doc):
return None, None
-def validate_document_name(req_doc_name:str, doc, docs_from_name):
+def validate_document_name(req_doc_name: str, doc, docs_from_name):
"""
Validate document name update.
@@ -937,6 +998,7 @@ def validate_document_name(req_doc_name:str, doc, docs_from_name):
return "Duplicated document name in the same dataset.", RetCode.DATA_ERROR
return None, None
+
def validate_chunk_method(doc, chunk_method=None):
"""
Validate chunk method update.
@@ -952,9 +1014,8 @@ def validate_chunk_method(doc, chunk_method=None):
A tuple of (error_message, error_code) if validation fails,
or (None, None) if validation passes.
"""
- if chunk_method is not None and len(chunk_method) == 0: # will not be detected in UpdateDocumentReq
+ if chunk_method is not None and len(chunk_method) == 0: # will not be detected in UpdateDocumentReq
return "`chunk_method` (empty string) is not valid", RetCode.DATA_ERROR
if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name):
return "Not supported yet!", RetCode.DATA_ERROR
return None, None
-
diff --git a/api/utils/web_utils.py b/api/utils/web_utils.py
index 4cb13ff7e6f..23d2421862d 100644
--- a/api/utils/web_utils.py
+++ b/api/utils/web_utils.py
@@ -15,11 +15,8 @@
#
import base64
-import ipaddress
import json
import re
-import socket
-from urllib.parse import urlparse
import aiosmtplib
from email.mime.text import MIMEText
from email.header import Header
@@ -37,10 +34,10 @@
OTP_LENGTH = 4
-OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes
-ATTEMPT_LIMIT = 5 # maximum attempts
-ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes
-RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute
+OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes
+ATTEMPT_LIMIT = 5 # maximum attempts
+ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes
+RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute
CONTENT_TYPE_MAP = {
@@ -188,29 +185,16 @@ def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_opt
return base64.b64decode(result["data"])
-def is_private_ip(ip: str) -> bool:
- try:
- ip_obj = ipaddress.ip_address(ip)
- return ip_obj.is_private
- except ValueError:
- return False
-
-
def is_valid_url(url: str) -> bool:
if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url):
return False
- parsed_url = urlparse(url)
- hostname = parsed_url.hostname
+ from common.ssrf_guard import assert_url_is_safe
- if not hostname:
- return False
try:
- ip = socket.gethostbyname(hostname)
- if is_private_ip(ip):
- return False
- except socket.gaierror:
+ assert_url_is_safe(url)
+ return True
+ except ValueError:
return False
- return True
def safe_json_parse(data: str | dict) -> dict:
diff --git a/cmd/admin_server.go b/cmd/admin_server.go
index 9e876639164..3775d038b72 100644
--- a/cmd/admin_server.go
+++ b/cmd/admin_server.go
@@ -18,12 +18,14 @@ package main
import (
"context"
+ "errors"
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"ragflow/internal/cache"
+ "ragflow/internal/common"
"ragflow/internal/engine"
"syscall"
"time"
@@ -33,33 +35,23 @@ import (
"ragflow/internal/admin"
"ragflow/internal/dao"
- "ragflow/internal/logger"
"ragflow/internal/server"
"ragflow/internal/utility"
)
-// AdminServer admin server
-type AdminServer struct {
- router *admin.Router
- handler *admin.Handler
- service *admin.Service
- engine *gin.Engine
- port string
-}
-
func main() {
var configPath string
flag.StringVar(&configPath, "config", "", "Path to configuration file")
flag.Parse()
// Initialize logger
- if err := logger.Init("info"); err != nil {
+ if err := common.Init("info"); err != nil {
panic("failed to initialize logger: " + err.Error())
}
// Initialize configuration
if err := server.Init(configPath); err != nil {
- logger.Error("Failed to initialize configuration", err)
+ common.Error("Failed to initialize configuration", err)
os.Exit(1)
}
@@ -67,15 +59,15 @@ func main() {
// Reinitialize logger with configured level if different
if cfg.Log.Level != "" && cfg.Log.Level != "info" {
- if err := logger.Init(cfg.Log.Level); err != nil {
- logger.Error("Failed to reinitialize logger with configured level", err)
+ if err := common.Init(cfg.Log.Level); err != nil {
+ common.Error("Failed to reinitialize logger with configured level", err)
}
}
// Set logger for server package
- server.SetLogger(logger.Logger)
+ server.SetLogger(common.Logger)
- logger.Info("Server mode", zap.String("mode", cfg.Server.Mode))
+ common.Info("Server mode", zap.String("mode", cfg.Server.Mode))
// Set Gin mode
if cfg.Server.Mode == "release" {
@@ -86,26 +78,26 @@ func main() {
// Initialize database
if err := dao.InitDB(); err != nil {
- logger.Error("Failed to initialize database", err)
+ common.Error("Failed to initialize database", err)
os.Exit(1)
}
// Initialize doc engine
if err := engine.Init(&cfg.DocEngine); err != nil {
- logger.Fatal("Failed to initialize doc engine", zap.Error(err))
+ common.Fatal("Failed to initialize doc engine", zap.Error(err))
}
defer engine.Close()
// Initialize Redis cache
if err := cache.Init(&cfg.Redis); err != nil {
- logger.Fatal("Failed to initialize Redis", zap.Error(err))
+ common.Fatal("Failed to initialize Redis", zap.Error(err))
}
defer cache.Close()
// Initialize server variables (runtime variables that can change during operation)
// This must be done after Cache is initialized
if err := server.InitVariables(cache.Get()); err != nil {
- logger.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error()))
+ common.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error()))
}
adminService := admin.NewService()
@@ -113,7 +105,7 @@ func main() {
// Initialize default admin user
if err := adminService.InitDefaultAdmin(); err != nil {
- logger.Error("Failed to initialize default admin user", err)
+ common.Error("Failed to initialize default admin user", err)
}
// Initialize router
@@ -129,7 +121,7 @@ func main() {
ginEngine.Use(gin.Recovery())
// Log request URL for every request
ginEngine.Use(func(c *gin.Context) {
- logger.Info("HTTP Request", zap.String("url", c.Request.URL.String()), zap.String("method", c.Request.Method))
+ common.Info("HTTP Request", zap.String("url", c.Request.URL.String()), zap.String("method", c.Request.Method))
c.Next()
})
@@ -144,13 +136,13 @@ func main() {
}
// Print RAGFlow version
- logger.Info("RAGFlow version", zap.String("version", utility.GetRAGFlowVersion()))
+ common.Info("RAGFlow version", zap.String("version", utility.GetRAGFlowVersion()))
// Print all configuration settings
server.PrintAll()
// Print RAGFlow Admin logo
- logger.Info("" +
+ common.Info("" +
"\n ____ ___ ______________ ___ __ _ \n" +
" / __ \\/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___ \n" +
" / /_/ / /| |/ / __/ /_ / / __ \\ | /| / / / /| |/ __ / __ `__ \\/ / __ \\ \n" +
@@ -159,10 +151,10 @@ func main() {
// Start server in a goroutine
go func() {
- logger.Info(fmt.Sprintf("Admin Go Version: %s", utility.GetRAGFlowVersion()))
- logger.Info(fmt.Sprintf("Starting RAGFlow admin server on port: %d", cfg.Admin.Port))
- if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- logger.Fatal("Failed to start server", zap.Error(err))
+ common.Info(fmt.Sprintf("Admin Go Version: %s", utility.GetRAGFlowVersion()))
+ common.Info(fmt.Sprintf("Starting RAGFlow admin server on port: %d", cfg.Admin.Port))
+ if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
+ common.Fatal("Failed to start server", zap.Error(err))
}
}()
@@ -171,8 +163,8 @@ func main() {
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2)
sig := <-quit
- logger.Info("Received signal", zap.String("signal", sig.String()))
- logger.Info("Shutting down server...")
+ common.Info("Received signal", zap.String("signal", sig.String()))
+ common.Info("Shutting down server...")
// Create context with timeout for graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
@@ -180,8 +172,8 @@ func main() {
// Shutdown server
if err := srv.Shutdown(ctx); err != nil {
- logger.Fatal("Server forced to shutdown", zap.Error(err))
+ common.Fatal("Server forced to shutdown", zap.Error(err))
}
- logger.Info("Server exited")
+ common.Info("Server exited")
}
diff --git a/cmd/ragflow_cli.go b/cmd/ragflow_cli.go
index bb18a5a44e2..cc2043687cc 100644
--- a/cmd/ragflow_cli.go
+++ b/cmd/ragflow_cli.go
@@ -4,6 +4,7 @@ import (
"fmt"
"os"
"os/signal"
+ "ragflow/internal/common"
"syscall"
"ragflow/internal/cli"
@@ -17,6 +18,15 @@ func main() {
os.Exit(1)
}
+ // Initialize logger with appropriate level
+ logLevel := "warn" // Default to warn (quiet mode)
+ if args.Verbose {
+ logLevel = "info"
+ }
+ if err = common.Init(logLevel); err != nil {
+ fmt.Printf("Warning: Failed to initialize logger: %v\n", err)
+ }
+
// Show help and exit
if args.ShowHelp {
cli.PrintUsage()
diff --git a/cmd/server_main.go b/cmd/server_main.go
index d1db4ad7622..e4a634e72af 100644
--- a/cmd/server_main.go
+++ b/cmd/server_main.go
@@ -2,6 +2,7 @@ package main
import (
"context"
+ "errors"
"flag"
"fmt"
"net/http"
@@ -23,7 +24,6 @@ import (
"ragflow/internal/dao"
"ragflow/internal/engine"
"ragflow/internal/handler"
- "ragflow/internal/logger"
"ragflow/internal/router"
"ragflow/internal/service"
"ragflow/internal/service/nlp"
@@ -55,81 +55,80 @@ func main() {
// Initialize logger with default level
// logger.Init("info"); // set debug log level
- if err := logger.Init("info"); err != nil {
+ if err := common.Init("info"); err != nil {
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
}
// Initialize configuration
if err := server.Init(""); err != nil {
- logger.Fatal("Failed to initialize config", zap.Error(err))
+ common.Fatal("Failed to initialize config", zap.Error(err))
}
// Override port with command line argument if provided
+ config := server.GetConfig()
if portFlag > 0 {
- config := server.GetConfig()
config.Server.Port = portFlag
- logger.Info("Port overridden by command line argument", zap.Int("port", portFlag))
+ common.Info("Port overridden by command line argument", zap.Int("port", portFlag))
}
- // Load model providers configuration
- if err := server.LoadModelProviders(""); err != nil {
- logger.Fatal("Failed to load model providers", zap.Error(err))
+ if config.Server.Port == 0 {
+ common.Fatal("Server port is not configured. Please specify via --port flag or config file.")
}
- logger.Info("Model providers loaded", zap.Int("count", len(server.GetModelProviders())))
- config := server.GetConfig()
- if config.Server.Port == 0 {
- logger.Fatal("Server port is not configured. Please specify via --port flag or config file.")
+ // Load model providers configuration
+ if err := server.LoadModelProviders(""); err != nil {
+ common.Fatal("Failed to load model providers", zap.Error(err))
}
+ common.Info("Model providers loaded", zap.Int("count", len(server.GetModelProviders())))
// Reinitialize logger with configured level if different
if config.Log.Level != "" && config.Log.Level != "info" {
- if err := logger.Init(config.Log.Level); err != nil {
- logger.Error("Failed to reinitialize logger with configured level", err)
+ if err := common.Init(config.Log.Level); err != nil {
+ common.Error("Failed to reinitialize logger with configured level", err)
}
}
- server.SetLogger(logger.Logger)
+ server.SetLogger(common.Logger)
if config.Log.Level == "" {
- config.Log.Level = logger.GetLevel()
+ config.Log.Level = common.GetLevel()
}
- logger.Info("Server mode", zap.String("mode", config.Server.Mode))
+ common.Info("Server mode", zap.String("mode", config.Server.Mode))
// Print all configuration settings
server.PrintAll()
// Initialize database
if err := dao.InitDB(); err != nil {
- logger.Fatal("Failed to initialize database", zap.Error(err))
+ common.Fatal("Failed to initialize database", zap.Error(err))
}
// Initialize LLM factory data models from configuration file
if err := dao.InitLLMFactory(); err != nil {
- logger.Error("Failed to initialize LLM factory", err)
+ common.Error("Failed to initialize LLM factory", err)
} else {
- logger.Info("LLM factory initialized successfully")
+ common.Info("LLM factory initialized successfully")
}
// Initialize doc engine
if err := engine.Init(&config.DocEngine); err != nil {
- logger.Fatal("Failed to initialize doc engine", zap.Error(err))
+ common.Fatal("Failed to initialize doc engine", zap.Error(err))
}
defer engine.Close()
// Initialize Redis cache
if err := cache.Init(&config.Redis); err != nil {
- logger.Fatal("Failed to initialize Redis", zap.Error(err))
+ common.Fatal("Failed to initialize Redis", zap.Error(err))
}
defer cache.Close()
if err := storage.InitStorageFactory(); err != nil {
- logger.Fatal("Failed to initialize storage factory", zap.Error(err))
+ common.Fatal("Failed to initialize storage factory", zap.Error(err))
}
// Initialize server variables (runtime variables that can change during operation)
// This must be done after Cache is initialized
if err := server.InitVariables(cache.Get()); err != nil {
- logger.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error()))
+ common.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error()))
}
// Initialize admin status (default: unavailable=1)
@@ -140,19 +139,19 @@ func main() {
DictPath: "/usr/share/infinity/resource",
}
if err := tokenizer.Init(tokenizerCfg); err != nil {
- logger.Fatal("Failed to initialize tokenizer", zap.Error(err))
+ common.Fatal("Failed to initialize tokenizer", zap.Error(err))
}
defer tokenizer.Close()
// Initialize global QueryBuilder using tokenizer's DictPath
// This ensures the Synonym uses the same wordnet directory as tokenizer
if err := nlp.InitQueryBuilderFromTokenizer(tokenizerCfg.DictPath); err != nil {
- logger.Fatal("Failed to initialize query builder", zap.Error(err))
+ common.Fatal("Failed to initialize query builder", zap.Error(err))
}
startServer(config)
- logger.Info("Server exited")
+ common.Info("Server exited")
}
func startServer(config *server.Config) {
@@ -181,6 +180,9 @@ func startServer(config *server.Config) {
memoryService := service.NewMemoryService()
modelProviderService := service.NewModelProviderService()
+ // Initialize doc engine for skill search
+ docEngine := engine.Get()
+
// Initialize handler layer
authHandler := handler.NewAuthHandler()
userHandler := handler.NewUserHandler(userService)
@@ -197,10 +199,11 @@ func startServer(config *server.Config) {
searchHandler := handler.NewSearchHandler(searchService, userService)
fileHandler := handler.NewFileHandler(fileService, userService)
memoryHandler := handler.NewMemoryHandler(memoryService)
+ skillSearchHandler := handler.NewSkillSearchHandler(docEngine)
providerHandler := handler.NewProviderHandler(userService, modelProviderService)
// Initialize router
- r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, providerHandler)
+ r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, skillSearchHandler, providerHandler)
// Create Gin engine
ginEngine := gin.New()
@@ -214,45 +217,49 @@ func startServer(config *server.Config) {
// Setup routes
r.Setup(ginEngine)
- // Create HTTP server
+ // Create HTTP server with timeouts to prevent slow clients from blocking shutdown
addr := fmt.Sprintf(":%d", config.Server.Port)
srv := &http.Server{
- Addr: addr,
- Handler: ginEngine,
+ Addr: addr,
+ Handler: ginEngine,
+ ReadHeaderTimeout: 10 * time.Second,
+ ReadTimeout: 60 * time.Second,
+ WriteTimeout: 120 * time.Second,
+ IdleTimeout: 120 * time.Second,
}
// Start server in a goroutine
go func() {
- logger.Info(
+ common.Info(
"\n ____ ___ ______ ______ __\n" +
" / __ \\ / | / ____// ____// /____ _ __\n" +
" / /_/ // /| | / / __ / /_ / // __ \\| | /| / /\n" +
" / _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ /\n" +
" /_/ |_|/_/ |_|\\____//_/ /_/ \\____/ |__/|__/\n",
)
- logger.Info(fmt.Sprintf("RAGFlow Go Version: %s", utility.GetRAGFlowVersion()))
- logger.Info(fmt.Sprintf("Server starting on port: %d", config.Server.Port))
- if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- logger.Fatal("Failed to start server", zap.Error(err))
+ common.Info(fmt.Sprintf("RAGFlow Go Version: %s", utility.GetRAGFlowVersion()))
+ common.Info(fmt.Sprintf("Server starting on port: %d", config.Server.Port))
+ if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
+ common.Fatal("Failed to start server", zap.Error(err))
}
}()
// Get local IP address for heartbeat reporting
- localIP := utility.GetLocalIP()
- if localIP == "" {
- localIP = "127.0.0.1"
+ localIP, err := utility.GetLocalIP()
+ if err != nil {
+ common.Fatal("fail to get local ip address")
}
// Initialize and start heartbeat reporter to admin server
heartbeatService := service.NewHeartbeatSender(
- logger.Logger,
+ common.Logger,
common.ServerTypeAPI,
fmt.Sprintf("ragflow-server-%d", config.Server.Port),
localIP,
config.Server.Port,
)
- if err := heartbeatService.InitHTTPClient(); err != nil {
- logger.Warn("Failed to initialize heartbeat service", zap.Error(err))
+ if err = heartbeatService.InitHTTPClient(); err != nil {
+ common.Warn("Failed to initialize heartbeat service", zap.Error(err))
} else {
// Start heartbeat reporter with 30 seconds interval
heartbeatReporter := utility.NewScheduledTask("Heartbeat reporter", 3*time.Second, func() {
@@ -272,15 +279,15 @@ func startServer(config *server.Config) {
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2)
sig := <-quit
- logger.Info(fmt.Sprintf("Receives %s signal to shutdown server", strings.ToUpper(sig.String())))
- logger.Info("Shutting down server...")
+ common.Info(fmt.Sprintf("Receives %s signal to shutdown server", strings.ToUpper(sig.String())))
+ common.Info("Shutting down server...")
// Create context with timeout for graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Shutdown server
- if err := srv.Shutdown(ctx); err != nil {
- logger.Fatal("Server forced to shutdown", zap.Error(err))
+ if err = srv.Shutdown(ctx); err != nil {
+ common.Fatal("Server forced to shutdown", zap.Error(err))
}
}
diff --git a/common/constants.py b/common/constants.py
index b027908637d..5ab9acaa502 100644
--- a/common/constants.py
+++ b/common/constants.py
@@ -244,6 +244,12 @@ class ForgettingPolicy(StrEnum):
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
TAG_FLD = "tag_feas"
+# Maximum page number used as "unlimited" sentinel value.
+# Parsing layer (chunk/Pdf.__call__) uses MAXIMUM_PAGE_NUMBER.
+# Task/DB layer (Task model) uses MAXIMUM_PAGE_NUMBER * 1000 to avoid collision with user-specified page ranges.
+MAXIMUM_PAGE_NUMBER = 100000
+MAXIMUM_TASK_PAGE_NUMBER = MAXIMUM_PAGE_NUMBER * 1000
+
MINERU_ENV_KEYS = ["MINERU_APISERVER", "MINERU_OUTPUT_DIR", "MINERU_BACKEND", "MINERU_SERVER_URL", "MINERU_DELETE_OUTPUT"]
MINERU_DEFAULT_CONFIG = {
@@ -260,3 +266,8 @@ class ForgettingPolicy(StrEnum):
"PADDLEOCR_ACCESS_TOKEN": None,
"PADDLEOCR_ALGORITHM": "PaddleOCR-VL",
}
+
+OPENDATALOADER_ENV_KEYS = ["OPENDATALOADER_APISERVER"]
+OPENDATALOADER_DEFAULT_CONFIG = {
+ "OPENDATALOADER_APISERVER": "",
+}
diff --git a/common/data_source/airtable_connector.py b/common/data_source/airtable_connector.py
index 46dcf07ee47..f1ab3004036 100644
--- a/common/data_source/airtable_connector.py
+++ b/common/data_source/airtable_connector.py
@@ -8,8 +8,14 @@
from common.data_source.config import AIRTABLE_CONNECTOR_SIZE_THRESHOLD, INDEX_BATCH_SIZE, DocumentSource
from common.data_source.exceptions import ConnectorMissingCredentialError
-from common.data_source.interfaces import LoadConnector, PollConnector
-from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch
+from common.data_source.interfaces import LoadConnector, PollConnector, SlimConnectorWithPermSync
+from common.data_source.models import (
+ Document,
+ GenerateDocumentsOutput,
+ GenerateSlimDocumentOutput,
+ SecondsSinceUnixEpoch,
+ SlimDocument,
+)
from common.data_source.utils import extract_size_bytes, get_file_ext
class AirtableClientNotSetUpError(PermissionError):
@@ -19,7 +25,7 @@ def __init__(self) -> None:
)
-class AirtableConnector(LoadConnector, PollConnector):
+class AirtableConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""
Lightweight Airtable connector.
@@ -39,6 +45,43 @@ def __init__(
self._airtable_client: AirtableApi | None = None
self.size_threshold = AIRTABLE_CONNECTOR_SIZE_THRESHOLD
+ def _iter_attachment_entries(self) -> Generator[tuple[str, str, str, str, str | None, dict[str, Any]], None, None]:
+ if not self._airtable_client:
+ raise ConnectorMissingCredentialError("Airtable credentials not loaded")
+
+ table = self.airtable_client.table(self.base_id, self.table_name_or_id)
+ records = table.all()
+
+ logging.info(
+ f"Starting Airtable attachment scan for table {self.table_name_or_id}, "
+ f"{len(records)} records found."
+ )
+
+ for record in records:
+ record_id = record.get("id")
+ fields = record.get("fields", {})
+ created_time = record.get("createdTime")
+
+ for field_value in fields.values():
+ if not isinstance(field_value, list):
+ continue
+
+ for attachment in field_value:
+ filename = attachment.get("filename")
+ attachment_id = attachment.get("id")
+
+ if not record_id or not filename or not attachment_id:
+ continue
+
+ yield (
+ record_id,
+ attachment_id,
+ filename,
+ f"airtable:{record_id}:{attachment_id}",
+ created_time,
+ attachment,
+ )
+
# -------------------------
# Credentials
# -------------------------
@@ -64,69 +107,65 @@ def load_from_state(self) -> GenerateDocumentsOutput:
if not self._airtable_client:
raise ConnectorMissingCredentialError("Airtable credentials not loaded")
- table = self.airtable_client.table(self.base_id, self.table_name_or_id)
- records = table.all()
-
- logging.info(
- f"Starting Airtable blob ingestion for table {self.table_name_or_id}, "
- f"{len(records)} records found."
- )
-
batch: list[Document] = []
- for record in records:
- record_id = record.get("id")
- fields = record.get("fields", {})
- created_time = record.get("createdTime")
-
- for field_value in fields.values():
- # We only care about attachment fields (lists of dicts with url/filename)
- if not isinstance(field_value, list):
- continue
+ for record_id, attachment_id, filename, doc_id, created_time, attachment in self._iter_attachment_entries():
+ url = attachment.get("url")
+ if not url or not created_time:
+ continue
+
+ try:
+ resp = requests.get(url, timeout=30)
+ resp.raise_for_status()
+ content = resp.content
+ except Exception:
+ logging.exception(
+ f"Failed to download attachment {filename} "
+ f"(record={record_id})"
+ )
+ continue
+ size_bytes = extract_size_bytes(attachment)
+ if (
+ self.size_threshold is not None
+ and isinstance(size_bytes, int)
+ and size_bytes > self.size_threshold
+ ):
+ logging.warning(
+ f"{filename} exceeds size threshold of {self.size_threshold}. Skipping."
+ )
+ continue
+ batch.append(
+ Document(
+ id=doc_id,
+ blob=content,
+ source=DocumentSource.AIRTABLE,
+ semantic_identifier=filename,
+ extension=get_file_ext(filename),
+ size_bytes=size_bytes if size_bytes else 0,
+ doc_updated_at=datetime.strptime(created_time, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc)
+ )
+ )
+
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
- for attachment in field_value:
- url = attachment.get("url")
- filename = attachment.get("filename")
- attachment_id = attachment.get("id")
+ if batch:
+ yield batch
- if not url or not filename or not attachment_id:
- continue
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ del callback
- try:
- resp = requests.get(url, timeout=30)
- resp.raise_for_status()
- content = resp.content
- except Exception:
- logging.exception(
- f"Failed to download attachment {filename} "
- f"(record={record_id})"
- )
- continue
- size_bytes = extract_size_bytes(attachment)
- if (
- self.size_threshold is not None
- and isinstance(size_bytes, int)
- and size_bytes > self.size_threshold
- ):
- logging.warning(
- f"{filename} exceeds size threshold of {self.size_threshold}. Skipping."
- )
- continue
- batch.append(
- Document(
- id=f"airtable:{record_id}:{attachment_id}",
- blob=content,
- source=DocumentSource.AIRTABLE,
- semantic_identifier=filename,
- extension=get_file_ext(filename),
- size_bytes=size_bytes if size_bytes else 0,
- doc_updated_at=datetime.strptime(created_time, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc)
- )
- )
+ batch: list[SlimDocument] = []
- if len(batch) >= self.batch_size:
- yield batch
- batch = []
+ for _, _, _, doc_id, _, _ in self._iter_attachment_entries():
+ batch.append(SlimDocument(id=doc_id))
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
if batch:
yield batch
@@ -165,4 +204,4 @@ def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch)
for doc in first_batch:
print(f"- {doc.semantic_identifier} ({doc.size_bytes} bytes)")
except StopIteration:
- print("No documents available in Dropbox.")
\ No newline at end of file
+ print("No documents available in Dropbox.")
diff --git a/common/data_source/asana_connector.py b/common/data_source/asana_connector.py
index 4143c0cba0d..e3aee9c4f04 100644
--- a/common/data_source/asana_connector.py
+++ b/common/data_source/asana_connector.py
@@ -1,13 +1,13 @@
from collections.abc import Iterator
import time
-from datetime import datetime
+from datetime import datetime, timezone
import logging
from typing import Any, Dict
import asana
import requests
from common.data_source.config import CONTINUE_ON_CONNECTOR_FAILURE, INDEX_BATCH_SIZE, DocumentSource
-from common.data_source.interfaces import LoadConnector, PollConnector
-from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch
+from common.data_source.interfaces import LoadConnector, PollConnector, SlimConnectorWithPermSync
+from common.data_source.models import Document, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SecondsSinceUnixEpoch, SlimDocument
from common.data_source.utils import extract_size_bytes, get_file_ext
@@ -63,6 +63,31 @@ def get_tasks(
) -> Iterator[AsanaTask]:
"""Get all tasks from the projects with the given gids that were modified since the given date.
If project_gids is None, get all tasks from all projects in the workspace."""
+ projects_list = self._get_project_gids_to_process(project_gids)
+ start_seconds = int(time.mktime(datetime.now().timetuple()))
+ for project_gid in projects_list:
+ for task in self._get_tasks_for_project(
+ project_gid, start_date, start_seconds
+ ):
+ yield task
+ logging.info(f"Completed fetching {self.task_count} tasks from Asana")
+ if self.api_error_count > 0:
+ logging.warning(
+ f"Encountered {self.api_error_count} API errors during task fetching"
+ )
+
+ def get_task_ids(
+ self, project_gids: list[str] | None, start_date: str
+ ) -> Iterator[str]:
+ """Get task gids without hydrating comments, users, or task text."""
+ projects_list = self._get_project_gids_to_process(project_gids)
+ for project_gid in projects_list:
+ for task_id in self._get_task_ids_for_project(project_gid, start_date):
+ yield task_id
+
+ def _get_project_gids_to_process(
+ self, project_gids: list[str] | None
+ ) -> list[str]:
logging.info("Starting to fetch Asana projects")
projects = self.project_api.get_projects(
opts={
@@ -70,7 +95,6 @@ def get_tasks(
"opt_fields": "gid,name,archived,modified_at",
}
)
- start_seconds = int(time.mktime(datetime.now().timetuple()))
projects_list = []
project_count = 0
for project_info in projects:
@@ -85,20 +109,9 @@ def get_tasks(
if project_count % 100 == 0:
logging.info(f"Processed {project_count} projects")
logging.info(f"Found {len(projects_list)} projects to process")
- for project_gid in projects_list:
- for task in self._get_tasks_for_project(
- project_gid, start_date, start_seconds
- ):
- yield task
- logging.info(f"Completed fetching {self.task_count} tasks from Asana")
- if self.api_error_count > 0:
- logging.warning(
- f"Encountered {self.api_error_count} API errors during task fetching"
- )
+ return projects_list
- def _get_tasks_for_project(
- self, project_gid: str, start_date: str, start_seconds: int
- ) -> Iterator[AsanaTask]:
+ def _get_project_to_process(self, project_gid: str) -> dict | None:
project = self.project_api.get_project(project_gid, opts={})
project_name = project.get("name", project_gid)
team = project.get("team") or {}
@@ -122,6 +135,35 @@ def _get_tasks_for_project(
f"Processing private project in configured team: {project_name} ({project_gid})"
)
+ return project
+
+ def _get_task_ids_for_project(
+ self, project_gid: str, start_date: str
+ ) -> Iterator[str]:
+ project = self._get_project_to_process(project_gid)
+ if project is None:
+ return
+
+ tasks_from_api = self.tasks_api.get_tasks_for_project(
+ project_gid,
+ {
+ "opt_fields": "gid",
+ "modified_since": start_date,
+ },
+ )
+ for data in tasks_from_api:
+ task_id = data.get("gid")
+ if task_id:
+ yield task_id
+
+ def _get_tasks_for_project(
+ self, project_gid: str, start_date: str, start_seconds: int
+ ) -> Iterator[AsanaTask]:
+ project = self._get_project_to_process(project_gid)
+ if project is None:
+ return
+
+ project_name = project.get("name", project_gid)
simple_start_date = start_date.split(".")[0].split("+")[0]
logging.info(
f"Fetching tasks modified since {simple_start_date} for project: {project_name} ({project_gid})"
@@ -242,7 +284,7 @@ def get_attachments(self, task_gid: str) -> list[dict]:
full = self.attachments_api.get_attachment(
attachment_gid=gid,
opts={
- "opt_fields": "name,download_url,size,created_at"
+ "opt_fields": "gid,name,download_url,size,created_at"
}
)
@@ -330,7 +372,7 @@ def get_time(self) -> str:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
-class AsanaConnector(LoadConnector, PollConnector):
+class AsanaConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
def __init__(
self,
asana_workspace_id: str,
@@ -367,11 +409,22 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
- start_time = datetime.fromtimestamp(start).isoformat()
+ start_time = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
+ end_time = datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
logging.info(f"Starting Asana poll from {start_time}")
docs_batch: list[Document] = []
tasks = self.asana_client.get_tasks(self.project_ids_to_index, start_time)
for task in tasks:
+ if end_time:
+ task_last_modified = task.last_modified
+ if task_last_modified.tzinfo is None:
+ task_last_modified = task_last_modified.replace(tzinfo=timezone.utc)
+ else:
+ task_last_modified = task_last_modified.astimezone(timezone.utc)
+
+ if task_last_modified >= end_time:
+ continue
+
docs = self._task_to_documents(task)
docs_batch.extend(docs)
@@ -390,6 +443,31 @@ def load_from_state(self) -> GenerateDocumentsOutput:
logging.info("Starting full index of all Asana tasks")
return self.poll_source(start=0, end=None)
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ del callback
+
+ start_time = datetime.fromtimestamp(0, tz=timezone.utc).isoformat()
+ docs_batch: list[SlimDocument] = []
+
+ for task_id in self.asana_client.get_task_ids(self.project_ids_to_index, start_time):
+ attachments = self.asana_client.get_attachments(task_id)
+
+ for att in attachments:
+ attachment_gid = att.get("gid")
+ if not attachment_gid:
+ continue
+
+ docs_batch.append(SlimDocument(id=f"asana:{task_id}:{attachment_gid}"))
+ if len(docs_batch) >= self.batch_size:
+ yield docs_batch
+ docs_batch = []
+
+ if docs_batch:
+ yield docs_batch
+
def _task_to_documents(self, task: AsanaTask) -> list[Document]:
docs: list[Document] = []
@@ -456,4 +534,4 @@ def _task_to_documents(self, task: AsanaTask) -> list[Document]:
for docs in all_docs:
for doc in docs:
print(doc.id)
- logging.info("Asana connector test completed")
\ No newline at end of file
+ logging.info("Asana connector test completed")
diff --git a/common/data_source/bitbucket/connector.py b/common/data_source/bitbucket/connector.py
index f355a8945fc..0557d2a5039 100644
--- a/common/data_source/bitbucket/connector.py
+++ b/common/data_source/bitbucket/connector.py
@@ -269,17 +269,11 @@ def validate_checkpoint_json(
def retrieve_all_slim_docs_perm_sync(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> Iterator[list[SlimDocument]]:
"""Return only document IDs for all existing pull requests."""
batch: list[SlimDocument] = []
- params = self._build_params(
- fields=SLIM_PR_LIST_RESPONSE_FIELDS,
- start=start,
- end=end,
- )
+ params = self._build_params(fields=SLIM_PR_LIST_RESPONSE_FIELDS)
with self._client() as client:
for slug in self._iter_target_repositories(client):
for pr in self._iter_pull_requests_for_repo(
@@ -361,10 +355,7 @@ def validate_connector_settings(self) -> None:
start_time = datetime.fromtimestamp(0, tz=timezone.utc)
end_time = datetime.now(timezone.utc)
- for doc_batch in bitbucket.retrieve_all_slim_docs_perm_sync(
- start=start_time.timestamp(),
- end=end_time.timestamp(),
- ):
+ for doc_batch in bitbucket.retrieve_all_slim_docs_perm_sync():
for doc in doc_batch:
print(doc)
@@ -385,4 +376,4 @@ def validate_connector_settings(self) -> None:
except StopIteration as e:
bitbucket_checkpoint = e.value
break
-
\ No newline at end of file
+
diff --git a/common/data_source/blob_connector.py b/common/data_source/blob_connector.py
index 1ab39189d79..7505b878ba3 100644
--- a/common/data_source/blob_connector.py
+++ b/common/data_source/blob_connector.py
@@ -19,7 +19,13 @@
InsufficientPermissionsError
)
from common.data_source.interfaces import LoadConnector, PollConnector
-from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput
+from common.data_source.models import (
+ Document,
+ SecondsSinceUnixEpoch,
+ GenerateDocumentsOutput,
+ GenerateSlimDocumentOutput,
+ SlimDocument,
+)
class BlobStorageConnector(LoadConnector, PollConnector):
@@ -122,29 +128,7 @@ def _yield_blob_objects(
end: datetime,
) -> GenerateDocumentsOutput:
"""Generate bucket objects"""
- if self.s3_client is None:
- raise ConnectorMissingCredentialError("Blob storage")
-
- paginator = self.s3_client.get_paginator("list_objects_v2")
- pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
-
- # Collect all objects first to count filename occurrences
- all_objects = []
- for page in pages:
- if "Contents" not in page:
- continue
- for obj in page["Contents"]:
- if obj["Key"].endswith("/"):
- continue
- last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
- if start < last_modified <= end:
- all_objects.append(obj)
-
- # Count filename occurrences to determine which need full paths
- filename_counts: dict[str, int] = {}
- for obj in all_objects:
- file_name = os.path.basename(obj["Key"])
- filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
+ all_objects, filename_counts = self._collect_blob_objects(start, end)
batch: list[Document] = []
for obj in all_objects:
@@ -162,20 +146,15 @@ def _yield_blob_objects(
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
)
continue
-
+
try:
- blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
+ blob = download_object(
+ self.s3_client, self.bucket_name, key, self.size_threshold
+ )
if blob is None:
continue
- # Use full path only if filename appears multiple times
- if filename_counts.get(file_name, 0) > 1:
- relative_path = key
- if self.prefix and key.startswith(self.prefix):
- relative_path = key[len(self.prefix):]
- semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
- else:
- semantic_id = file_name
+ semantic_id = self._get_semantic_id(key, file_name, filename_counts)
batch.append(
Document(
@@ -185,7 +164,7 @@ def _yield_blob_objects(
semantic_identifier=semantic_id,
extension=get_file_ext(file_name),
doc_updated_at=last_modified,
- size_bytes=size_bytes if size_bytes else 0
+ size_bytes=size_bytes if size_bytes else 0,
)
)
if len(batch) == self.batch_size:
@@ -194,7 +173,76 @@ def _yield_blob_objects(
except Exception:
logging.exception(f"Error decoding object {key}")
-
+
+ if batch:
+ yield batch
+
+ def _collect_blob_objects(
+ self,
+ start: datetime,
+ end: datetime,
+ ) -> tuple[list[dict[str, Any]], dict[str, int]]:
+ """Collect object metadata for files in the requested window."""
+ if self.s3_client is None:
+ raise ConnectorMissingCredentialError("Blob storage")
+
+ paginator = self.s3_client.get_paginator("list_objects_v2")
+ pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
+
+ # Collect all objects first to count filename occurrences
+ all_objects: list[dict[str, Any]] = []
+ for page in pages:
+ if "Contents" not in page:
+ continue
+ for obj in page["Contents"]:
+ if obj["Key"].endswith("/"):
+ continue
+ last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
+ if start < last_modified <= end:
+ all_objects.append(obj)
+
+ filename_counts: dict[str, int] = {}
+ for obj in all_objects:
+ file_name = os.path.basename(obj["Key"])
+ filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
+
+ return all_objects, filename_counts
+
+ def _get_semantic_id(
+ self,
+ key: str,
+ file_name: str,
+ filename_counts: dict[str, int],
+ ) -> str:
+ """Use full relative path only when filenames collide."""
+ if filename_counts.get(file_name, 0) > 1:
+ relative_path = key
+ if self.prefix and key.startswith(self.prefix):
+ relative_path = key[len(self.prefix):]
+ return relative_path.replace("/", " / ") if relative_path else file_name
+ return file_name
+
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ """Return a full current snapshot of blob object IDs without downloading content."""
+ del callback
+
+ all_objects, _ = self._collect_blob_objects(
+ start=datetime(1970, 1, 1, tzinfo=timezone.utc),
+ end=datetime.now(timezone.utc),
+ )
+
+ batch: list[SlimDocument] = []
+ for obj in all_objects:
+ batch.append(
+ SlimDocument(id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}")
+ )
+ if len(batch) == self.batch_size:
+ yield batch
+ batch = []
+
if batch:
yield batch
diff --git a/common/data_source/box_connector.py b/common/data_source/box_connector.py
index 253029d3c92..cc44f356e87 100644
--- a/common/data_source/box_connector.py
+++ b/common/data_source/box_connector.py
@@ -1,7 +1,7 @@
"""Box connector"""
import logging
from datetime import datetime, timezone
-from typing import Any
+from typing import Any, Generator
from box_sdk_gen import BoxClient
from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE
@@ -10,21 +10,21 @@
ConnectorValidationError,
)
from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
-from common.data_source.models import Document, GenerateDocumentsOutput
+from common.data_source.models import Document, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument
from common.data_source.utils import get_file_ext
+
class BoxConnector(LoadConnector, PollConnector):
def __init__(self, folder_id: str, batch_size: int = INDEX_BATCH_SIZE, use_marker: bool = True) -> None:
self.batch_size = batch_size
self.folder_id = "0" if not folder_id else folder_id
self.use_marker = use_marker
-
+ self.box_client: BoxClient | None = None
def load_credentials(self, auth: Any):
self.box_client = BoxClient(auth=auth)
return None
-
def validate_connector_settings(self):
if self.box_client is None:
raise ConnectorMissingCredentialError("Box")
@@ -35,79 +35,41 @@ def validate_connector_settings(self):
logging.exception("[Box]: Failed to validate Box credentials")
raise ConnectorValidationError(f"Unexpected error during Box settings validation: {e}")
-
- def _yield_files_recursive(
- self,
- folder_id: str,
- start: SecondsSinceUnixEpoch | None,
- end: SecondsSinceUnixEpoch | None,
- relative_folder_path: str = "",
- ) -> GenerateDocumentsOutput:
-
+ def _iter_files_recursive(
+ self,
+ folder_id: str,
+ relative_folder_path: str = "",
+ ) -> Generator[tuple[Any, str], None, None]:
if self.box_client is None:
raise ConnectorMissingCredentialError("Box")
result = self.box_client.folders.get_folder_items(
folder_id=folder_id,
limit=self.batch_size,
- usemarker=self.use_marker
+ usemarker=self.use_marker,
)
while True:
- batch: list[Document] = []
for entry in result.entries:
- if entry.type == 'file' :
- file = self.box_client.files.get_file_by_id(
- entry.id
- )
- modified_time: SecondsSinceUnixEpoch | None = None
- raw_time = (
- getattr(file, "created_at", None)
- or getattr(file, "content_created_at", None)
- )
-
- if raw_time:
- modified_time = self._box_datetime_to_epoch_seconds(raw_time)
- if start is not None and modified_time <= start:
- continue
- if end is not None and modified_time > end:
- continue
-
- content_bytes = self.box_client.downloads.download_file(file.id)
+ if entry.type == "file":
+ file = self.box_client.files.get_file_by_id(entry.id)
semantic_identifier = (
f"{relative_folder_path} / {file.name}"
if relative_folder_path
else file.name
)
-
- batch.append(
- Document(
- id=f"box:{file.id}",
- blob=content_bytes.read(),
- source=DocumentSource.BOX,
- semantic_identifier=semantic_identifier,
- extension=get_file_ext(file.name),
- doc_updated_at=modified_time,
- size_bytes=file.size,
- metadata=file.metadata
- )
- )
- elif entry.type == 'folder':
+ yield file, semantic_identifier
+ elif entry.type == "folder":
child_relative_path = (
f"{relative_folder_path} / {entry.name}"
if relative_folder_path
else entry.name
)
- yield from self._yield_files_recursive(
+ yield from self._iter_files_recursive(
folder_id=entry.id,
- start=start,
- end=end,
- relative_folder_path=child_relative_path
+ relative_folder_path=child_relative_path,
)
- if batch:
- yield batch
-
if not result.next_marker:
break
@@ -115,9 +77,56 @@ def _yield_files_recursive(
folder_id=folder_id,
limit=self.batch_size,
marker=result.next_marker,
- usemarker=True
+ usemarker=True,
)
+ def _yield_files_recursive(
+ self,
+ folder_id: str,
+ start: SecondsSinceUnixEpoch | None,
+ end: SecondsSinceUnixEpoch | None,
+ relative_folder_path: str = "",
+ ) -> GenerateDocumentsOutput:
+ if self.box_client is None:
+ raise ConnectorMissingCredentialError("Box")
+
+ batch: list[Document] = []
+ for file, semantic_identifier in self._iter_files_recursive(
+ folder_id=folder_id,
+ relative_folder_path=relative_folder_path,
+ ):
+ modified_time: SecondsSinceUnixEpoch | None = None
+ raw_time = (
+ getattr(file, "created_at", None)
+ or getattr(file, "content_created_at", None)
+ )
+
+ if raw_time:
+ modified_time = self._box_datetime_to_epoch_seconds(raw_time)
+ if start is not None and modified_time <= start:
+ continue
+ if end is not None and modified_time > end:
+ continue
+
+ content_bytes = self.box_client.downloads.download_file(file.id)
+ batch.append(
+ Document(
+ id=f"box:{file.id}",
+ blob=content_bytes.read(),
+ source=DocumentSource.BOX,
+ semantic_identifier=semantic_identifier,
+ extension=get_file_ext(file.name),
+ doc_updated_at=modified_time,
+ size_bytes=file.size,
+ metadata=file.metadata,
+ )
+ )
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+
+ if batch:
+ yield batch
def _box_datetime_to_epoch_seconds(self, dt: datetime) -> SecondsSinceUnixEpoch:
"""Convert a Box SDK datetime to Unix epoch seconds (UTC).
@@ -133,6 +142,21 @@ def _box_datetime_to_epoch_seconds(self, dt: datetime) -> SecondsSinceUnixEpoch:
return SecondsSinceUnixEpoch(int(dt.timestamp()))
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ del callback
+
+ batch: list[SlimDocument] = []
+ for file, _semantic_identifier in self._iter_files_recursive(folder_id=self.folder_id):
+ batch.append(SlimDocument(id=f"box:{file.id}"))
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+
+ if batch:
+ yield batch
def poll_source(self, start, end):
return self._yield_files_recursive(folder_id=self.folder_id, start=start, end=end)
diff --git a/common/data_source/confluence_connector.py b/common/data_source/confluence_connector.py
index abe55b5b275..ef0d6a77600 100644
--- a/common/data_source/confluence_connector.py
+++ b/common/data_source/confluence_connector.py
@@ -1904,8 +1904,6 @@ def retrieve_all_slim_docs(
def retrieve_all_slim_docs_perm_sync(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
"""
@@ -1913,16 +1911,12 @@ def retrieve_all_slim_docs_perm_sync(
Does not fetch actual text. Used primarily for incremental permission sync.
"""
return self._retrieve_all_slim_docs(
- start=start,
- end=end,
callback=callback,
include_permissions=True,
)
def _retrieve_all_slim_docs(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
diff --git a/common/data_source/dingtalk_ai_table_connector.py b/common/data_source/dingtalk_ai_table_connector.py
index 66588d4d307..40dc44b61f5 100644
--- a/common/data_source/dingtalk_ai_table_connector.py
+++ b/common/data_source/dingtalk_ai_table_connector.py
@@ -22,8 +22,8 @@
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
from common.data_source.exceptions import ConnectorMissingCredentialError, ConnectorValidationError
-from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
-from common.data_source.models import Document, GenerateDocumentsOutput
+from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
+from common.data_source.models import Document, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument
logger = logging.getLogger(__name__)
@@ -38,7 +38,7 @@ def __init__(self) -> None:
super().__init__("DingTalk Notable client is not set up. Did you forget to call load_credentials()?")
-class DingTalkAITableConnector(LoadConnector, PollConnector):
+class DingTalkAITableConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""
DingTalk AI Table (Notable) connector for accessing table records.
@@ -75,6 +75,9 @@ def __init__(
self._client: NotableClient | None = None
self._access_token: str | None = None
+ def _document_id(self, sheet_id: str, record_id: str) -> str:
+ return f"{_DINGTALK_AI_TABLE_DOC_ID_PREFIX}{self.table_id}:{sheet_id}:{record_id}"
+
def _create_client(self) -> NotableClient:
"""Create DingTalk Notable API client."""
config = open_api_models.Config()
@@ -280,6 +283,8 @@ def _convert_record_to_document(
record_id = record.get("id", "unknown")
fields = record.get("fields", {})
+ doc_id = self._document_id(sheet_id, str(record_id))
+
# Convert fields to JSON string for blob content
content = json.dumps(fields, ensure_ascii=False, indent=2)
blob = content.encode("utf-8")
@@ -304,7 +309,7 @@ def _convert_record_to_document(
# Create document
doc = Document(
- id=f"{_DINGTALK_AI_TABLE_DOC_ID_PREFIX}{self.table_id}:{sheet_id}:{record_id}",
+ id=doc_id,
source=DocumentSource.DINGTALK_AI_TABLE,
semantic_identifier=semantic_identifier,
extension=".json",
@@ -316,6 +321,44 @@ def _convert_record_to_document(
return doc
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ """
+ Enumerate current record IDs for all sheets without building document blobs.
+
+ IDs match :meth:`_convert_record_to_document` / full ingest.
+ """
+ del callback
+ logger.info(
+ "[DingTalk Notable]: slim snapshot table_id=%s operator_id=%s",
+ self.table_id,
+ self.operator_id,
+ )
+ sheets = self._get_all_sheets()
+ batch: list[SlimDocument] = []
+ for sheet in sheets:
+ sheet_id = sheet["id"]
+ next_token: str | None = None
+ while True:
+ records, next_token = self._list_records(
+ sheet_id=sheet_id,
+ next_token=next_token,
+ )
+ for record in records:
+ rid = record.get("id")
+ if not rid:
+ continue
+ batch.append(SlimDocument(id=self._document_id(sheet_id, str(rid))))
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+ if not next_token:
+ break
+ if batch:
+ yield batch
+
def _yield_documents_from_table(
self,
start: SecondsSinceUnixEpoch | None = None,
diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py
index e65a6324185..83b2b562f0e 100644
--- a/common/data_source/discord_connector.py
+++ b/common/data_source/discord_connector.py
@@ -13,8 +13,14 @@
from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource
from common.data_source.exceptions import ConnectorMissingCredentialError
-from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
-from common.data_source.models import Document, GenerateDocumentsOutput, TextSection
+from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
+from common.data_source.models import (
+ Document,
+ GenerateDocumentsOutput,
+ GenerateSlimDocumentOutput,
+ SlimDocument,
+ TextSection,
+)
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
_SNIPPET_LENGTH = 30
@@ -94,8 +100,12 @@ async def _fetch_filtered_channels(
async def _fetch_documents_from_channel(
channel: TextChannel,
start_time: datetime | None,
- end_time: datetime | None,
-) -> AsyncIterable[Document]:
+) -> AsyncIterable[DiscordMessage]:
+ """Yield raw Discord messages for one channel and its threads.
+
+ This stays at the message layer so callers can decide whether they need
+ full Document construction or only lightweight ID accounting.
+ """
# Discord's epoch starts at 2015-01-01
discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc)
if start_time and start_time < discord_epoch:
@@ -109,39 +119,23 @@ async def _fetch_documents_from_channel(
async for channel_message in channel.history(
limit=None,
after=start_time,
- before=end_time,
):
# Skip messages that are not the default type
if channel_message.type != MessageType.default:
continue
- sections: list[TextSection] = [
- TextSection(
- text=channel_message.content,
- link=channel_message.jump_url,
- )
- ]
-
- yield _convert_message_to_document(channel_message, sections)
+ yield channel_message
for active_thread in channel.threads:
async for thread_message in active_thread.history(
limit=None,
after=start_time,
- before=end_time,
):
# Skip messages that are not the default type
if thread_message.type != MessageType.default:
continue
- sections = [
- TextSection(
- text=thread_message.content,
- link=thread_message.jump_url,
- )
- ]
-
- yield _convert_message_to_document(thread_message, sections)
+ yield thread_message
async for archived_thread in channel.archived_threads(
limit=None,
@@ -149,20 +143,12 @@ async def _fetch_documents_from_channel(
async for thread_message in archived_thread.history(
limit=None,
after=start_time,
- before=end_time,
):
# Skip messages that are not the default type
if thread_message.type != MessageType.default:
continue
- sections = [
- TextSection(
- text=thread_message.content,
- link=thread_message.jump_url,
- )
- ]
-
- yield _convert_message_to_document(thread_message, sections)
+ yield thread_message
def _manage_async_retrieval(
@@ -171,20 +157,23 @@ def _manage_async_retrieval(
channel_names: list[str],
server_ids: list[int],
start: datetime | None = None,
- end: datetime | None = None,
-) -> Iterable[Document]:
+) -> Iterable[DiscordMessage]:
+ """Bridge the async Discord client into a synchronous iterator.
+
+ `start` is only used as a lower bound for the underlying fetch. Callers
+ that need a narrower time window should apply their own filtering while
+ iterating so the same full scan can also support deleted-file sync.
+ """
# parse requested_start_date_string to datetime
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
- # Set start_time to the most recent of start and pull_date, or whichever is provided
+ # Keep the configured start date as the full-scan lower bound.
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
-
- end_time: datetime | None = end
proxy_url: str | None = os.environ.get("https_proxy") or os.environ.get("http_proxy")
if proxy_url:
logging.info(f"Using proxy for Discord: {proxy_url}")
- async def _async_fetch() -> AsyncIterable[Document]:
+ async def _async_fetch() -> AsyncIterable[DiscordMessage]:
intents = Intents.default()
intents.message_content = True
async with Client(intents=intents, proxy=proxy_url) as cli:
@@ -198,15 +187,13 @@ async def _async_fetch() -> AsyncIterable[Document]:
)
for channel in filtered_channels:
- async for doc in _fetch_documents_from_channel(
+ async for message in _fetch_documents_from_channel(
channel=channel,
start_time=start_time,
- end_time=end_time,
):
- print(doc)
- yield doc
+ yield message
- def run_and_yield() -> Iterable[Document]:
+ def run_and_yield() -> Iterable[DiscordMessage]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
@@ -228,7 +215,7 @@ def run_and_yield() -> Iterable[Document]:
return run_and_yield()
-class DiscordConnector(LoadConnector, PollConnector):
+class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""Discord connector for accessing Discord messages and channels"""
def __init__(
@@ -251,12 +238,28 @@ def discord_bot_token(self) -> str:
raise ConnectorMissingCredentialError("Discord")
return self._discord_bot_token
- def _manage_doc_batching(
+ def _iter_merged_documents(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateDocumentsOutput:
- doc_batch = []
+ """Build merged Discord documents for the requested polling window."""
+ doc_batch: list[Document] = []
+
+ def _message_created_at(message: DiscordMessage) -> datetime:
+ created_at = message.created_at
+ if created_at.tzinfo is None:
+ return created_at.replace(tzinfo=timezone.utc)
+ return created_at.astimezone(timezone.utc)
+
+ def _is_in_window(message: DiscordMessage) -> bool:
+ created_at = _message_created_at(message)
+ if start is not None and created_at < start:
+ return False
+ if end is not None and created_at >= end:
+ return False
+ return True
+
def merge_batch():
nonlocal doc_batch
id = doc_batch[0].id
@@ -280,14 +283,23 @@ def merge_batch():
size_bytes=size_bytes,
)
- for doc in _manage_async_retrieval(
+ for message in _manage_async_retrieval(
token=self.discord_bot_token,
requested_start_date_string=self.requested_start_date_string,
channel_names=self.channel_names,
server_ids=self.server_ids,
start=start,
- end=end,
):
+ if not _is_in_window(message):
+ continue
+
+ sections = [
+ TextSection(
+ text=message.content,
+ link=message.jump_url,
+ )
+ ]
+ doc = _convert_message_to_document(message, sections)
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield [merge_batch()]
@@ -296,6 +308,13 @@ def merge_batch():
if doc_batch:
yield [merge_batch()]
+ def _manage_doc_batching(
+ self,
+ start: datetime | None = None,
+ end: datetime | None = None,
+ ) -> GenerateDocumentsOutput:
+ yield from self._iter_merged_documents(start=start, end=end)
+
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._discord_bot_token = credentials["discord_bot_token"]
return None
@@ -316,6 +335,41 @@ def load_from_state(self) -> Any:
"""Load messages from Discord state"""
return self._manage_doc_batching(None, None)
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ del callback
+ slim_doc_batch: list[SlimDocument] = []
+ full_scan_batch_size = 0
+ full_scan_batch_first_id: str | None = None
+
+ for message in _manage_async_retrieval(
+ token=self.discord_bot_token,
+ requested_start_date_string=self.requested_start_date_string,
+ channel_names=self.channel_names,
+ server_ids=self.server_ids,
+ start=None,
+ ):
+ if full_scan_batch_first_id is None:
+ full_scan_batch_first_id = f"{_DISCORD_DOC_ID_PREFIX}{message.id}"
+ full_scan_batch_size += 1
+
+ if full_scan_batch_size >= self.batch_size:
+ slim_doc_batch.append(SlimDocument(id=full_scan_batch_first_id))
+ full_scan_batch_size = 0
+ full_scan_batch_first_id = None
+
+ if len(slim_doc_batch) >= self.batch_size:
+ yield slim_doc_batch
+ slim_doc_batch = []
+
+ if full_scan_batch_first_id is not None:
+ slim_doc_batch.append(SlimDocument(id=full_scan_batch_first_id))
+
+ if slim_doc_batch:
+ yield slim_doc_batch
+
if __name__ == "__main__":
import os
diff --git a/common/data_source/dropbox_connector.py b/common/data_source/dropbox_connector.py
index 0e7131d8f3b..43ab08f4b06 100644
--- a/common/data_source/dropbox_connector.py
+++ b/common/data_source/dropbox_connector.py
@@ -14,14 +14,14 @@
ConnectorValidationError,
InsufficientPermissionsError,
)
-from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
-from common.data_source.models import Document, GenerateDocumentsOutput
+from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync
+from common.data_source.models import Document, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument
from common.data_source.utils import get_file_ext
logger = logging.getLogger(__name__)
-class DropboxConnector(LoadConnector, PollConnector):
+class DropboxConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""Dropbox connector for accessing Dropbox files and folders"""
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
@@ -87,57 +87,48 @@ def _yield_files_recursive(
if self.dropbox_client is None:
raise ConnectorMissingCredentialError("Dropbox")
- # Collect all files first to count filename occurrences
- all_files = []
- self._collect_files_recursive(path, start, end, all_files)
-
+ all_files: list[FileMetadata] = []
+ self._collect_file_entries_recursive(path, start, end, all_files)
+
# Count filename occurrences
filename_counts: dict[str, int] = {}
- for entry, _ in all_files:
+ for entry in all_files:
filename_counts[entry.name] = filename_counts.get(entry.name, 0) + 1
-
+
# Process files in batches
batch: list[Document] = []
- for entry, downloaded_file in all_files:
- modified_time = entry.client_modified
- if modified_time.tzinfo is None:
- modified_time = modified_time.replace(tzinfo=timezone.utc)
- else:
- modified_time = modified_time.astimezone(timezone.utc)
-
- # Use full path only if filename appears multiple times
- if filename_counts.get(entry.name, 0) > 1:
- # Remove leading slash and replace slashes with ' / '
- relative_path = entry.path_display.lstrip('/')
- semantic_id = relative_path.replace('/', ' / ') if relative_path else entry.name
- else:
- semantic_id = entry.name
-
+ for entry in all_files:
+ try:
+ downloaded_file = self._download_file(entry.path_display)
+ except Exception:
+ logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}")
+ continue
+
batch.append(
Document(
id=f"dropbox:{entry.id}",
blob=downloaded_file,
source=DocumentSource.DROPBOX,
- semantic_identifier=semantic_id,
+ semantic_identifier=self._get_semantic_identifier(entry, filename_counts),
extension=get_file_ext(entry.name),
- doc_updated_at=modified_time,
+ doc_updated_at=self._normalize_modified_time(entry.client_modified),
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
)
)
-
+
if len(batch) == self.batch_size:
yield batch
batch = []
-
+
if batch:
yield batch
- def _collect_files_recursive(
+ def _collect_file_entries_recursive(
self,
path: str,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
- all_files: list,
+ all_files: list[FileMetadata],
) -> None:
"""Recursively collect all files matching time criteria."""
if self.dropbox_client is None:
@@ -152,33 +143,56 @@ def _collect_files_recursive(
while True:
for entry in result.entries:
if isinstance(entry, FileMetadata):
- modified_time = entry.client_modified
- if modified_time.tzinfo is None:
- modified_time = modified_time.replace(tzinfo=timezone.utc)
- else:
- modified_time = modified_time.astimezone(timezone.utc)
-
- time_as_seconds = modified_time.timestamp()
+ time_as_seconds = self._normalize_modified_time(entry.client_modified).timestamp()
if start is not None and time_as_seconds <= start:
continue
if end is not None and time_as_seconds > end:
continue
- try:
- downloaded_file = self._download_file(entry.path_display)
- all_files.append((entry, downloaded_file))
- except Exception:
- logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}")
- continue
+ all_files.append(entry)
elif isinstance(entry, FolderMetadata):
- self._collect_files_recursive(entry.path_lower, start, end, all_files)
+ self._collect_file_entries_recursive(entry.path_lower, start, end, all_files)
if not result.has_more:
break
result = self.dropbox_client.files_list_folder_continue(result.cursor)
+ def _normalize_modified_time(self, modified_time):
+ if modified_time.tzinfo is None:
+ return modified_time.replace(tzinfo=timezone.utc)
+ return modified_time.astimezone(timezone.utc)
+
+ def _get_semantic_identifier(self, entry: FileMetadata, filename_counts: dict[str, int]) -> str:
+ if filename_counts.get(entry.name, 0) <= 1:
+ return entry.name
+
+ relative_path = entry.path_display.lstrip("/")
+ return relative_path.replace("/", " / ") if relative_path else entry.name
+
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ del callback
+
+ if self.dropbox_client is None:
+ raise ConnectorMissingCredentialError("Dropbox")
+
+ all_files: list[FileMetadata] = []
+ self._collect_file_entries_recursive("", None, None, all_files)
+
+ batch: list[SlimDocument] = []
+ for entry in all_files:
+ batch.append(SlimDocument(id=f"dropbox:{entry.id}"))
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+
+ if batch:
+ yield batch
+
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
"""Poll Dropbox for recent file changes"""
if self.dropbox_client is None:
diff --git a/common/data_source/github/connector.py b/common/data_source/github/connector.py
index 258e2cf8b46..2d65c995e6b 100644
--- a/common/data_source/github/connector.py
+++ b/common/data_source/github/connector.py
@@ -964,11 +964,9 @@ def retrieve_slim_document(
def retrieve_all_slim_docs_perm_sync(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> GenerateSlimDocumentOutput:
- yield from self.retrieve_slim_document(start=start, end=end, callback=callback)
+ yield from self.retrieve_slim_document(callback=callback)
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint(
diff --git a/common/data_source/gitlab_connector.py b/common/data_source/gitlab_connector.py
index 0d2c0dab775..dae24992b49 100644
--- a/common/data_source/gitlab_connector.py
+++ b/common/data_source/gitlab_connector.py
@@ -20,8 +20,11 @@
from common.data_source.interfaces import LoadConnector
from common.data_source.interfaces import PollConnector
from common.data_source.interfaces import SecondsSinceUnixEpoch
+from common.data_source.interfaces import SlimConnectorWithPermSync
from common.data_source.models import BasicExpertInfo
from common.data_source.models import Document
+from common.data_source.models import GenerateSlimDocumentOutput
+from common.data_source.models import SlimDocument
from common.data_source.utils import get_file_ext
T = TypeVar("T")
@@ -158,7 +161,7 @@ def _should_exclude(path: str) -> bool:
return any(fnmatch.fnmatch(path, pattern) for pattern in exclude_patterns)
-class GitlabConnector(LoadConnector, PollConnector):
+class GitlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
def __init__(
self,
project_owner: str,
@@ -313,6 +316,67 @@ def poll_source(
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
return self._fetch_from_gitlab(start_datetime, end_datetime)
+ def retrieve_all_slim_docs_perm_sync(self, callback: Any = None) -> GenerateSlimDocumentOutput:
+ if self.gitlab_client is None:
+ raise ConnectorMissingCredentialError("Gitlab")
+
+ project: Project = self.gitlab_client.projects.get(
+ f"{self.project_owner}/{self.project_name}"
+ )
+
+ slim_batch: list[SlimDocument] = []
+
+ def append_doc(doc_id: str):
+ slim_batch.append(SlimDocument(id=doc_id))
+ if len(slim_batch) >= self.batch_size:
+ batch = slim_batch[:]
+ slim_batch.clear()
+ return batch
+ return None
+
+ if self.include_code_files:
+ default_branch = project.default_branch
+ queue = deque([""])
+ while queue:
+ current_path = queue.popleft()
+ files = project.repository_tree(path=current_path, all=True)
+ for file in files:
+ if _should_exclude(file["path"]):
+ continue
+ if file["type"] == "tree":
+ queue.append(file["path"])
+ continue
+ if file["type"] != "blob":
+ continue
+
+ file_url = f"{self.gitlab_client.url}/{self.project_owner}/{self.project_name}/-/blob/{default_branch}/{file['path']}"
+ batch = append_doc(file_url)
+ if batch:
+ yield batch
+
+ if self.include_mrs:
+ merge_requests = project.mergerequests.list(
+ state=self.state_filter,
+ iterator=True,
+ )
+ for mr in merge_requests:
+ batch = append_doc(mr.web_url)
+ if batch:
+ yield batch
+
+ if self.include_issues:
+ issues = project.issues.list(
+ state=self.state_filter,
+ iterator=True,
+ )
+ for issue in issues:
+ batch = append_doc(issue.web_url)
+ if batch:
+ yield batch
+
+ if slim_batch:
+ yield slim_batch
+
if __name__ == "__main__":
import os
@@ -337,4 +401,4 @@ def poll_source(
document_batches = connector.load_from_state()
for f in document_batches:
print("Batch:", f)
- print("Finished loading from state.")
\ No newline at end of file
+ print("Finished loading from state.")
diff --git a/common/data_source/gmail_connector.py b/common/data_source/gmail_connector.py
index 1421f9f4bf1..ea4dd993ae0 100644
--- a/common/data_source/gmail_connector.py
+++ b/common/data_source/gmail_connector.py
@@ -270,12 +270,10 @@ def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch)
def retrieve_all_slim_docs_perm_sync(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
callback=None,
) -> GenerateSlimDocumentOutput:
"""Retrieve slim documents for permission synchronization."""
- query = build_time_range_query(start, end)
+ query = build_time_range_query()
doc_batch = []
for user_email in self._get_all_user_emails():
@@ -343,4 +341,4 @@ def retrieve_all_slim_docs_perm_sync(
print(f)
print("\n\n")
except Exception as e:
- logging.exception(f"Error loading credentials: {e}")
\ No newline at end of file
+ logging.exception(f"Error loading credentials: {e}")
diff --git a/common/data_source/google_drive/connector.py b/common/data_source/google_drive/connector.py
index b44c28d74db..479c60e0b63 100644
--- a/common/data_source/google_drive/connector.py
+++ b/common/data_source/google_drive/connector.py
@@ -159,6 +159,7 @@ def __init__(
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._creds_dict: dict[str, Any] | None = None
+ self._all_drive_ids_cache: set[str] | None = None
# ids of folders and shared drives that have been traversed
self._retrieved_folder_and_drive_ids: set[str] = set()
@@ -211,6 +212,7 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None
self.include_files_shared_with_me = True
self._creds_dict = new_creds_dict
+ self._all_drive_ids_cache = None
return new_creds_dict
@@ -249,7 +251,11 @@ def _get_all_user_emails(self) -> list[str]:
return user_emails
def get_all_drive_ids(self) -> set[str]:
- return self._get_all_drives_for_user(self.primary_admin_email)
+ if self._all_drive_ids_cache is None:
+ self._all_drive_ids_cache = self._get_all_drives_for_user(
+ self.primary_admin_email
+ )
+ return set(self._all_drive_ids_cache)
def _get_all_drives_for_user(self, user_email: str) -> set[str]:
drive_service = get_drive_service(self.creds, user_email)
@@ -265,7 +271,14 @@ def _get_all_drives_for_user(self, user_email: str) -> set[str]:
all_drive_ids.add(drive["id"])
if not all_drive_ids:
- self.logger.warning("No drives found even though indexing shared drives was requested.")
+ if self._requested_shared_drive_ids:
+ self.logger.warning(
+ "No shared drives found for user %s while resolving requested shared drives.",
+ user_email,
+ )
+ elif self.include_shared_drives:
+ log_fn = self.logger.warning if is_service_account else self.logger.info
+ log_fn("No shared drives found for user %s.", user_email)
return all_drive_ids
@@ -1087,8 +1100,6 @@ def _extract_slim_docs_from_google_drive(
def retrieve_all_slim_docs_perm_sync(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
try:
@@ -1096,8 +1107,6 @@ def retrieve_all_slim_docs_perm_sync(
while checkpoint.completion_stage != DriveRetrievalStage.DONE:
yield from self._extract_slim_docs_from_google_drive(
checkpoint=checkpoint,
- start=start,
- end=end,
)
self.logger.info("Drive perm sync: Slim doc retrieval complete")
diff --git a/common/data_source/google_util/resource.py b/common/data_source/google_util/resource.py
index eb060e46883..ba4199cb078 100644
--- a/common/data_source/google_util/resource.py
+++ b/common/data_source/google_util/resource.py
@@ -85,9 +85,19 @@ def _get_google_service(
if isinstance(creds, ServiceAccountCredentials):
# NOTE: https://developers.google.com/identity/protocols/oauth2/service-account#error-codes
creds = creds.with_subject(user_email)
- service = build(service_name, service_version, credentials=creds)
+ service = build(
+ service_name,
+ service_version,
+ credentials=creds,
+ cache_discovery=False,
+ )
elif isinstance(creds, OAuthCredentials):
- service = build(service_name, service_version, credentials=creds)
+ service = build(
+ service_name,
+ service_version,
+ credentials=creds,
+ cache_discovery=False,
+ )
return service
diff --git a/common/data_source/imap_connector.py b/common/data_source/imap_connector.py
index f682676e8ed..a8c1988f6ce 100644
--- a/common/data_source/imap_connector.py
+++ b/common/data_source/imap_connector.py
@@ -1,5 +1,6 @@
import copy
import email
+import hashlib
from email.header import decode_header
import imaplib
import logging
@@ -12,14 +13,26 @@
from enum import Enum
from typing import Any
from typing import cast
-import uuid
import bs4
from pydantic import BaseModel
from common.data_source.config import IMAP_CONNECTOR_SIZE_THRESHOLD, DocumentSource
-from common.data_source.interfaces import CheckpointOutput, CheckpointedConnectorWithPermSync, CredentialsConnector, CredentialsProviderInterface
-from common.data_source.models import BasicExpertInfo, ConnectorCheckpoint, Document, ExternalAccess, SecondsSinceUnixEpoch
+from common.data_source.interfaces import (
+ CheckpointOutput,
+ CheckpointedConnectorWithPermSync,
+ CredentialsConnector,
+ CredentialsProviderInterface,
+)
+from common.data_source.models import (
+ BasicExpertInfo,
+ ConnectorCheckpoint,
+ Document,
+ ExternalAccess,
+ GenerateSlimDocumentOutput,
+ SecondsSinceUnixEpoch,
+ SlimDocument,
+)
_DEFAULT_IMAP_PORT_NUMBER = int(os.environ.get("IMAP_PORT", 993))
_IMAP_OKAY_STATUS = "OK"
@@ -86,9 +99,6 @@ def _parse_date(date_str: str | None) -> datetime | None:
except (TypeError, ValueError):
return None
- message_id = _decode(header=Header.MESSAGE_ID_HEADER)
- if not message_id:
- message_id = f""
# It's possible for the subject line to not exist or be an empty string.
subject = _decode(header=Header.SUBJECT_HEADER) or "Unknown Subject"
from_ = _decode(header=Header.FROM_HEADER)
@@ -97,11 +107,27 @@ def _parse_date(date_str: str | None) -> datetime | None:
to = _decode(header=Header.DELIVERED_TO_HEADER)
cc = _decode(header=Header.CC_HEADER)
date_str = _decode(header=Header.DATE_HEADER)
- date = _parse_date(date_str=date_str)
+ parsed_date = _parse_date(date_str=date_str)
+ date = parsed_date
if not date:
date = datetime.now(tz=timezone.utc)
+ message_id = _decode(header=Header.MESSAGE_ID_HEADER)
+ if not message_id:
+ message_id = _build_stable_generated_message_id(
+ email_msg=email_msg,
+ subject=subject,
+ sender=from_ or "",
+ recipients=to or "",
+ cc=cc or "",
+ date_key=(
+ _as_utc(parsed_date).isoformat()
+ if parsed_date
+ else (date_str or "")
+ ),
+ )
+
# If any of the above are `None`, model validation will fail.
# Therefore, no guards (i.e.: `if is None: raise RuntimeError(..)`) were written.
return cls.model_validate(
@@ -269,12 +295,7 @@ def _load_from_checkpoint(
continue
email_headers = EmailHeaders.from_email_msg(email_msg=email_msg)
- msg_dt = email_headers.date
- if msg_dt.tzinfo is None:
- msg_dt = msg_dt.replace(tzinfo=timezone.utc)
- else:
- msg_dt = msg_dt.astimezone(timezone.utc)
-
+ msg_dt = _as_utc(email_headers.date)
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
end_dt = datetime.fromtimestamp(end, tz=timezone.utc)
@@ -339,6 +360,64 @@ def load_from_checkpoint_with_perm_sync(
start=start, end=end, checkpoint=checkpoint, include_perm_sync=True
)
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ start: SecondsSinceUnixEpoch | None = None,
+ end: SecondsSinceUnixEpoch | None = None,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ del callback
+ mail_client = self._get_mail_client()
+ start_ts = start if start is not None else 0
+ end_ts = (
+ end if end is not None else datetime.now(tz=timezone.utc).timestamp()
+ )
+ start_dt = datetime.fromtimestamp(start_ts, tz=timezone.utc)
+ end_dt = datetime.fromtimestamp(end_ts, tz=timezone.utc)
+
+ if self._mailboxes:
+ mailboxes = _sanitize_mailbox_names(self._mailboxes)
+ else:
+ mailboxes = _sanitize_mailbox_names(
+ _fetch_all_mailboxes_for_email_account(mail_client=mail_client)
+ )
+
+ slim_doc_batch: list[SlimDocument] = []
+ for mailbox in mailboxes:
+ email_ids = _fetch_email_ids_in_mailbox(
+ mail_client=mail_client,
+ mailbox=mailbox,
+ start=start_ts,
+ end=end_ts,
+ )
+ _select_mailbox(mail_client=mail_client, mailbox=mailbox)
+
+ for email_id in email_ids:
+ email_msg = _fetch_email(mail_client=mail_client, email_id=email_id)
+ if not email_msg:
+ logging.warning(f"Failed to fetch message {email_id=}; skipping")
+ continue
+
+ email_headers = EmailHeaders.from_email_msg(email_msg=email_msg)
+ msg_dt = _as_utc(email_headers.date)
+ if not (start_dt < msg_dt <= end_dt):
+ continue
+
+ slim_doc_batch.append(SlimDocument(id=email_headers.id))
+ for att in extract_attachments(email_msg):
+ slim_doc_batch.append(
+ SlimDocument(
+ id=_attachment_document_id(email_headers.id, att)
+ )
+ )
+
+ if len(slim_doc_batch) >= _PAGE_SIZE:
+ yield slim_doc_batch
+ slim_doc_batch = []
+
+ if slim_doc_batch:
+ yield slim_doc_batch
+
def _fetch_all_mailboxes_for_email_account(mail_client: imaplib.IMAP4_SSL) -> list[str]:
status, mailboxes_data = mail_client.list('""', "*")
@@ -435,6 +514,39 @@ def _fetch_email(mail_client: imaplib.IMAP4_SSL, email_id: str) -> Message | Non
return email.message_from_bytes(raw_email)
+def _as_utc(dt: datetime) -> datetime:
+ if dt.tzinfo is None:
+ return dt.replace(tzinfo=timezone.utc)
+ return dt.astimezone(timezone.utc)
+
+
+def _build_stable_generated_message_id(
+ email_msg: Message,
+ subject: str,
+ sender: str,
+ recipients: str,
+ cc: str,
+ date_key: str,
+) -> str:
+ body = _extract_email_body_text(email_msg)
+ raw_digest = hashlib.sha256(email_msg.as_bytes()).hexdigest()
+ body_digest = hashlib.sha256(body.encode("utf-8")).hexdigest()
+ digest = hashlib.sha256(
+ "\n".join(
+ [
+ subject,
+ date_key,
+ sender,
+ recipients,
+ cc,
+ body_digest,
+ raw_digest,
+ ]
+ ).encode("utf-8")
+ ).hexdigest()
+ return f"generated:{digest}"
+
+
def _convert_email_headers_and_body_into_document(
email_msg: Message,
email_headers: EmailHeaders,
@@ -544,6 +656,13 @@ def decode_mime_filename(raw: str | None) -> str | None:
return "".join(decoded)
+
+def _attachment_document_id(parent_doc_id: str, att: dict) -> str:
+ raw_filename = att["filename"]
+ filename = decode_mime_filename(raw_filename) or "attachment.bin"
+ return f"{parent_doc_id}#att:{filename}"
+
+
def attachment_to_document(
parent_doc: Document,
att: dict,
@@ -554,7 +673,7 @@ def attachment_to_document(
ext = "." + filename.split(".")[-1] if "." in filename else ""
return Document(
- id=f"{parent_doc.id}#att:{filename}",
+ id=_attachment_document_id(parent_doc.id, att),
source=DocumentSource.IMAP,
semantic_identifier=filename,
extension=ext,
@@ -574,6 +693,15 @@ def _parse_email_body(
email_msg: Message,
email_headers: EmailHeaders,
) -> str:
+ body = _extract_email_body_text(email_msg)
+ if not body:
+ logging.warning(
+ f"Email with {email_headers.id=} has an empty body; returning an empty string"
+ )
+ return body
+
+
+def _extract_email_body_text(email_msg: Message) -> str:
body = None
for part in email_msg.walk():
if part.is_multipart():
@@ -598,9 +726,6 @@ def _parse_email_body(
continue
if not body:
- logging.warning(
- f"Email with {email_headers.id=} has an empty body; returning an empty string"
- )
return ""
soup = bs4.BeautifulSoup(markup=body, features="html.parser")
@@ -636,6 +761,7 @@ def _parse_singular_addr(raw_header: str) -> tuple[str, str]:
if __name__ == "__main__":
import time
+ import uuid
from types import TracebackType
from common.data_source.utils import load_all_docs_from_checkpoint_connector
diff --git a/common/data_source/interfaces.py b/common/data_source/interfaces.py
index b68a40c1e1a..324293baaba 100644
--- a/common/data_source/interfaces.py
+++ b/common/data_source/interfaces.py
@@ -60,8 +60,6 @@ class SlimConnectorWithPermSync(ABC):
@abstractmethod
def retrieve_all_slim_docs_perm_sync(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Generator[list[SlimDocument], None, None]:
"""Retrieve all simplified documents (with permission sync)"""
diff --git a/common/data_source/jira/connector.py b/common/data_source/jira/connector.py
index db3c3f8942d..aa4082f4149 100644
--- a/common/data_source/jira/connector.py
+++ b/common/data_source/jira/connector.py
@@ -149,7 +149,10 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None
else:
logger.warning("[Jira] Scoped token requested but Jira base URL does not appear to be an Atlassian Cloud domain; scoped token ignored.")
- user_email = credentials.get("jira_user_email") or credentials.get("username")
+ user_email = (
+ credentials.get("jira_user_email")
+ or credentials.get("jira_username")
+ )
api_token = credentials.get("jira_api_token") or credentials.get("token") or credentials.get("api_token")
password = credentials.get("jira_password") or credentials.get("password")
rest_api_version = credentials.get("rest_api_version")
@@ -377,16 +380,14 @@ def validate_checkpoint_json(self, checkpoint_json: str) -> JiraCheckpoint:
def retrieve_all_slim_docs_perm_sync(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
- callback: Any = None, # noqa: ARG002 - maintained for interface compatibility
+ callback: Any = None, # noqa: ARG002 - callback interface hook
) -> Generator[list[SlimDocument], None, None]:
"""Return lightweight references to Jira issues (used for permission syncing)."""
if not self.jira_client:
raise ConnectorMissingCredentialError("Jira")
- start_ts = start if start is not None else 0
- end_ts = end if end is not None else datetime.now(timezone.utc).timestamp()
+ start_ts = 0
+ end_ts = datetime.now(timezone.utc).timestamp()
jql = self._build_jql(start_ts, end_ts)
checkpoint = self.build_dummy_checkpoint()
@@ -962,7 +963,16 @@ def main(config: dict[str, Any] | None = None) -> None:
if not base_url:
raise RuntimeError("Jira base URL must be provided via config or CLI arguments.")
- if not (credentials.get("jira_api_token") or (credentials.get("jira_user_email") and credentials.get("jira_password"))):
+ if not (
+ credentials.get("jira_api_token")
+ or (
+ (
+ credentials.get("jira_user_email")
+ or credentials.get("jira_username")
+ )
+ and credentials.get("jira_password")
+ )
+ ):
raise RuntimeError("Provide either an API token or both email/password for Jira authentication.")
connector_options = {
diff --git a/common/data_source/moodle_connector.py b/common/data_source/moodle_connector.py
index 39efcf07be0..850ce5815d1 100644
--- a/common/data_source/moodle_connector.py
+++ b/common/data_source/moodle_connector.py
@@ -21,14 +21,19 @@
LoadConnector,
PollConnector,
SecondsSinceUnixEpoch,
+ SlimConnectorWithPermSync,
+)
+from common.data_source.models import (
+ Document,
+ GenerateSlimDocumentOutput,
+ SlimDocument,
)
-from common.data_source.models import Document
from common.data_source.utils import batch_generator, rl_requests
logger = logging.getLogger(__name__)
-class MoodleConnector(LoadConnector, PollConnector):
+class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""Moodle LMS connector for accessing course content"""
def __init__(self, moodle_url: str, batch_size: int = INDEX_BATCH_SIZE) -> None:
@@ -137,6 +142,78 @@ def poll_source(
self._get_updated_content(courses, start, end)
)
+ @staticmethod
+ def _slim_doc_id_for_module(module) -> Optional[str]:
+ """Return the indexed document id for a Moodle module, or None.
+
+ The id format must match the ones produced by the _process_*
+ helpers below. Module types that we never ingest (label, url) and
+ modules with no id return None.
+ """
+ mtype = getattr(module, "modname", None)
+ mid = getattr(module, "id", None)
+ if not mtype or mid is None:
+ return None
+ if mtype in ("label", "url"):
+ return None
+ if mtype == "resource":
+ return f"moodle_resource_{mid}"
+ if mtype == "forum":
+ return f"moodle_forum_{mid}"
+ if mtype == "page":
+ return f"moodle_page_{mid}"
+ if mtype == "book":
+ return f"moodle_book_{mid}"
+ if mtype in ("assign", "quiz"):
+ return f"moodle_{mtype}_{mid}"
+ return None
+
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ """List the ids of every Moodle module that could be indexed.
+
+ This is a lightweight pass over courses and modules with no file
+ downloads. The caller compares the returned ids against the index
+ and removes any indexed document whose id is not in this list.
+ """
+ del callback
+ if not self.moodle_client:
+ raise ConnectorMissingCredentialError("Moodle client not initialized")
+
+ logger.info("Starting Moodle slim snapshot for stale-document cleanup")
+ courses = self._get_enrolled_courses()
+ if not courses:
+ logger.warning("No courses found for slim snapshot")
+ return
+
+ batch: list[SlimDocument] = []
+ total = 0
+ for course in courses:
+ try:
+ contents = self._get_course_contents(course.id)
+ for section in contents:
+ for module in section.modules:
+ slim_id = self._slim_doc_id_for_module(module)
+ if slim_id is None:
+ continue
+ batch.append(SlimDocument(id=slim_id))
+ total += 1
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+ except Exception as e:
+ self._log_error(
+ f"slim snapshot for course {getattr(course, 'fullname', '?')}",
+ e,
+ )
+
+ if batch:
+ yield batch
+
+ logger.info(f"Moodle slim snapshot completed: {total} documents listed")
+
@retry(tries=3, delay=1, backoff=2)
def _get_enrolled_courses(self) -> list:
if not self.moodle_client:
diff --git a/common/data_source/notion_connector.py b/common/data_source/notion_connector.py
index 30536dfb944..ea3d6d07646 100644
--- a/common/data_source/notion_connector.py
+++ b/common/data_source/notion_connector.py
@@ -28,9 +28,11 @@
from common.data_source.models import (
Document,
GenerateDocumentsOutput,
+ GenerateSlimDocumentOutput,
NotionBlock,
NotionPage,
NotionSearchResponse,
+ SlimDocument,
TextSection,
)
from common.data_source.utils import (
@@ -433,6 +435,45 @@ def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str]
return result_blocks, child_pages, attachments
+ def _read_slim_blocks(self, base_block_id: str) -> tuple[list[str], list[str]]:
+ child_pages: list[str] = []
+ attachment_ids: list[str] = []
+ cursor = None
+
+ while True:
+ data = self._fetch_child_blocks(base_block_id, cursor)
+
+ if data is None:
+ return child_pages, attachment_ids
+
+ for result in data["results"]:
+ result_block_id = result["id"]
+ result_type = result["type"]
+
+ if result_type in {"file", "image", "pdf", "video", "audio"}:
+ attachment_ids.append(result_block_id)
+
+ if result["has_children"]:
+ if result_type == "child_page":
+ child_pages.append(result_block_id)
+ else:
+ nested_child_pages, nested_attachment_ids = self._read_slim_blocks(
+ result_block_id
+ )
+ child_pages.extend(nested_child_pages)
+ attachment_ids.extend(nested_attachment_ids)
+
+ if result_type == "child_database" and self.recursive_index_enabled:
+ _, inner_child_pages = self._read_pages_from_database(result_block_id)
+ child_pages.extend(inner_child_pages)
+
+ if data["next_cursor"] is None:
+ break
+
+ cursor = data["next_cursor"]
+
+ return child_pages, attachment_ids
+
def _read_page_title(self, page: NotionPage) -> Optional[str]:
"""Extracts the title from a Notion page."""
if hasattr(page, "database_name") and page.database_name:
@@ -552,6 +593,79 @@ def _recursive_load(self, start: SecondsSinceUnixEpoch | None = None, end: Secon
pages = [self._fetch_page(page_id=self.root_page_id)]
yield from batch_generator(self._read_pages(pages, start, end), self.batch_size)
+ def _read_pages_for_slim_docs(
+ self,
+ pages: list[NotionPage],
+ slim_indexed_pages: set[str],
+ ) -> Generator[SlimDocument, None, None]:
+ all_child_page_ids: list[str] = []
+
+ for page in pages:
+ if isinstance(page, dict):
+ page = NotionPage(**page)
+ if page.id in slim_indexed_pages:
+ continue
+
+ child_page_ids, attachment_ids = self._read_slim_blocks(page.id)
+ all_child_page_ids.extend(child_page_ids)
+ slim_indexed_pages.add(page.id)
+
+ yield SlimDocument(id=page.id)
+ for attachment_id in attachment_ids:
+ yield SlimDocument(id=attachment_id)
+
+ if self.recursive_index_enabled and all_child_page_ids:
+ for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE):
+ child_page_batch = [
+ self._fetch_page(page_id)
+ for page_id in child_page_batch_ids
+ if page_id not in slim_indexed_pages
+ ]
+ yield from self._read_pages_for_slim_docs(
+ child_page_batch,
+ slim_indexed_pages,
+ )
+
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ slim_indexed_pages: set[str] = set()
+
+ if self.recursive_index_enabled and self.root_page_id:
+ root_pages = [self._fetch_page(page_id=self.root_page_id)]
+ yield from batch_generator(
+ self._read_pages_for_slim_docs(root_pages, slim_indexed_pages),
+ self.batch_size,
+ )
+ return
+
+ query_dict = {
+ "filter": {"property": "object", "value": "page"},
+ "page_size": 100,
+ }
+
+ slim_batch: list[SlimDocument] = []
+ while True:
+ db_res = self._search_notion(query_dict)
+ pages = [NotionPage(**page) for page in db_res.results]
+
+ for doc in self._read_pages_for_slim_docs(pages, slim_indexed_pages):
+ slim_batch.append(doc)
+ if len(slim_batch) >= self.batch_size:
+ yield slim_batch
+ slim_batch = []
+ if callback:
+ callback.progress("notion_slim_document", 1)
+
+ if db_res.has_more:
+ query_dict["start_cursor"] = db_res.next_cursor
+ else:
+ break
+
+ if slim_batch:
+ yield slim_batch
+
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Applies integration token to headers."""
self.headers["Authorization"] = f"Bearer {credentials['notion_integration_token']}"
@@ -653,4 +767,4 @@ def validate_connector_settings(self) -> None:
document_batches = connector.load_from_state()
for doc_batch in document_batches:
for doc in doc_batch:
- print(doc)
\ No newline at end of file
+ print(doc)
diff --git a/common/data_source/rdbms_connector.py b/common/data_source/rdbms_connector.py
index 05628501c65..9811d2064dc 100644
--- a/common/data_source/rdbms_connector.py
+++ b/common/data_source/rdbms_connector.py
@@ -1,5 +1,6 @@
"""RDBMS (MySQL/PostgreSQL) data source connector for importing data from relational databases."""
+import copy
import hashlib
import json
import logging
@@ -12,8 +13,13 @@
ConnectorMissingCredentialError,
ConnectorValidationError,
)
-from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch
-from common.data_source.models import Document
+from common.data_source.interfaces import (
+ LoadConnector,
+ PollConnector,
+ SecondsSinceUnixEpoch,
+ SlimConnectorWithPermSync,
+)
+from common.data_source.models import Document, SlimDocument
class DatabaseType(str, Enum):
@@ -22,15 +28,18 @@ class DatabaseType(str, Enum):
POSTGRESQL = "postgresql"
-class RDBMSConnector(LoadConnector, PollConnector):
+class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""
- RDBMS connector for importing data from MySQL and PostgreSQL databases.
-
- This connector allows users to:
- 1. Connect to a MySQL or PostgreSQL database
- 2. Execute a SQL query to extract data
- 3. Map columns to content (for vectorization) and metadata
- 4. Sync data in batch or incremental mode using a timestamp column
+ Import rows from MySQL or PostgreSQL into documents.
+
+ The flow is:
+ 1. Connect to the configured database.
+ 2. Read rows from a custom SQL query, or from every table when no query is provided.
+ 3. Build document content from the selected content columns.
+ 4. Copy the selected metadata columns into document metadata.
+ 5. Use the configured ID column as the stable document ID, or hash the content when no ID column is set.
+ 6. For incremental sync, treat the timestamp column as an ordered cursor and only compare values by size.
+ 7. For deleted-file sync, read a slim snapshot of current row IDs and let the sync worker remove stale documents.
"""
def __init__(
self,
@@ -73,6 +82,9 @@ def __init__(
self._connection = None
self._credentials: Dict[str, Any] = {}
+ self._sync_connector_id: str | None = None
+ self._sync_config: Dict[str, Any] | None = None
+ self._pending_sync_cursor_value: Any = None
def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None:
"""Load database credentials."""
@@ -160,98 +172,175 @@ def _get_tables(self) -> list[str]:
finally:
cursor.close()
- def _build_query_with_time_filter(
+
+ def _get_base_queries(self) -> list[str]:
+ if self.query:
+ return [self.query.rstrip(";")]
+ return [f"SELECT * FROM {table}" for table in self._get_tables()]
+
+
+ def _wrap_query(self, base_query: str, select_clause: str = "*") -> str:
+ return f"SELECT {select_clause} FROM ({base_query}) AS ragflow_src"
+
+
+ @staticmethod
+ def serialize_cursor_value(value: Any) -> Any:
+ # Example:
+ # - int cursor 42 is stored as 42
+ # - datetime cursor 2026-05-07T12:34:56+00:00 is stored as
+ # {"__ragflow_rdbms_cursor_type__": "datetime", "value": "..."}
+ # Only datetime needs wrapping because connector config is JSON.
+ if isinstance(value, datetime):
+ return {
+ "__ragflow_rdbms_cursor_type__": "datetime",
+ "value": value.isoformat(),
+ }
+ return value
+
+
+ @staticmethod
+ def deserialize_cursor_value(value: Any) -> Any:
+ # Reverse the datetime wrapper above.
+ # Non-datetime cursors such as int/str/float are returned as-is.
+ if (
+ isinstance(value, dict)
+ and value.get("__ragflow_rdbms_cursor_type__") == "datetime"
+ ):
+ return datetime.fromisoformat(value["value"])
+ return value
+
+
+ def _format_sql_value(self, value: Any) -> str:
+ if isinstance(value, datetime):
+ if value.tzinfo is None:
+ value = value.replace(tzinfo=timezone.utc)
+ if self.db_type == DatabaseType.MYSQL:
+ rendered = value.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
+ else:
+ rendered = value.astimezone(timezone.utc).isoformat()
+ return f"'{rendered}'"
+ if isinstance(value, bool):
+ if self.db_type == DatabaseType.POSTGRESQL:
+ return "TRUE" if value else "FALSE"
+ return "1" if value else "0"
+ if isinstance(value, (int, float)):
+ return str(value)
+ if isinstance(value, str):
+ return "'" + value.replace("'", "''") + "'"
+ raise ConnectorValidationError(
+ f"Unsupported timestamp cursor value type: {type(value).__name__}"
+ )
+
+
+ def _build_time_filtered_query(
self,
- start: Optional[datetime] = None,
- end: Optional[datetime] = None,
+ base_query: str,
+ start: Any = None,
+ end: Any = None,
) -> str:
- """Build the query with optional time filtering for incremental sync."""
- if not self.query:
- return "" # Will be handled by table discovery
- base_query = self.query.rstrip(";")
-
if not self.timestamp_column or (start is None and end is None):
- return base_query
-
- has_where = "where" in base_query.lower()
- connector = " AND" if has_where else " WHERE"
-
- time_conditions = []
+ return self._wrap_query(base_query)
+
+ conditions = []
if start is not None:
- if self.db_type == DatabaseType.MYSQL:
- time_conditions.append(f"{self.timestamp_column} > '{start.strftime('%Y-%m-%d %H:%M:%S')}'")
- else:
- time_conditions.append(f"{self.timestamp_column} > '{start.isoformat()}'")
-
+ conditions.append(
+ f"ragflow_src.{self.timestamp_column} > {self._format_sql_value(start)}"
+ )
if end is not None:
- if self.db_type == DatabaseType.MYSQL:
- time_conditions.append(f"{self.timestamp_column} <= '{end.strftime('%Y-%m-%d %H:%M:%S')}'")
- else:
- time_conditions.append(f"{self.timestamp_column} <= '{end.isoformat()}'")
-
- if time_conditions:
- return f"{base_query}{connector} {' AND '.join(time_conditions)}"
-
- return base_query
+ conditions.append(
+ f"ragflow_src.{self.timestamp_column} <= {self._format_sql_value(end)}"
+ )
- def _row_to_document(self, row: Union[tuple, list, Dict[str, Any]], column_names: list) -> Document:
- """Convert a database row to a Document."""
- row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
-
+ query = self._wrap_query(base_query)
+ if conditions:
+ query = f"{query} WHERE {' AND '.join(conditions)}"
+ return query
+
+
+ def _build_max_timestamp_query(self, base_query: str) -> str:
+ return (
+ f"SELECT MAX(ragflow_src.{self.timestamp_column}) "
+ f"FROM ({base_query}) AS ragflow_src"
+ )
+
+
+ def _build_slim_query(self, base_query: str) -> str:
+ columns = [self.id_column] if self.id_column else self.content_columns
+ select_clause = ", ".join(f"ragflow_src.{column}" for column in columns)
+ return self._wrap_query(base_query, select_clause)
+
+
+ def _build_content(self, row_dict: Dict[str, Any]) -> str:
content_parts = []
for col in self.content_columns:
- if col in row_dict and row_dict[col] is not None:
- value = row_dict[col]
- if isinstance(value, (dict, list)):
- value = json.dumps(value, ensure_ascii=False)
- # Use brackets around field name and put value on a new line
- # so that TxtParser preserves field boundaries after chunking.
- content_parts.append(f"【{col}】:\n{value}")
-
- content = "\n\n".join(content_parts)
-
- if self.id_column and self.id_column in row_dict:
- doc_id = f"{self.db_type}:{self.database}:{row_dict[self.id_column]}"
- else:
- content_hash = hashlib.md5(content.encode()).hexdigest()
- doc_id = f"{self.db_type}:{self.database}:{content_hash}"
-
+ if col not in row_dict or row_dict[col] is None:
+ continue
+ value = row_dict[col]
+ if isinstance(value, (dict, list)):
+ value = json.dumps(value, ensure_ascii=False)
+ content_parts.append(f"【{col}】:\n{value}")
+ return "\n\n".join(content_parts)
+
+
+ def _build_document_id_from_row(self, row_dict: Dict[str, Any]) -> str:
+ if self.id_column and self.id_column in row_dict and row_dict[self.id_column] is not None:
+ return f"{self.db_type}:{self.database}:{row_dict[self.id_column]}"
+ content = self._build_content(row_dict)
+ content_hash = hashlib.md5(content.encode()).hexdigest()
+ return f"{self.db_type}:{self.database}:{content_hash}"
+
+
+ def _row_to_document(
+ self,
+ row: Union[tuple, list, Dict[str, Any]],
+ column_names: list[str],
+ ) -> Document:
+ """Convert a database row to a Document."""
+ row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
+ content = self._build_content(row_dict)
metadata = {}
for col in self.metadata_columns:
- if col in row_dict and row_dict[col] is not None:
- value = row_dict[col]
- if isinstance(value, datetime):
- value = value.isoformat()
- elif isinstance(value, (dict, list)):
- value = json.dumps(value, ensure_ascii=False)
- else:
- value = str(value)
- metadata[col] = value
-
+ if col not in row_dict or row_dict[col] is None:
+ continue
+ value = row_dict[col]
+ if isinstance(value, datetime):
+ value = value.isoformat()
+ elif isinstance(value, (dict, list)):
+ value = json.dumps(value, ensure_ascii=False)
+ else:
+ value = str(value)
+ metadata[col] = value
+
doc_updated_at = datetime.now(timezone.utc)
- if self.timestamp_column and self.timestamp_column in row_dict:
+ if self.timestamp_column and self.timestamp_column in row_dict and row_dict[self.timestamp_column] is not None:
ts_value = row_dict[self.timestamp_column]
if isinstance(ts_value, datetime):
if ts_value.tzinfo is None:
doc_updated_at = ts_value.replace(tzinfo=timezone.utc)
else:
- doc_updated_at = ts_value
-
+ doc_updated_at = ts_value.astimezone(timezone.utc)
+
first_content_col = self.content_columns[0] if self.content_columns else "record"
- semantic_id = str(row_dict.get(first_content_col, "database_record")).replace("\n", " ").replace("\r", " ").strip()[:100]
+ semantic_id = (
+ str(row_dict.get(first_content_col, "database_record"))
+ .replace("\n", " ")
+ .replace("\r", " ")
+ .strip()[:100]
+ )
+ blob = content.encode("utf-8")
-
return Document(
- id=doc_id,
- blob=content.encode("utf-8"),
+ id=self._build_document_id_from_row(row_dict),
+ blob=blob,
source=DocumentSource(self.db_type.value),
semantic_identifier=semantic_id,
extension=".txt",
doc_updated_at=doc_updated_at,
- size_bytes=len(content.encode("utf-8")),
+ size_bytes=len(blob),
metadata=metadata if metadata else None,
)
+
def _yield_documents_from_query(
self,
query: str,
@@ -288,30 +377,146 @@ def _yield_documents_from_query(
pass
cursor.close()
+
+ def _yield_slim_documents_from_query(
+ self,
+ query: str,
+ ) -> Generator[list[SlimDocument], None, None]:
+ connection = self._get_connection()
+ cursor = connection.cursor()
+
+ try:
+ logging.debug(f"Executing slim query: {query[:200]}...")
+ cursor.execute(query)
+ column_names = [desc[0] for desc in cursor.description]
+
+ batch: list[SlimDocument] = []
+ for row in cursor:
+ row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row
+ batch.append(SlimDocument(id=self._build_document_id_from_row(row_dict)))
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+
+ if batch:
+ yield batch
+ finally:
+ try:
+ cursor.fetchall()
+ except Exception:
+ pass
+ cursor.close()
+
+
+ def get_max_cursor_value(self) -> Any:
+ if not self.timestamp_column:
+ return None
+
+ max_cursor_value = None
+ connection = self._get_connection()
+ cursor = connection.cursor()
+
+ try:
+ for base_query in self._get_base_queries():
+ query = self._build_max_timestamp_query(base_query)
+ logging.debug(f"Executing max timestamp query: {query[:200]}...")
+ cursor.execute(query)
+ row = cursor.fetchone()
+ if row is None or row[0] is None:
+ continue
+ if max_cursor_value is None or row[0] > max_cursor_value:
+ max_cursor_value = row[0]
+ finally:
+ cursor.close()
+
+ return max_cursor_value
+
+
def _yield_documents(
self,
- start: Optional[datetime] = None,
- end: Optional[datetime] = None,
+ start: Any = None,
+ end: Any = None,
) -> Generator[list[Document], None, None]:
"""Generate documents from database query results."""
- if self.query:
- query = self._build_query_with_time_filter(start, end)
- yield from self._yield_documents_from_query(query)
- else:
- tables = self._get_tables()
- logging.info(f"No query specified. Loading all {len(tables)} tables: {tables}")
- for table in tables:
- query = f"SELECT * FROM {table}"
- logging.info(f"Loading table: {table}")
+ base_queries = self._get_base_queries()
+ if not self.query:
+ logging.info(f"No query specified. Loading all {len(base_queries)} tables.")
+
+ try:
+ for base_query in base_queries:
+ query = self._build_time_filtered_query(base_query, start, end)
yield from self._yield_documents_from_query(query)
-
- self._close_connection()
+ finally:
+ self._close_connection()
+
def load_from_state(self) -> Generator[list[Document], None, None]:
"""Load all documents from the database (full sync)."""
logging.debug(f"Loading all records from {self.db_type} database: {self.database}")
return self._yield_documents()
+
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> Generator[list[SlimDocument], None, None]:
+ del callback
+
+ base_queries = self._get_base_queries()
+ if not self.query:
+ logging.info(f"No query specified. Retrieving slim documents from all {len(base_queries)} tables.")
+
+ try:
+ for base_query in base_queries:
+ yield from self._yield_slim_documents_from_query(
+ self._build_slim_query(base_query)
+ )
+ finally:
+ self._close_connection()
+
+ def prepare_sync_state(self, connector_id: str, config: Dict[str, Any]) -> None:
+ self._sync_connector_id = connector_id
+ self._sync_config = copy.deepcopy(config)
+ if not self.timestamp_column:
+ self._pending_sync_cursor_value = None
+ return
+ self._pending_sync_cursor_value = self.get_max_cursor_value()
+
+
+ def get_saved_sync_cursor_value(self) -> Any:
+ if self._sync_config is None:
+ return None
+ return self.deserialize_cursor_value(self._sync_config.get("sync_cursor_value"))
+
+
+ def persist_sync_state(self) -> None:
+ if not self.timestamp_column or self._sync_connector_id is None or self._sync_config is None:
+ return
+
+ from api.db.services.connector_service import ConnectorService
+
+ updated_conf = copy.deepcopy(self._sync_config)
+ updated_conf["sync_cursor_value"] = self.serialize_cursor_value(
+ self._pending_sync_cursor_value
+ )
+ ConnectorService.update_by_id(self._sync_connector_id, {"config": updated_conf})
+ self._sync_config = updated_conf
+
+
+ def load_from_cursor_range(
+ self,
+ start_value: Any = None,
+ end_value: Any = None,
+ ) -> Generator[list[Document], None, None]:
+ if end_value is None:
+ self._close_connection()
+ return iter(())
+ if start_value is not None and end_value <= start_value:
+ self._close_connection()
+ return iter(())
+ return self._yield_documents(start_value, end_value)
+
+
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> Generator[list[Document], None, None]:
@@ -322,16 +527,8 @@ def poll_source(
"Falling back to full sync."
)
return self.load_from_state()
-
- start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
- end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
-
- logging.debug(
- f"Polling {self.db_type} database {self.database} "
- f"from {start_datetime} to {end_datetime}"
- )
-
- return self._yield_documents(start_datetime, end_datetime)
+ return self._yield_documents(start, end)
+
def validate_connector_settings(self) -> None:
"""Validate connector settings by testing the connection."""
diff --git a/common/data_source/rss_connector.py b/common/data_source/rss_connector.py
index 85471407abc..6fad756d73b 100644
--- a/common/data_source/rss_connector.py
+++ b/common/data_source/rss_connector.py
@@ -1,44 +1,29 @@
import hashlib
-import ipaddress
-import socket
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from time import struct_time
from typing import Any
-from urllib.parse import urlparse
+from urllib.parse import urljoin, urlparse
import bs4
import feedparser
import requests
from common.data_source.config import INDEX_BATCH_SIZE, REQUEST_TIMEOUT_SECONDS, DocumentSource
-from common.data_source.interfaces import LoadConnector, PollConnector
-from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch
+from common.data_source.interfaces import LoadConnector, PollConnector, SlimConnectorWithPermSync
+from common.data_source.models import (
+ Document,
+ GenerateDocumentsOutput,
+ GenerateSlimDocumentOutput,
+ SecondsSinceUnixEpoch,
+ SlimDocument,
+)
+from common.ssrf_guard import assert_url_is_safe, pin_dns as _pin_dns
+_MAX_REDIRECTS = 10
-def _is_private_ip(ip: str) -> bool:
- try:
- ip_obj = ipaddress.ip_address(ip)
- return ip_obj.is_private or ip_obj.is_link_local or ip_obj.is_loopback
- except ValueError:
- return False
-
-def _validate_url_no_ssrf(url: str) -> None:
- parsed = urlparse(url)
- hostname = parsed.hostname
- if not hostname:
- raise ValueError("URL must have a valid hostname")
-
- try:
- ip = socket.gethostbyname(hostname)
- if _is_private_ip(ip):
- raise ValueError(f"URL resolves to private/internal IP address: {ip}")
- except socket.gaierror as e:
- raise ValueError(f"Failed to resolve hostname: {hostname}") from e
-
-
-class RSSConnector(LoadConnector, PollConnector):
+class RSSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
def __init__(self, feed_url: str, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.feed_url = feed_url.strip()
self.batch_size = batch_size
@@ -61,6 +46,25 @@ def load_from_state(self) -> GenerateDocumentsOutput:
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput:
yield from self._load_entries(start=start, end=end)
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ del callback
+
+ feed = self._read_feed(require_entries=False)
+ batch: list[SlimDocument] = []
+
+ for entry in feed.entries:
+ batch.append(SlimDocument(id=self._build_document_id(entry)))
+
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+
+ if batch:
+ yield batch
+
def _load_entries(
self,
start: SecondsSinceUnixEpoch | None = None,
@@ -87,7 +91,8 @@ def _load_entries(
if batch:
yield batch
- def _validate_feed_url(self) -> None:
+ def _validate_feed_url(self) -> tuple[str, str]:
+ """Validate ``self.feed_url`` and return ``(hostname, resolved_ip)``."""
if not self.feed_url:
raise ValueError("feed_url is required")
@@ -95,7 +100,7 @@ def _validate_feed_url(self) -> None:
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise ValueError("feed_url must be a valid http or https URL")
- _validate_url_no_ssrf(self.feed_url)
+ return assert_url_is_safe(self.feed_url)
def _read_feed(self, require_entries: bool) -> Any:
if self._cached_feed is not None:
@@ -103,15 +108,38 @@ def _read_feed(self, require_entries: bool) -> Any:
raise ValueError("RSS feed contains no entries")
return self._cached_feed
- self._validate_feed_url()
+ # Validate once to get the pinned IP for the initial request.
+ current_hostname, current_ip = self._validate_feed_url()
+ current_url = self.feed_url
+
+ # Follow redirects manually: each hop is validated and DNS-pinned
+ # *before* the connection is made, closing the TOCTOU rebinding window
+ # that existed when allow_redirects=True was used with post-hoc checks.
+ response: requests.Response | None = None
+ for _ in range(_MAX_REDIRECTS + 1):
+ with _pin_dns(current_hostname, current_ip):
+ response = requests.get(
+ current_url,
+ timeout=REQUEST_TIMEOUT_SECONDS,
+ allow_redirects=False,
+ )
+
+ if response.status_code not in (301, 302, 303, 307, 308):
+ break
+
+ location = response.headers.get("Location")
+ if not location:
+ break # broken redirect; let raise_for_status() handle it
+
+ redirect_url = urljoin(current_url, location)
+ # Validate redirect target before following it.
+ current_hostname, current_ip = assert_url_is_safe(redirect_url)
+ current_url = redirect_url
+ else:
+ raise ValueError(f"Exceeded {_MAX_REDIRECTS} redirects fetching {self.feed_url!r}")
- response = requests.get(self.feed_url, timeout=REQUEST_TIMEOUT_SECONDS, allow_redirects=True)
response.raise_for_status()
- final_url = getattr(response, "url", self.feed_url)
- if final_url != self.feed_url and urlparse(final_url).hostname:
- _validate_url_no_ssrf(final_url)
-
feed = feedparser.parse(response.content)
if getattr(feed, "bozo", False) and not feed.entries:
error = getattr(feed, "bozo_exception", None)
@@ -127,7 +155,7 @@ def _read_feed(self, require_entries: bool) -> Any:
def _build_document(self, entry: Any, updated_at: datetime) -> Document:
link = (entry.get("link") or "").strip()
title = (entry.get("title") or "").strip()
- stable_key = (entry.get("id") or link or title or self.feed_url).strip()
+ stable_key = self._resolve_stable_key(entry)
semantic_identifier = title or link or stable_key
content = self._build_content(entry, semantic_identifier)
blob = content.encode("utf-8")
@@ -149,7 +177,7 @@ def _build_document(self, entry: Any, updated_at: datetime) -> Document:
metadata["categories"] = categories
return Document(
- id=f"rss:{hashlib.md5(stable_key.encode('utf-8')).hexdigest()}",
+ id=self._build_document_id(entry),
source=DocumentSource.RSS,
semantic_identifier=semantic_identifier,
extension=".txt",
@@ -177,6 +205,15 @@ def _build_content(self, entry: Any, semantic_identifier: str) -> str:
return "\n\n".join(part for part in parts if part).strip()
+ def _build_document_id(self, entry: Any) -> str:
+ stable_key = self._resolve_stable_key(entry)
+ return f"rss:{hashlib.md5(stable_key.encode('utf-8')).hexdigest()}"
+
+ def _resolve_stable_key(self, entry: Any) -> str:
+ link = (entry.get("link") or "").strip()
+ title = (entry.get("title") or "").strip()
+ return (entry.get("id") or link or title or self.feed_url).strip()
+
def _resolve_entry_time(self, entry: Any) -> datetime:
for field in ("updated_parsed", "published_parsed"):
value = entry.get(field)
diff --git a/common/data_source/seafile_connector.py b/common/data_source/seafile_connector.py
index ef7afeecf47..66bcf954fde 100644
--- a/common/data_source/seafile_connector.py
+++ b/common/data_source/seafile_connector.py
@@ -20,17 +20,19 @@
CredentialExpiredError,
InsufficientPermissionsError,
)
-from common.data_source.interfaces import LoadConnector, PollConnector
+from common.data_source.interfaces import LoadConnector, PollConnector, SlimConnectorWithPermSync
from common.data_source.models import (
Document,
SecondsSinceUnixEpoch,
GenerateDocumentsOutput,
+ GenerateSlimDocumentOutput,
SeafileSyncScope,
+ SlimDocument,
)
logger = logging.getLogger(__name__)
-class SeaFileConnector(LoadConnector, PollConnector):
+class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""SeaFile connector supporting account-, library- and directory-level sync.
API endpoints used:
@@ -357,8 +359,18 @@ def _get_repo_info(self) -> Optional[dict]:
return self._get_repo_info_via_account(self.repo_id)
@retry(tries=3, delay=1, backoff=2)
- def _get_directory_entries(self, repo_id: str, path: str = "/") -> list[dict]:
- """List directory contents using the appropriate endpoint."""
+ def _get_directory_entries(
+ self,
+ repo_id: str,
+ path: str = "/",
+ *,
+ raise_on_failure: bool = False,
+ ) -> list[dict]:
+ """List directory contents using the appropriate endpoint.
+
+ When ``raise_on_failure`` is True (used for slim snapshots), HTTP/API errors
+ propagate so callers do not treat a failed listing as an empty directory.
+ """
try:
if self._use_repo_token:
# GET /api/v2.1/via-repo-token/dir/?path=/foo
@@ -380,6 +392,8 @@ def _get_directory_entries(self, repo_id: str, path: str = "/") -> list[dict]:
logger.warning(
"Error fetching directory %s in repo %s: %s", path, repo_id, e,
)
+ if raise_on_failure:
+ raise
return []
@retry(tries=3, delay=1, backoff=2)
@@ -412,9 +426,14 @@ def _list_files_recursive(
path: str,
start: datetime,
end: datetime,
+ *,
+ filter_by_mtime: bool = True,
+ strict_listing: bool = False,
) -> list[tuple[str, dict, dict]]:
files = []
- entries = self._get_directory_entries(repo_id, path)
+ entries = self._get_directory_entries(
+ repo_id, path, raise_on_failure=strict_listing,
+ )
for entry in entries:
entry_type = entry.get("type")
@@ -424,15 +443,33 @@ def _list_files_recursive(
if entry_type == "dir":
files.extend(
self._list_files_recursive(
- repo_id, repo_name, entry_path, start, end,
+ repo_id,
+ repo_name,
+ entry_path,
+ start,
+ end,
+ filter_by_mtime=filter_by_mtime,
+ strict_listing=strict_listing,
)
)
elif entry_type == "file":
modified = self._parse_mtime(entry.get("mtime"))
- if start < modified <= end:
+ if filter_by_mtime:
+ if start < modified <= end:
+ files.append(
+ (
+ entry_path,
+ entry,
+ {"id": repo_id, "name": repo_name},
+ )
+ )
+ else:
files.append(
- (entry_path, entry,
- {"id": repo_id, "name": repo_name})
+ (
+ entry_path,
+ entry,
+ {"id": repo_id, "name": repo_name},
+ )
)
return files
@@ -473,6 +510,8 @@ def _yield_seafile_documents(
try:
files = self._list_files_recursive(
lib["id"], lib["name"], root, start, end,
+ filter_by_mtime=True,
+ strict_listing=False,
)
all_files.extend(files)
except Exception as e:
@@ -539,4 +578,59 @@ def poll_source(
for batch in self._yield_seafile_documents(start_dt, end_dt):
yield batch
-
\ No newline at end of file
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ """Full snapshot of file IDs eligible for indexing (no downloads).
+
+ Uses ``seafile:{repo_id}:{file_id}`` matching :meth:`_yield_seafile_documents`.
+ Listing uses strict directory reads (errors propagate) so partial snapshots
+ are never treated as authoritative for stale-document cleanup.
+ """
+ del callback
+ logger.info(
+ "Starting SeaFile slim snapshot: scope=%s url=%s",
+ self.sync_scope.value,
+ self.seafile_url,
+ )
+
+ libraries = self._resolve_libraries_to_scan()
+ all_files: list[tuple[str, dict, dict]] = []
+ for lib in libraries:
+ root = self._root_path_for_repo(lib["id"])
+ span_start = datetime(1970, 1, 1, tzinfo=timezone.utc)
+ span_end = datetime.now(timezone.utc)
+ listed = self._list_files_recursive(
+ lib["id"],
+ lib["name"],
+ root,
+ span_start,
+ span_end,
+ filter_by_mtime=False,
+ strict_listing=True,
+ )
+ all_files.extend(listed)
+
+ batch: list[SlimDocument] = []
+ total = 0
+ for file_path, file_entry, library in all_files:
+ file_size = file_entry.get("size", 0)
+ if file_size > self.size_threshold:
+ continue
+ file_id = file_entry.get("id", "")
+ repo_id = library["id"]
+ batch.append(SlimDocument(id=f"seafile:{repo_id}:{file_id}"))
+ total += 1
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+
+ if batch:
+ yield batch
+
+ logger.info(
+ "Completed SeaFile slim snapshot: %d documents (listed_paths=%d)",
+ total,
+ len(all_files),
+ )
diff --git a/common/data_source/sharepoint_connector.py b/common/data_source/sharepoint_connector.py
index 7bc8e3410dc..e5684023c15 100644
--- a/common/data_source/sharepoint_connector.py
+++ b/common/data_source/sharepoint_connector.py
@@ -112,10 +112,8 @@ def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
def retrieve_all_slim_docs_perm_sync(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Any:
"""Retrieve all simplified documents with permission sync"""
# Simplified implementation
- return []
\ No newline at end of file
+ return []
diff --git a/common/data_source/slack_connector.py b/common/data_source/slack_connector.py
index 5fabc3d00fb..162826762cd 100644
--- a/common/data_source/slack_connector.py
+++ b/common/data_source/slack_connector.py
@@ -528,8 +528,6 @@ def set_credentials_provider(self, credentials_provider: Any) -> None:
def retrieve_all_slim_docs_perm_sync(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
@@ -662,4 +660,4 @@ def get_credentials(self):
connector.validate_connector_settings()
print("Slack connector settings validated successfully")
except Exception as e:
- print(f"Validation failed: {e}")
\ No newline at end of file
+ print(f"Validation failed: {e}")
diff --git a/common/data_source/teams_connector.py b/common/data_source/teams_connector.py
index 0b4cd564252..98b472667a0 100644
--- a/common/data_source/teams_connector.py
+++ b/common/data_source/teams_connector.py
@@ -106,10 +106,8 @@ def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
def retrieve_all_slim_docs_perm_sync(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
callback: Any = None,
) -> Any:
"""Retrieve all simplified documents with permission sync"""
# Simplified implementation
- return []
\ No newline at end of file
+ return []
diff --git a/common/data_source/webdav_connector.py b/common/data_source/webdav_connector.py
index b860c0b61ae..6ea6558ad5b 100644
--- a/common/data_source/webdav_connector.py
+++ b/common/data_source/webdav_connector.py
@@ -17,11 +17,11 @@
CredentialExpiredError,
InsufficientPermissionsError
)
-from common.data_source.interfaces import LoadConnector, OnyxExtensionType, PollConnector
-from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput
+from common.data_source.interfaces import LoadConnector, OnyxExtensionType, PollConnector, SlimConnectorWithPermSync
+from common.data_source.models import Document, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SecondsSinceUnixEpoch, SlimDocument
-class WebDAVConnector(LoadConnector, PollConnector):
+class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""WebDAV connector for syncing files from WebDAV servers"""
def __init__(
@@ -102,17 +102,20 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None
return None
def _list_files_recursive(
- self,
+ self,
path: str,
start: datetime,
end: datetime,
+ *,
+ filter_by_mtime: bool = True,
) -> list[tuple[str, dict]]:
"""Recursively list all files in the given path
Args:
path: Path to list files from
- start: Start datetime for filtering
- end: End datetime for filtering
+ start: Start datetime for filtering (ignored when ``filter_by_mtime`` is False)
+ end: End datetime for filtering (ignored when ``filter_by_mtime`` is False)
+ filter_by_mtime: When False, include every supported extension without mtime window
Returns:
List of tuples containing (file_path, file_info)
@@ -134,7 +137,14 @@ def _list_files_recursive(
if item.get('type') == 'directory':
try:
- files.extend(self._list_files_recursive(item_path, start, end))
+ files.extend(
+ self._list_files_recursive(
+ item_path,
+ start,
+ end,
+ filter_by_mtime=filter_by_mtime,
+ )
+ )
except Exception as e:
logging.error(f"Error recursing into directory {item_path}: {e}")
continue
@@ -168,10 +178,13 @@ def _list_files_recursive(
logging.debug(f"File {item_path}: modified={modified}, start={start}, end={end}, include={start < modified <= end}")
- if start < modified <= end:
- files.append((item_path, item))
+ if filter_by_mtime:
+ if start < modified <= end:
+ files.append((item_path, item))
+ else:
+ logging.debug(f"File {item_path} filtered out by time range")
else:
- logging.debug(f"File {item_path} filtered out by time range")
+ files.append((item_path, item))
except Exception as e:
logging.error(f"Error processing file {item_path}: {e}")
continue
@@ -323,6 +336,61 @@ def poll_source(
for batch in self._yield_webdav_documents(start_datetime, end_datetime):
yield batch
+ def retrieve_all_slim_docs_perm_sync(
+ self,
+ callback: Any = None,
+ ) -> GenerateSlimDocumentOutput:
+ """Full-tree snapshot of indexed paths for stale-document reconciliation.
+
+ Uses the same ``webdav:{base_url}:{file_path}`` ids as :meth:`_yield_webdav_documents`,
+ without downloading file contents.
+ """
+ del callback
+ if self.client is None:
+ raise ConnectorMissingCredentialError("WebDAV client not initialized")
+
+ logging.info(
+ "Starting WebDAV slim snapshot: base_url=%s path=%s",
+ self.base_url,
+ self.remote_path,
+ )
+
+ files = self._list_files_recursive(
+ self.remote_path,
+ datetime(1970, 1, 1, tzinfo=timezone.utc),
+ datetime.now(timezone.utc),
+ filter_by_mtime=False,
+ )
+ batch: list[SlimDocument] = []
+ total = 0
+ for file_path, file_info in files:
+ file_name = os.path.basename(file_path)
+ if not self._is_supported_file(file_name):
+ continue
+ size_bytes = file_info.get("size", 0)
+ if (
+ self.size_threshold is not None
+ and isinstance(size_bytes, int)
+ and size_bytes > self.size_threshold
+ ):
+ continue
+ batch.append(
+ SlimDocument(id=f"webdav:{self.base_url}:{file_path}")
+ )
+ total += 1
+ if len(batch) >= self.batch_size:
+ yield batch
+ batch = []
+
+ if batch:
+ yield batch
+
+ logging.info(
+ "Completed WebDAV slim snapshot: %d documents (listed_paths=%d)",
+ total,
+ len(files),
+ )
+
def validate_connector_settings(self) -> None:
"""Validate WebDAV connector settings.
diff --git a/common/data_source/zendesk_connector.py b/common/data_source/zendesk_connector.py
index 85b3426fe3f..c357b500fb7 100644
--- a/common/data_source/zendesk_connector.py
+++ b/common/data_source/zendesk_connector.py
@@ -246,6 +246,18 @@ def _article_to_document(
)
+def _is_indexable_article(article: dict[str, Any]) -> bool:
+ body = article.get("body")
+ return (
+ bool(body)
+ and not article.get("draft")
+ and not any(
+ label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
+ for label in article.get("label_names") or []
+ )
+ )
+
+
def _get_comment_text(
comment: dict[str, Any],
author_map: dict[str, BasicExpertInfo],
@@ -333,6 +345,10 @@ def _ticket_to_document(
)
+def _is_indexable_ticket(ticket: dict[str, Any]) -> bool:
+ return ticket.get("status") != "deleted"
+
+
class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
# We use cursor-based paginated retrieval for articles
after_cursor_articles: str | None
@@ -419,14 +435,7 @@ def _retrieve_articles(
has_more = response.has_more
after_cursor = response.meta.get("after_cursor")
for article in articles:
- if (
- article.get("body") is None
- or article.get("draft")
- or any(
- label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
- for label in article.get("label_names", [])
- )
- ):
+ if not _is_indexable_article(article):
continue
try:
@@ -498,7 +507,7 @@ def _retrieve_tickets(
has_more = ticket_response.has_more
next_start_time = ticket_response.meta["end_time"]
for ticket in tickets:
- if ticket.get("status") == "deleted":
+ if not _is_indexable_ticket(ticket):
continue
try:
@@ -553,16 +562,14 @@ def _retrieve_tickets(
def retrieve_all_slim_docs_perm_sync(
self,
- start: SecondsSinceUnixEpoch | None = None,
- end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
if self.content_type == "articles":
- articles = _get_articles(
- self.client, start_time=int(start) if start else None
- )
+ articles = _get_articles(self.client)
for article in articles:
+ if not _is_indexable_article(article):
+ continue
slim_doc_batch.append(
SlimDocument(
id=f"article:{article['id']}",
@@ -572,10 +579,10 @@ def retrieve_all_slim_docs_perm_sync(
yield slim_doc_batch
slim_doc_batch = []
elif self.content_type == "tickets":
- tickets = _get_tickets(
- self.client, start_time=int(start) if start else None
- )
+ tickets = _get_tickets(self.client)
for ticket in tickets:
+ if not _is_indexable_ticket(ticket):
+ continue
slim_doc_batch.append(
SlimDocument(
id=f"zendesk_ticket_{ticket['id']}",
@@ -664,4 +671,4 @@ def build_dummy_checkpoint(self) -> ZendeskConnectorCheckpoint:
checkpoint = next_checkpoint
if any_doc:
- break
\ No newline at end of file
+ break
diff --git a/common/doc_store/infinity_conn_base.py b/common/doc_store/infinity_conn_base.py
index 20baa34a60a..af8493b82b2 100644
--- a/common/doc_store/infinity_conn_base.py
+++ b/common/doc_store/infinity_conn_base.py
@@ -16,10 +16,12 @@
import logging
import os
+import random
import re
import json
import time
from abc import abstractmethod
+from typing import Callable, TypeVar
import infinity
from infinity.common import ConflictType
@@ -32,6 +34,117 @@
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr
+# Concurrent CREATE/DROP TABLE on the same Infinity instance can race on
+# Infinity's RocksDB-backed catalog counters (e.g. ``db|1|next_table_id``).
+# When two writers touch the counter at the same instant, Infinity surfaces
+# error 9003 / "Resource busy" instead of waiting on a lock — turning a
+# user-visible operation into an avoidable failure under modest concurrency
+# (two users creating a knowledge base at the same time, batch onboarding,
+# multi-replica deployments, …).
+#
+# We retry the metadata path (CREATE TABLE / CREATE INDEX / DROP TABLE) on
+# this specific error with exponential backoff + jitter. The wrapped calls
+# already use ``ConflictType.Ignore``, so re-running them on retry is
+# idempotent. The retry budget is intentionally bounded (5 attempts,
+# ~1.5s worst case) so a genuine outage still surfaces quickly.
+#
+# Tunable from the environment:
+# INFINITY_META_RETRY_MAX default 5
+# INFINITY_META_RETRY_BASE_DELAY_MS default 50
+
+_T = TypeVar("_T")
+
+# Infinity error code 9003 is raised on RocksDB transaction contention. It is
+# not in the SDK's ErrorCode enum yet, so we keep the literal here.
+_INFINITY_RESOURCE_BUSY_CODE = 9003
+
+
+def _int_env(name: str, default: int) -> int:
+ """Read an int from the environment without crashing on bad input.
+
+ A misconfigured ``INFINITY_META_RETRY_MAX=`` (empty value) or non-numeric
+ string would otherwise raise ``ValueError`` at module import time and
+ take down every backend worker. We log and fall back to the default
+ instead.
+ """
+ raw = os.getenv(name)
+ if raw is None or raw == "":
+ return default
+ try:
+ return int(raw)
+ except ValueError:
+ logging.getLogger(__name__).warning(
+ "Ignoring invalid %s=%r, falling back to %d", name, raw, default,
+ )
+ return default
+
+
+_META_RETRY_MAX = _int_env("INFINITY_META_RETRY_MAX", 5)
+_META_RETRY_BASE_DELAY_MS = _int_env("INFINITY_META_RETRY_BASE_DELAY_MS", 50)
+
+
+def _is_meta_contention_error(exc: BaseException) -> bool:
+ """Return True iff ``exc`` is the RocksDB metadata-counter "Resource busy".
+
+ Prefer the numeric error code when the SDK exposes one — substring matching
+ on ``str(exc)`` is the fallback for older SDKs that surface only a tuple
+ or a plain string. Both surfaces are observed in the wild today.
+ """
+ code = getattr(exc, "error_code", None)
+ if code is None:
+ # Some Infinity SDK paths raise a plain ``Exception((9003, "..."))``
+ # whose ``args[0]`` carries the code.
+ args = getattr(exc, "args", None)
+ if args and isinstance(args, tuple) and args:
+ code = args[0]
+ if code == _INFINITY_RESOURCE_BUSY_CODE:
+ return True
+ msg = str(exc)
+ return "Resource busy" in msg and "rocksdb" in msg.lower()
+
+
+def _retry_on_meta_contention(
+ op_name: str,
+ operation: Callable[[], _T],
+ *,
+ logger: logging.Logger | None = None,
+ max_attempts: int = _META_RETRY_MAX,
+ base_delay_ms: int = _META_RETRY_BASE_DELAY_MS,
+) -> _T:
+ """Run ``operation`` and retry on RocksDB "Resource busy" errors.
+
+ Exponential backoff with ±50% jitter to avoid a thundering herd when many
+ workers retry simultaneously. Any exception that does not match
+ :func:`_is_meta_contention_error` is re-raised immediately so genuine
+ failures still surface fast.
+ """
+ log = logger or logging.getLogger(__name__)
+ last_exc: BaseException | None = None
+ for attempt in range(max_attempts):
+ try:
+ return operation()
+ except Exception as exc:
+ if not _is_meta_contention_error(exc):
+ raise
+ last_exc = exc
+ if attempt == max_attempts - 1:
+ break
+ base = (base_delay_ms / 1000.0) * (2 ** attempt)
+ sleep_for = base + random.uniform(0, base * 0.5)
+ log.info(
+ "INFINITY meta contention on %s (attempt %d/%d), "
+ "retrying in %.3fs: %s",
+ op_name, attempt + 1, max_attempts, sleep_for, exc,
+ )
+ time.sleep(sleep_for)
+ log.warning(
+ "INFINITY meta contention on %s exhausted %d attempts: %s",
+ op_name, max_attempts, last_exc,
+ )
+ assert last_exc is not None
+ raise last_exc
+
+
class InfinityConnectionBase(DocStoreConnection):
def __init__(self, mapping_file_name: str = "infinity_mapping.json", logger_name: str = "ragflow.infinity_conn", table_name_prefix: str="ragflow_"):
from common.doc_store.infinity_conn_pool import INFINITY_CONN
@@ -173,7 +286,15 @@ def exists(cln):
cond = list()
for k, v in condition.items():
- if not isinstance(k, str) or not v:
+ if not isinstance(k, str):
+ continue
+ if k == "available_int":
+ if v == 0:
+ cond.append("available_int=0")
+ elif v == 1:
+ cond.append("available_int=1")
+ continue
+ if not v:
continue
if self.field_keyword(k):
if isinstance(v, list):
@@ -266,7 +387,11 @@ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_
inf_conn = self.connPool.get_conn()
try:
- inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
+ inf_db = _retry_on_meta_contention(
+ f"create_database({self.dbName})",
+ lambda: inf_conn.create_database(self.dbName, ConflictType.Ignore),
+ logger=self.logger,
+ )
# Use configured schema
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
@@ -285,24 +410,32 @@ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_
vector_name = f"q_{vector_size}_vec"
schema[vector_name] = {"type": f"vector,{vector_size},float"}
- inf_table = inf_db.create_table(
- table_name,
- schema,
- ConflictType.Ignore,
+ inf_table = _retry_on_meta_contention(
+ f"create_table({table_name})",
+ lambda: inf_db.create_table(
+ table_name,
+ schema,
+ ConflictType.Ignore,
+ ),
+ logger=self.logger,
)
- inf_table.create_index(
- "q_vec_idx",
- IndexInfo(
- vector_name,
- IndexType.Hnsw,
- {
- "M": "16",
- "ef_construction": "50",
- "metric": "cosine",
- "encode": "lvq",
- },
+ _retry_on_meta_contention(
+ f"create_index(q_vec_idx, {table_name})",
+ lambda: inf_table.create_index(
+ "q_vec_idx",
+ IndexInfo(
+ vector_name,
+ IndexType.Hnsw,
+ {
+ "M": "16",
+ "ef_construction": "50",
+ "metric": "cosine",
+ "encode": "lvq",
+ },
+ ),
+ ConflictType.Ignore,
),
- ConflictType.Ignore,
+ logger=self.logger,
)
for field_name, field_info in schema.items():
if field_info["type"] != "varchar" or "analyzer" not in field_info:
@@ -311,10 +444,15 @@ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
- inf_table.create_index(
- f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
- IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
- ConflictType.Ignore,
+ idx_name = f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}"
+ _retry_on_meta_contention(
+ f"create_index({idx_name}, {table_name})",
+ lambda fn=field_name, an=analyzer, name=idx_name: inf_table.create_index(
+ name,
+ IndexInfo(fn, IndexType.FullText, {"ANALYZER": an}),
+ ConflictType.Ignore,
+ ),
+ logger=self.logger,
)
# Create secondary indexes for fields with index_type
@@ -323,10 +461,14 @@ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_
continue
index_config = field_info["index_type"]
if isinstance(index_config, str) and index_config == "secondary":
- inf_table.create_index(
- f"sec_{field_name}",
- IndexInfo(field_name, IndexType.Secondary),
- ConflictType.Ignore,
+ _retry_on_meta_contention(
+ f"create_index(sec_{field_name}, {table_name})",
+ lambda fn=field_name: inf_table.create_index(
+ f"sec_{fn}",
+ IndexInfo(fn, IndexType.Secondary),
+ ConflictType.Ignore,
+ ),
+ logger=self.logger,
)
self.logger.info(f"INFINITY created secondary index sec_{field_name} for field {field_name}")
elif isinstance(index_config, dict):
@@ -334,10 +476,14 @@ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_
params = {}
if "cardinality" in index_config:
params = {"cardinality": index_config["cardinality"]}
- inf_table.create_index(
- f"sec_{field_name}",
- IndexInfo(field_name, IndexType.Secondary, params),
- ConflictType.Ignore,
+ _retry_on_meta_contention(
+ f"create_index(sec_{field_name}, {table_name})",
+ lambda fn=field_name, p=params: inf_table.create_index(
+ f"sec_{fn}",
+ IndexInfo(fn, IndexType.Secondary, p),
+ ConflictType.Ignore,
+ ),
+ logger=self.logger,
)
self.logger.info(f"INFINITY created secondary index sec_{field_name} for field {field_name} with params {params}")
@@ -355,18 +501,26 @@ def create_doc_meta_idx(self, index_name: str):
"""
table_name = index_name
inf_conn = self.connPool.get_conn()
- inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
try:
+ inf_db = _retry_on_meta_contention(
+ f"create_database({self.dbName})",
+ lambda: inf_conn.create_database(self.dbName, ConflictType.Ignore),
+ logger=self.logger,
+ )
fp_mapping = os.path.join(get_project_base_directory(), "conf", "doc_meta_infinity_mapping.json")
if not os.path.exists(fp_mapping):
self.logger.error(f"Document metadata mapping file not found at {fp_mapping}")
return False
with open(fp_mapping) as f:
schema = json.load(f)
- inf_db.create_table(
- table_name,
- schema,
- ConflictType.Ignore,
+ _retry_on_meta_contention(
+ f"create_table({table_name})",
+ lambda: inf_db.create_table(
+ table_name,
+ schema,
+ ConflictType.Ignore,
+ ),
+ logger=self.logger,
)
# Create secondary indexes on id and kb_id for better query performance
@@ -392,14 +546,14 @@ def create_doc_meta_idx(self, index_name: str):
except Exception as e:
self.logger.warning(f"Failed to create index on kb_id for {table_name}: {e}")
- self.connPool.release_conn(inf_conn)
self.logger.debug(f"INFINITY created document metadata table {table_name} with secondary indexes")
return True
except Exception as e:
- self.connPool.release_conn(inf_conn)
self.logger.exception(f"Error creating document metadata table {table_name}: {e}")
return False
+ finally:
+ self.connPool.release_conn(inf_conn)
def delete_idx(self, index_name: str, dataset_id: str):
if index_name.startswith("ragflow_doc_meta_"):
@@ -409,7 +563,11 @@ def delete_idx(self, index_name: str, dataset_id: str):
inf_conn = self.connPool.get_conn()
try:
db_instance = inf_conn.get_database(self.dbName)
- db_instance.drop_table(table_name, ConflictType.Ignore)
+ _retry_on_meta_contention(
+ f"drop_table({table_name})",
+ lambda: db_instance.drop_table(table_name, ConflictType.Ignore),
+ logger=self.logger,
+ )
self.logger.info(f"INFINITY dropped table {table_name}")
finally:
self.connPool.release_conn(inf_conn)
diff --git a/common/metadata_es_filter.py b/common/metadata_es_filter.py
new file mode 100644
index 00000000000..afe0f27386e
--- /dev/null
+++ b/common/metadata_es_filter.py
@@ -0,0 +1,580 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Translate RAGflow document-metadata filter lists into Elasticsearch DSL.
+
+The legacy ``common.metadata_utils.meta_filter`` evaluates user-defined
+metadata conditions in Python after loading every document's metadata into
+memory. That works for small knowledge bases but degrades badly past a few
+thousand documents. This module produces an equivalent ES bool query so the
+filtering can be pushed down to the search engine.
+
+Operators handled here mirror ``meta_filter`` exactly. When a filter cannot be
+translated (unknown operator, malformed value, list-typed input that the
+in-memory code special-cases) the translator raises
+:class:`UnsupportedMetaFilter` so callers fall back to the in-memory path
+without silently changing semantics.
+"""
+
+from __future__ import annotations
+
+import ast
+import re
+from dataclasses import dataclass, field
+from typing import Any, Dict, Iterable, List, Optional, Sequence
+
+# Field prefix in the doc-metadata ES index. Every user metadata key lives at
+# ``meta_fields.`` thanks to the dynamic object mapping in
+# ``conf/doc_meta_es_mapping.json``.
+META_FIELDS_PREFIX = "meta_fields"
+
+# Strict ``YYYY-MM-DD`` recogniser, kept consistent with the legacy in-memory
+# path. Mismatched-type comparisons (string vs date, list vs scalar) fall back
+# to in-memory semantics rather than guess at the right ES coercion.
+_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
+
+# Operators that the legacy filter exposes. Anything outside this set is a bug
+# elsewhere; surface it instead of silently no-op'ing.
+SUPPORTED_OPERATORS: frozenset[str] = frozenset(
+ {
+ "=",
+ "≠",
+ ">",
+ "<",
+ "≥",
+ "≤",
+ "in",
+ "not in",
+ "contains",
+ "not contains",
+ "start with",
+ "end with",
+ "empty",
+ "not empty",
+ }
+)
+
+# ES range comparators keyed by RAGflow operator.
+_RANGE_OPS: Dict[str, str] = {
+ ">": "gt",
+ "<": "lt",
+ "≥": "gte",
+ "≤": "lte",
+}
+
+# Negative operators that diverge from ``meta_filter`` on multi-valued metadata
+# fields. The in-memory path checks each value bucket independently, so a doc
+# whose field is ``[a, b]`` matches ``≠ a`` (because the ``b`` bucket satisfies
+# the predicate). ``must_not term: a`` in ES would exclude that doc outright.
+# Without a cheap way to prove a field is single-valued at query time we refuse
+# push-down for these operators and let the in-memory fallback handle them.
+# ``not contains`` is not in this set: ``all(not contains)`` is equivalent to
+# ``not any(contains)``, so ``must_not wildcard *X*`` matches the legacy
+# semantics on both single- and multi-valued fields.
+MULTIVALUE_UNSAFE_NEGATIVE_OPS: frozenset[str] = frozenset({"≠", "not in"})
+
+
+class UnsupportedMetaFilter(Exception):
+ """Raised when a metadata filter cannot be expressed as ES DSL.
+
+ Carries the filter that failed so callers can log a precise reason and the
+ in-memory fallback can pick up unchanged.
+ """
+
+ def __init__(self, reason: str, filter_clause: Optional[Dict[str, Any]] = None) -> None:
+ super().__init__(reason)
+ self.reason = reason
+ self.filter_clause = filter_clause
+
+
+@dataclass
+class TranslatedFilter:
+ """A single user filter rendered as one or more ES bool clauses.
+
+ A clause that wants the field to be present (``≠``, ``not in``, range,
+ ``not contains``) goes into ``must`` so the negation does not accidentally
+ match documents missing the key. ``must_not`` carries the actual rejection.
+ Pure positive filters (``=``, ``contains``, ``in``, ``exists``) fill
+ ``must`` only.
+ """
+
+ must: List[Dict[str, Any]] = field(default_factory=list)
+ must_not: List[Dict[str, Any]] = field(default_factory=list)
+
+ def to_clauses(self) -> List[Dict[str, Any]]:
+ """Collapse to the ES clauses this filter contributes to a parent bool.
+
+ Always emits a single atomic clause when there is anything to emit:
+ a multi-clause ``must`` (e.g. range = ``exists`` + ``range``) gets
+ wrapped in its own ``bool`` so an OR-logic parent ``should`` can't
+ match on just one half of the filter. A pure single positive clause
+ is returned unwrapped because there is nothing to break apart.
+ """
+ if not self.must and not self.must_not:
+ return []
+ if not self.must_not:
+ if len(self.must) == 1:
+ return list(self.must)
+ # Multi-clause positive filter — keep it atomic for OR parents.
+ return [{"bool": {"must": list(self.must)}}]
+ # Negative semantics always need wrapping so they survive being OR'd
+ # with siblings.
+ return [{"bool": {"must": list(self.must), "must_not": list(self.must_not)}}]
+
+
+@dataclass
+class MetaFilterPushdownPlan:
+ """Composed ES bool query body for an entire RAGflow filter request."""
+
+ logic: str
+ translated: List[TranslatedFilter] = field(default_factory=list)
+
+ def is_empty(self) -> bool:
+ return not self.translated
+
+ def to_query(self, kb_ids: Sequence[str]) -> Dict[str, Any]:
+ """Render the full ES query body, scoped to the given KB ids.
+
+ The KB filter is always a ``terms`` clause so the query can serve any
+ number of knowledge bases without rewriting the caller.
+ """
+ kb_clause = {"terms": {"kb_id": list(kb_ids)}}
+
+ if self.is_empty():
+ return {"query": {"bool": {"filter": [kb_clause]}}}
+
+ sub_clauses = [t.to_clauses() for t in self.translated]
+ flat_clauses: List[Dict[str, Any]] = [c for group in sub_clauses for c in group]
+
+ if self.logic == "or":
+ inner = {
+ "bool": {
+ "should": flat_clauses,
+ "minimum_should_match": 1,
+ }
+ }
+ else:
+ inner = {"bool": {"must": flat_clauses}}
+
+ return {
+ "query": {
+ "bool": {
+ "filter": [kb_clause, inner],
+ }
+ }
+ }
+
+
+class MetaFilterTranslator:
+ """Translate one user filter clause at a time into ES DSL fragments.
+
+ Stateless aside from configuration; safe to instantiate once per request
+ or share at module scope.
+ """
+
+ def __init__(self, prefix: str = META_FIELDS_PREFIX) -> None:
+ self.prefix = prefix
+
+ def field_name(self, key: str) -> str:
+ """Compose the dotted ES field path for a user metadata key."""
+ return f"{self.prefix}.{key}"
+
+ def translate(self, flt: Dict[str, Any]) -> TranslatedFilter:
+ """Translate a single filter dict into ES bool clauses.
+
+ Raises ``UnsupportedMetaFilter`` for malformed input or operator/value
+ combinations the legacy in-memory path treats as a special case (e.g.
+ list-of-strings membership in ``in``/``not in``).
+ """
+ op = flt.get("op")
+ key = flt.get("key")
+ value = flt.get("value")
+
+ if not key or not isinstance(key, str):
+ raise UnsupportedMetaFilter("filter is missing a string key", flt)
+ if op not in SUPPORTED_OPERATORS:
+ raise UnsupportedMetaFilter(f"unknown operator {op!r}", flt)
+
+ field_path = self.field_name(key)
+
+ if op == "empty":
+ return self._translate_empty(field_path)
+ if op == "not empty":
+ return self._translate_not_empty(field_path)
+ if op == "=":
+ return self._translate_equal(field_path, value, flt)
+ if op == "≠":
+ return self._translate_not_equal(field_path, value, flt)
+ if op in _RANGE_OPS:
+ return self._translate_range(field_path, op, value, flt)
+ if op == "in":
+ return self._translate_in(field_path, value, flt)
+ if op == "not in":
+ return self._translate_not_in(field_path, value, flt)
+ if op == "contains":
+ return self._translate_contains(field_path, value, flt)
+ if op == "not contains":
+ return self._translate_not_contains(field_path, value, flt)
+ if op == "start with":
+ return self._translate_start_with(field_path, value, flt)
+ if op == "end with":
+ return self._translate_end_with(field_path, value, flt)
+
+ # Unreachable: SUPPORTED_OPERATORS gate above covers every branch.
+ raise UnsupportedMetaFilter(f"no handler for operator {op!r}", flt)
+
+ def _translate_empty(self, field_path: str) -> TranslatedFilter:
+ # "empty" matches documents whose value is missing OR equals "" — same
+ # falsy semantics the in-memory ``not input`` check enforces. The
+ # blank-string check has to target ``.keyword`` because the analyzed
+ # text field drops empty values during tokenisation, leaving no token
+ # for ``term: ""`` to match.
+ return TranslatedFilter(
+ must=[
+ {
+ "bool": {
+ "should": [
+ {"bool": {"must_not": [{"exists": {"field": field_path}}]}},
+ {"term": {_keyword_path(field_path): ""}},
+ ],
+ "minimum_should_match": 1,
+ }
+ }
+ ]
+ )
+
+ def _translate_not_empty(self, field_path: str) -> TranslatedFilter:
+ return TranslatedFilter(
+ must=[{"exists": {"field": field_path}}],
+ must_not=[{"term": {_keyword_path(field_path): ""}}],
+ )
+
+ def _translate_equal(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter:
+ coerced = _coerce_scalar(value, flt)
+ return TranslatedFilter(must=[_term_or_match(field_path, coerced)])
+
+ def _translate_not_equal(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter:
+ coerced = _coerce_scalar(value, flt)
+ return TranslatedFilter(
+ must=[{"exists": {"field": field_path}}],
+ must_not=[_term_or_match(field_path, coerced)],
+ )
+
+ def _translate_range(self, field_path: str, op: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter:
+ coerced = _coerce_range_value(value, flt)
+ return TranslatedFilter(
+ must=[
+ {"exists": {"field": field_path}},
+ {"range": {field_path: {_RANGE_OPS[op]: coerced}}},
+ ]
+ )
+
+ def _translate_in(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter:
+ members = _csv_or_list(value, flt)
+ return TranslatedFilter(must=[_terms_string_or_numeric(field_path, members)])
+
+ def _translate_not_in(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter:
+ members = _csv_or_list(value, flt)
+ return TranslatedFilter(
+ must=[{"exists": {"field": field_path}}],
+ must_not=[_terms_string_or_numeric(field_path, members)],
+ )
+
+ def _translate_contains(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter:
+ text = _coerce_string(value, flt)
+ return TranslatedFilter(must=[_wildcard(field_path, f"*{_escape_wildcard(text)}*")])
+
+ def _translate_not_contains(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter:
+ text = _coerce_string(value, flt)
+ return TranslatedFilter(
+ must=[{"exists": {"field": field_path}}],
+ must_not=[_wildcard(field_path, f"*{_escape_wildcard(text)}*")],
+ )
+
+ def _translate_start_with(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter:
+ text = _coerce_string(value, flt)
+ return TranslatedFilter(
+ must=[{"prefix": {_keyword_path(field_path): {"value": text, "case_insensitive": True}}}]
+ )
+
+ def _translate_end_with(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter:
+ text = _coerce_string(value, flt)
+ return TranslatedFilter(must=[_wildcard(field_path, f"*{_escape_wildcard(text)}")])
+
+
+def build_meta_filter_query(
+ filters: Sequence[Dict[str, Any]],
+ logic: str,
+ kb_ids: Sequence[str],
+ translator: Optional[MetaFilterTranslator] = None,
+) -> Dict[str, Any]:
+ """Top-level helper: translate every filter and render the ES query body.
+
+ Raises ``UnsupportedMetaFilter`` if any filter cannot be expressed.
+ """
+ plan = plan_pushdown(filters, logic, translator=translator)
+ return plan.to_query(kb_ids)
+
+
+def plan_pushdown(
+ filters: Sequence[Dict[str, Any]],
+ logic: str,
+ translator: Optional[MetaFilterTranslator] = None,
+) -> MetaFilterPushdownPlan:
+ """Translate every filter in turn, building a single composed plan.
+
+ Separated from ``build_meta_filter_query`` so callers can inspect or
+ augment the plan before binding it to a KB scope.
+ """
+ if logic not in {"and", "or"}:
+ raise UnsupportedMetaFilter(f"unknown logic {logic!r}")
+
+ t = translator or MetaFilterTranslator()
+ plan = MetaFilterPushdownPlan(logic=logic)
+ for flt in filters:
+ plan.translated.append(t.translate(flt))
+ return plan
+
+
+def is_pushdown_supported(filters: Sequence[Dict[str, Any]]) -> bool:
+ """Cheap pre-check: do all filters look translatable without coercion?
+
+ Used by the routing layer to skip the heavier ``plan_pushdown`` call when
+ the request obviously needs the in-memory fallback.
+
+ Operators in :data:`MULTIVALUE_UNSAFE_NEGATIVE_OPS` are rejected here so a
+ single such filter forces the whole request to in-memory evaluation, which
+ is the only place we can replicate the per-bucket semantics over
+ multi-valued metadata fields.
+ """
+ for flt in filters:
+ op = flt.get("op")
+ if op not in SUPPORTED_OPERATORS:
+ return False
+ if op in MULTIVALUE_UNSAFE_NEGATIVE_OPS:
+ return False
+ if not isinstance(flt.get("key"), str) or not flt.get("key"):
+ return False
+ return True
+
+
+def extract_doc_ids(es_response: Dict[str, Any]) -> List[str]:
+ """Pull doc IDs out of an ES search response shaped like ``{hits:{hits:[...]}}``.
+
+ Tolerates both the dict-typed ES 7+ response and the dict-coerced
+ ``ObjectApiResponse`` returned by the elasticsearch python client.
+ """
+ hits_root = es_response.get("hits") if isinstance(es_response, dict) else None
+ if not hits_root:
+ # ``ObjectApiResponse`` is dict-like; ``.get`` works at both levels.
+ try:
+ hits_root = es_response["hits"]
+ except Exception:
+ return []
+
+ raw_hits: Iterable[Dict[str, Any]]
+ if isinstance(hits_root, dict):
+ raw_hits = hits_root.get("hits", []) or []
+ else:
+ raw_hits = []
+
+ out: List[str] = []
+ for hit in raw_hits:
+ if not isinstance(hit, dict):
+ continue
+ # ``id`` is mirrored into ``_source`` by the metadata writer; ``_id``
+ # is the canonical identifier. Prefer ``_id`` so renames in the source
+ # field name don't break us.
+ doc_id = hit.get("_id")
+ if not doc_id:
+ source = hit.get("_source") or {}
+ doc_id = source.get("id") or source.get("doc_id")
+ if doc_id:
+ out.append(str(doc_id))
+ return out
+
+
+# ---------------------------------------------------------------------------
+# Value coercion helpers
+# ---------------------------------------------------------------------------
+
+
+def _coerce_scalar(value: Any, flt: Dict[str, Any]) -> Any:
+ """Mirror the legacy ``ast.literal_eval`` then ``str.lower()`` flow.
+
+ The in-memory filter parses values as Python literals when possible (so
+ ``"5"`` becomes ``5``) and lower-cases strings. For ES ``term`` queries we
+ need the same coercion or numeric data won't match.
+ """
+ if value is None:
+ raise UnsupportedMetaFilter("scalar comparison value is None", flt)
+ if isinstance(value, (list, dict)):
+ raise UnsupportedMetaFilter("scalar comparison value is non-scalar", flt)
+
+ s = str(value).strip()
+ if _DATE_RE.match(s):
+ return s
+ try:
+ parsed = ast.literal_eval(s)
+ except Exception:
+ parsed = s
+ if isinstance(parsed, str):
+ return parsed.lower()
+ if isinstance(parsed, (int, float, bool)):
+ return parsed
+ return s.lower()
+
+
+def _coerce_range_value(value: Any, flt: Dict[str, Any]) -> Any:
+ """Range comparisons accept dates verbatim and numbers parsed via literal_eval.
+
+ Strings that aren't numeric or ISO dates are pushed through as-is — ES
+ will compare them lexically against keyword fields, which is the same
+ behaviour as the in-memory ``input >= value`` Python comparison after the
+ original ``ast.literal_eval`` failure path.
+ """
+ if value is None:
+ raise UnsupportedMetaFilter("range comparison value is None", flt)
+ s = str(value).strip()
+ if _DATE_RE.match(s):
+ return s
+ try:
+ parsed = ast.literal_eval(s)
+ except Exception:
+ return s
+ if isinstance(parsed, (int, float)):
+ return parsed
+ return s
+
+
+def _coerce_string(value: Any, flt: Dict[str, Any]) -> str:
+ """String operators (contains/start with/end with) need a non-empty string."""
+ if value is None:
+ raise UnsupportedMetaFilter("string-operator value is None", flt)
+ if isinstance(value, (list, dict)):
+ raise UnsupportedMetaFilter("string-operator value must be a scalar", flt)
+ s = str(value)
+ if not s:
+ raise UnsupportedMetaFilter("string-operator value is empty", flt)
+ return s
+
+
+def _csv_or_list(value: Any, flt: Dict[str, Any]) -> List[Any]:
+ """``in`` / ``not in`` accept either a real list or a comma-separated string.
+
+ The legacy in-memory path applies ``ast.literal_eval`` to the value too.
+ Mirror that for parity, then trim whitespace and lower-case any strings.
+ """
+ if value is None:
+ raise UnsupportedMetaFilter("membership value is None", flt)
+
+ if isinstance(value, (list, tuple)):
+ members = list(value)
+ elif isinstance(value, str):
+ try:
+ parsed = ast.literal_eval(value)
+ except Exception:
+ parsed = value
+ if isinstance(parsed, (list, tuple)):
+ members = list(parsed)
+ else:
+ members = [m.strip() for m in value.split(",") if m.strip()]
+ else:
+ members = [value]
+
+ if not members:
+ raise UnsupportedMetaFilter("membership value resolved to empty list", flt)
+
+ normalised: List[Any] = []
+ for m in members:
+ if isinstance(m, str):
+ normalised.append(m.lower().strip())
+ else:
+ normalised.append(m)
+ return normalised
+
+
+def _keyword_path(field_path: str) -> str:
+ """Sub-field used for exact-match string queries.
+
+ Dynamic mapping under ``meta_fields`` indexes string values as ``text``
+ with a ``.keyword`` multi-field. ``term``/``terms``/``prefix``/``wildcard``
+ against the analyzed parent breaks for any multi-word value because the
+ inverted index stores per-token entries, not the original phrase. Routing
+ string queries through ``.keyword`` keeps semantics aligned with the
+ in-memory ``meta_filter`` (full-string compare after lower-casing).
+ """
+ return f"{field_path}.keyword"
+
+
+def _term_or_match(field_path: str, value: Any) -> Dict[str, Any]:
+ """Exact-match clause that respects how dynamic mapping indexes the value.
+
+ String values target the ``.keyword`` sub-field with ``case_insensitive``
+ so phrase values still match (the in-memory path lower-cases before
+ comparing). Numeric / bool values target the parent path because numeric
+ fields have no ``.keyword`` sub-field under default dynamic mapping.
+ """
+ if isinstance(value, str):
+ return {
+ "term": {
+ _keyword_path(field_path): {
+ "value": value,
+ "case_insensitive": True,
+ }
+ }
+ }
+ return {"term": {field_path: value}}
+
+
+def _terms_string_or_numeric(field_path: str, members: List[Any]) -> Dict[str, Any]:
+ """``in``/``not in`` payload that mirrors ``_term_or_match`` per element.
+
+ ES ``terms`` does not accept ``case_insensitive``, so for string members we
+ expand into a ``bool: should`` of case-insensitive ``term`` queries on the
+ keyword sub-field. Pure-numeric / bool member lists keep the cheaper
+ ``terms`` form on the parent path.
+ """
+ if all(not isinstance(m, str) for m in members):
+ return {"terms": {field_path: members}}
+ return {
+ "bool": {
+ "should": [_term_or_match(field_path, m) for m in members],
+ "minimum_should_match": 1,
+ }
+ }
+
+
+def _wildcard(field_path: str, pattern: str) -> Dict[str, Any]:
+ """Wildcard runs against ``.keyword`` so the original phrase is searched.
+
+ ``wildcard`` against an analyzed text field walks per-token entries, which
+ drops phrase context (``Alice Wonderland`` becomes tokens ``alice``,
+ ``wonderland``). The ``.keyword`` sub-field preserves the full original
+ string, matching the in-memory ``str.find`` semantics.
+ """
+ return {
+ "wildcard": {
+ _keyword_path(field_path): {
+ "value": pattern,
+ "case_insensitive": True,
+ }
+ }
+ }
+
+
+def _escape_wildcard(text: str) -> str:
+ """Escape the two ES wildcard metacharacters so user input stays literal."""
+ return text.replace("\\", "\\\\").replace("*", "\\*").replace("?", "\\?")
diff --git a/common/metadata_utils.py b/common/metadata_utils.py
index c919bd186af..c2fc90b5414 100644
--- a/common/metadata_utils.py
+++ b/common/metadata_utils.py
@@ -42,6 +42,13 @@ def convert_conditions(metadata_condition):
def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
doc_ids = set([])
+ def normalize_string_values(value):
+ if isinstance(value, str):
+ return value.lower()
+ if isinstance(value, list):
+ return [item.lower() if isinstance(item, str) else item for item in value]
+ return value
+
def filter_out(v2docs, operator, value):
ids = []
for input, docids in v2docs.items():
@@ -96,10 +103,8 @@ def filter_out(v2docs, operator, value):
value = value.lower()
else:
# Non-comparison operators: maintain original logic
- if isinstance(input, str):
- input = input.lower()
- if isinstance(value, str):
- value = value.lower()
+ input = normalize_string_values(input)
+ value = normalize_string_values(value)
matched = False
try:
@@ -161,11 +166,13 @@ def filter_out(v2docs, operator, value):
async def apply_meta_data_filter(
meta_data_filter: dict | None,
- metas: dict,
- question: str,
+ metas: dict | None = None,
+ question: str = "",
chat_mdl: Any = None,
base_doc_ids: list[str] | None = None,
manual_value_resolver: Callable[[dict], dict] | None = None,
+ kb_ids: list[str] | None = None,
+ metas_loader: Callable[[], dict] | None = None,
) -> list[str] | None:
"""
Apply metadata filtering rules and return the filtered doc_ids.
@@ -175,6 +182,20 @@ async def apply_meta_data_filter(
- semi_auto: generate conditions using selected metadata keys only
- manual: directly filter based on provided conditions
+ When ``kb_ids`` is supplied and the active doc store is Elasticsearch the
+ generated filter conditions are pushed down to ES via
+ ``DocMetadataService.filter_doc_ids_by_meta_pushdown`` instead of being
+ evaluated in Python over ``metas``. The in-memory ``meta_filter`` path
+ remains the fallback so callers without a KB scope, or backends without
+ push-down support, behave exactly as before.
+
+ ``metas`` may be supplied eagerly or via ``metas_loader``. The loader is
+ only invoked when the metadata dict is actually needed — i.e. for the LLM
+ context in ``auto`` / ``semi_auto`` modes, or as the in-memory fallback
+ when push-down can't service a request. ``manual`` mode that lands on the
+ push-down path therefore skips the expensive
+ ``get_flatted_meta_by_kbs`` round-trip entirely.
+
Returns:
list of doc_ids, ["-999"] when manual filters yield no result, or None
when auto/semi_auto filters return empty.
@@ -188,9 +209,28 @@ async def apply_meta_data_filter(
method = meta_data_filter.get("method")
+ # Memoised metadata loader. ``_get_metas`` materialises the dict at most
+ # once per call; downstream branches that never reach an in-memory eval
+ # leave the loader untouched.
+ cached_metas: dict | None = metas
+
+ def _get_metas() -> dict:
+ nonlocal cached_metas
+ if cached_metas is None:
+ cached_metas = metas_loader() if metas_loader else {}
+ return cached_metas
+
+ def _evaluate(conditions: list[dict], logic: str) -> list[str]:
+ """Run conditions through ES push-down when possible, in-memory otherwise."""
+ if conditions and kb_ids:
+ pushed = _try_meta_pushdown(kb_ids, conditions, logic)
+ if pushed is not None:
+ return pushed
+ return meta_filter(_get_metas(), conditions, logic)
+
if method == "auto":
- filters: dict = await gen_meta_filter(chat_mdl, metas, question)
- doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
+ filters: dict = await gen_meta_filter(chat_mdl, _get_metas(), question)
+ doc_ids.extend(_evaluate(filters["conditions"], filters.get("logic", "and")))
if not doc_ids:
return None
elif method == "semi_auto":
@@ -207,23 +247,47 @@ async def apply_meta_data_filter(
constraints[key] = op
if selected_keys:
- filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
+ current_metas = _get_metas()
+ filtered_metas = {key: current_metas[key] for key in selected_keys if key in current_metas}
if filtered_metas:
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question, constraints=constraints)
- doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
+ doc_ids.extend(_evaluate(filters["conditions"], filters.get("logic", "and")))
if not doc_ids:
return None
elif method == "manual":
filters = meta_data_filter.get("manual", [])
if manual_value_resolver:
filters = [manual_value_resolver(flt) for flt in filters]
- doc_ids.extend(meta_filter(metas, filters, meta_data_filter.get("logic", "and")))
+ doc_ids.extend(_evaluate(filters, meta_data_filter.get("logic", "and")))
if filters and not doc_ids:
doc_ids = ["-999"]
return doc_ids
+def _try_meta_pushdown(
+ kb_ids: list[str],
+ conditions: list[dict],
+ logic: str,
+) -> list[str] | None:
+ """Attempt the ES push-down path; return ``None`` to fall back in-memory.
+
+ Lazy-imports ``DocMetadataService`` so this module stays usable in
+ environments where the API/db layer hasn't been wired up (e.g. unit tests
+ that exercise ``meta_filter`` directly).
+ """
+ try:
+ from api.db.services.doc_metadata_service import DocMetadataService
+ except Exception as e:
+ logging.debug(f"[apply_meta_data_filter] push-down disabled, import failed: {e}")
+ return None
+ try:
+ return DocMetadataService.filter_doc_ids_by_meta_pushdown(kb_ids, conditions, logic)
+ except Exception as e:
+ logging.warning(f"[apply_meta_data_filter] push-down errored, falling back: {e}")
+ return None
+
+
def dedupe_list(values: list) -> list:
seen = set()
deduped = []
diff --git a/common/parser_config_utils.py b/common/parser_config_utils.py
index 0bc7ffc28b3..daf91cc8e1a 100644
--- a/common/parser_config_utils.py
+++ b/common/parser_config_utils.py
@@ -29,5 +29,8 @@ def normalize_layout_recognizer(layout_recognizer_raw: Any) -> tuple[Any, str |
elif lowered.endswith("@paddleocr"):
parser_model_name = layout_recognizer_raw.rsplit("@", 1)[0]
layout_recognizer = "PaddleOCR"
+ elif lowered.endswith("@opendataloader"):
+ parser_model_name = layout_recognizer_raw.rsplit("@", 1)[0]
+ layout_recognizer = "OpenDataLoader"
return layout_recognizer, parser_model_name
diff --git a/common/settings.py b/common/settings.py
index 2b67dc34d72..49693b93701 100644
--- a/common/settings.py
+++ b/common/settings.py
@@ -17,6 +17,8 @@
import json
import secrets
import logging
+from datetime import date
+
from common.constants import RAG_FLOW_SERVICE_NAME
from common.file_utils import get_project_base_directory
from common.config_utils import get_base_config, decrypt_database_config
@@ -43,6 +45,8 @@
import memory.utils.infinity_conn as memory_infinity_conn
import memory.utils.ob_conn as memory_ob_conn
+TIMEZONE = os.getenv("TZ", "Asia/Shanghai")
+
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
@@ -137,6 +141,24 @@ def get_svr_queue_name(priority: int) -> str:
def get_svr_queue_names():
return [get_svr_queue_name(priority) for priority in [1, 0]]
+def init_secret_key():
+ secret_key = os.environ.get("RAGFLOW_SECRET_KEY")
+ if secret_key and len(secret_key) >= 32:
+ return secret_key
+
+ # Check if there's a configured secret key
+ configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key")
+ if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32:
+ return configured_key
+ return None
+
+
+def get_secret_key():
+ global SECRET_KEY
+ if SECRET_KEY is None:
+ return _get_or_create_secret_key()
+ return SECRET_KEY
+
def _get_or_create_secret_key():
# secret_key = os.environ.get("RAGFLOW_SECRET_KEY")
# if secret_key and len(secret_key) >= 32:
@@ -152,7 +174,8 @@ def _get_or_create_secret_key():
generated_key = secrets.token_hex(32)
secret_key = REDIS_CONN.get_or_create_secret_key("ragflow:system:secret_key", generated_key)
- logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.")
+ if generated_key == secret_key:
+ logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.")
return secret_key
class StorageFactory:
@@ -243,7 +266,7 @@ def init_settings():
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
global SECRET_KEY
- SECRET_KEY = _get_or_create_secret_key()
+ SECRET_KEY = init_secret_key()
# authentication
diff --git a/common/ssrf_guard.py b/common/ssrf_guard.py
new file mode 100644
index 00000000000..b60bcd4bc99
--- /dev/null
+++ b/common/ssrf_guard.py
@@ -0,0 +1,172 @@
+#
+# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Shared SSRF-guard utilities.
+
+Uses only the standard library so it can be imported from both ``api/`` and
+``common/`` without pulling in any heavyweight dependencies.
+"""
+
+import ipaddress
+import logging
+import socket
+import threading
+from contextlib import contextmanager
+from urllib.parse import urlparse
+
+logger = logging.getLogger(__name__)
+
+# ---------------------------------------------------------------------------
+# DNS pinning — closes the TOCTOU / rebinding window between SSRF validation
+# and the actual TCP connection. The monkey-patch is a no-op for any host
+# that has no active pin, so it cannot affect unrelated code.
+# ---------------------------------------------------------------------------
+
+_tl = threading.local()
+_global_dns_pins: dict[str, str] = {}
+_global_pin_lock = threading.Lock()
+_orig_getaddrinfo = socket.getaddrinfo
+
+
+def _getaddrinfo_with_pins(host, port, *args, **kwargs):
+ # Thread-local pins (synchronous callers: requests.get in the same thread)
+ local_pins: dict = getattr(_tl, "dns_pins", {})
+ if host in local_pins:
+ ip = local_pins[host]
+ family = socket.AF_INET6 if ":" in ip else socket.AF_INET
+ return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))]
+ # Process-global pins (async callers whose DNS resolves in executor threads)
+ with _global_pin_lock:
+ ip = _global_dns_pins.get(host)
+ if ip is not None:
+ family = socket.AF_INET6 if ":" in ip else socket.AF_INET
+ return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))]
+ return _orig_getaddrinfo(host, port, *args, **kwargs)
+
+
+socket.getaddrinfo = _getaddrinfo_with_pins
+
+
+@contextmanager
+def pin_dns(hostname: str, ip: str):
+ """Pin *hostname* → *ip* in the current thread for the duration of this context.
+
+ Use for synchronous ``requests.get()`` callers to prevent DNS rebinding
+ between SSRF validation and the actual TCP connection.
+ """
+ pins = _tl.__dict__.setdefault("dns_pins", {})
+ pins[hostname] = ip
+ try:
+ yield
+ finally:
+ pins.pop(hostname, None)
+
+
+@contextmanager
+def pin_dns_global(hostname: str, ip: str):
+ """Pin *hostname* → *ip* across all threads for the duration of this context.
+
+ Use for async callers (e.g. asyncio-based crawlers) where DNS resolution
+ may happen in thread-pool executor threads rather than the calling thread.
+ """
+ with _global_pin_lock:
+ _global_dns_pins[hostname] = ip
+ try:
+ yield
+ finally:
+ with _global_pin_lock:
+ _global_dns_pins.pop(hostname, None)
+
+
+_DEFAULT_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"})
+
+
+def _effective_ip(
+ ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
+) -> ipaddress.IPv4Address | ipaddress.IPv6Address:
+ """Return the IPv4 equivalent for IPv4-mapped IPv6 addresses, unchanged otherwise.
+
+ Without this normalization ``::ffff:127.0.0.1`` would pass ``is_global``
+ as an IPv6Address in some Python versions, bypassing the loopback check.
+ """
+ if isinstance(ip, ipaddress.IPv6Address):
+ mapped = ip.ipv4_mapped
+ if mapped is not None:
+ return mapped
+ return ip
+
+
+def assert_url_is_safe(
+ url: str,
+ *,
+ allowed_schemes: frozenset[str] = _DEFAULT_ALLOWED_SCHEMES,
+) -> tuple[str, str]:
+ """Raise ``ValueError`` if *url* is not safe to fetch (SSRF guard).
+
+ Checks performed in order:
+
+ 1. Scheme is in *allowed_schemes*.
+ 2. Hostname is present.
+ 3. **Every** address returned by ``getaddrinfo`` is globally routable
+ (``ip.is_global``). This is an allowlist approach: it catches private,
+ loopback, link-local, reserved, multicast, and all other
+ special-purpose ranges rather than individual deny-list flags.
+ IPv4-mapped IPv6 addresses (e.g. ``::ffff:127.0.0.1``) are normalised
+ to their IPv4 form via :func:`_effective_ip` before the check.
+
+ Returns ``(hostname, resolved_ip)`` — the first validated public IP string
+ — so the caller can **pin** that address in its HTTP client and prevent
+ DNS-rebinding attacks (the hostname is resolved exactly once).
+ """
+ parsed = urlparse(url)
+ scheme = parsed.scheme
+ if scheme not in allowed_schemes:
+ logger.warning(
+ "SSRF guard blocked URL with disallowed scheme: scheme=%r url=%r",
+ scheme,
+ url,
+ )
+ raise ValueError(f"Disallowed URL scheme: {scheme!r}. Only {sorted(allowed_schemes)} are allowed.")
+
+ hostname = parsed.hostname
+ if not hostname:
+ logger.warning("SSRF guard blocked URL with missing host: url=%r", url)
+ raise ValueError("URL is missing a host.")
+
+ try:
+ addr_infos = socket.getaddrinfo(hostname, None)
+ except socket.gaierror as exc:
+ logger.warning("SSRF guard could not resolve hostname=%r reason=%s", hostname, exc)
+ raise ValueError(f"Could not resolve hostname {hostname!r}: {exc}") from exc
+
+ resolved_ip: str | None = None
+ for _family, _type, _proto, _canonname, sockaddr in addr_infos:
+ raw_ip = ipaddress.ip_address(sockaddr[0])
+ eff_ip = _effective_ip(raw_ip)
+ if not eff_ip.is_global:
+ logger.warning(
+ "SSRF guard blocked URL: hostname=%r resolved to non-public address=%s",
+ hostname,
+ raw_ip,
+ )
+ raise ValueError(f"URL resolves to a non-public address ({raw_ip}), which is not allowed.")
+ if resolved_ip is None:
+ resolved_ip = str(raw_ip)
+
+ if resolved_ip is None:
+ logger.warning("SSRF guard blocked URL: hostname=%r resolved to no addresses", hostname)
+ raise ValueError(f"Hostname {hostname!r} resolved to no addresses.")
+
+ return hostname, resolved_ip
diff --git a/conf/infinity_mapping.json b/conf/infinity_mapping.json
index 77d26dd9604..5f7ed80f261 100644
--- a/conf/infinity_mapping.json
+++ b/conf/infinity_mapping.json
@@ -38,5 +38,6 @@
"removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"doc_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
"toc_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
- "raptor_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}
+ "raptor_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
+ "raptor_layer_int": {"type": "integer", "default": 0}
}
diff --git a/conf/llm_factories.json b/conf/llm_factories.json
index 0cadfe3679d..2fc12803d78 100644
--- a/conf/llm_factories.json
+++ b/conf/llm_factories.json
@@ -377,7 +377,7 @@
"tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK,TTS,SPEECH2TEXT,MODERATION",
"status": "1",
"rank": "950",
- "url" : "https://dashscope.aliyuncs.com/compatible-mode/v1",
+ "url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"llm": [
{
"llm_name": "qwen3.5-122b-a10b",
@@ -421,13 +421,6 @@
"model_type": "chat",
"is_tools": false
},
- {
- "llm_name": "deepseek-r1-distill-qwen-7b",
- "tags": "LLM,CHAT,32K",
- "max_tokens": 32768,
- "model_type": "chat",
- "is_tools": false
- },
{
"llm_name": "deepseek-r1-distill-qwen-14b",
"tags": "LLM,CHAT,32K",
@@ -1134,16 +1127,16 @@
"url": "https://api.deepseek.com/v1",
"llm": [
{
- "llm_name": "deepseek-chat",
+ "llm_name": "deepseek-v4-flash",
"tags": "LLM,CHAT,",
- "max_tokens": 64000,
+ "max_tokens": 1000000,
"model_type": "chat",
"is_tools": true
},
{
- "llm_name": "deepseek-reasoner",
+ "llm_name": "deepseek-v4-pro",
"tags": "LLM,CHAT,",
- "max_tokens": 64000,
+ "max_tokens": 1000000,
"model_type": "chat",
"is_tools": true
}
@@ -1557,53 +1550,52 @@
"rank": "980",
"llm": [
{
- "llm_name": "gemini-3-pro-preview",
- "tags": "LLM,CHAT,1M,IMAGE2TEXT",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-3-pro-preview",
+ "tags": "LLM,CHAT,1M,IMAGE2TEXT",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
{
- "llm_name": "gemini-2.5-flash",
- "tags": "LLM,CHAT,1024K,IMAGE2TEXT",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-2.5-flash",
+ "tags": "LLM,CHAT,1024K,IMAGE2TEXT",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
{
- "llm_name": "gemini-2.5-pro",
- "tags": "LLM,CHAT,IMAGE2TEXT,1024K",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-2.5-pro",
+ "tags": "LLM,CHAT,IMAGE2TEXT,1024K",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
{
- "llm_name": "gemini-2.5-flash-lite",
- "tags": "LLM,CHAT,1024K,IMAGE2TEXT",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-2.5-flash-lite",
+ "tags": "LLM,CHAT,1024K,IMAGE2TEXT",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
{
- "llm_name": "gemini-2.0-flash",
- "tags": "LLM,CHAT,1024K",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-2.0-flash",
+ "tags": "LLM,CHAT,1024K",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
{
- "llm_name": "gemini-2.0-flash-lite",
- "tags": "LLM,CHAT,1024K",
- "max_tokens": 1048576,
- "model_type": "image2text",
- "is_tools": true
+ "llm_name": "gemini-2.0-flash-lite",
+ "tags": "LLM,CHAT,1024K",
+ "max_tokens": 1048576,
+ "model_type": "image2text",
+ "is_tools": true
},
-
{
- "llm_name": "gemini-embedding-001",
- "tags": "TEXT EMBEDDING",
- "max_tokens": 2048,
- "model_type": "embedding"
+ "llm_name": "gemini-embedding-001",
+ "tags": "TEXT EMBEDDING",
+ "max_tokens": 2048,
+ "model_type": "embedding"
}
]
},
@@ -2949,20 +2941,6 @@
"model_type": "chat",
"is_tools": true
},
- {
- "llm_name": "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
- "tags": "LLM,CHAT,32k",
- "max_tokens": 32000,
- "model_type": "chat",
- "is_tools": true
- },
- {
- "llm_name": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
- "tags": "LLM,CHAT,32k",
- "max_tokens": 32000,
- "model_type": "chat",
- "is_tools": true
- },
{
"llm_name": "deepseek-ai/DeepSeek-V2.5",
"tags": "LLM,CHAT,32k",
@@ -4247,13 +4225,6 @@
"model_type": "chat",
"is_tools": false
},
- {
- "llm_name": "DeepSeek-R1-Distill-Qwen-7B",
- "tags": "LLM,CHAT",
- "max_tokens": 65792,
- "model_type": "chat",
- "is_tools": false
- },
{
"llm_name": "DeepSeek-R1-Distill-Qwen-1.5B",
"tags": "LLM,CHAT",
@@ -6255,6 +6226,14 @@
"rank": "910",
"llm": []
},
+ {
+ "name": "OpenDataLoader",
+ "logo": "",
+ "tags": "OCR",
+ "status": "1",
+ "rank": "920",
+ "llm": []
+ },
{
"name": "n1n",
"logo": "",
@@ -6293,6 +6272,435 @@
}
]
},
+ {
+ "name": "Astraflow",
+ "logo": "",
+ "tags": "LLM,TEXT EMBEDDING",
+ "status": "1",
+ "rank": "250",
+ "url": "https://api-us-ca.umodelverse.ai/v1",
+ "llm": [
+ {
+ "llm_name": "claude-opus-4-7",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "claude-opus-4-6",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "claude-sonnet-4-5-20250929",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "claude-haiku-4-5-20251001",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-5.4",
+ "tags": "LLM,CHAT,400k",
+ "max_tokens": 400000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-5.4-mini",
+ "tags": "LLM,CHAT,400k",
+ "max_tokens": 400000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-5.4-nano",
+ "tags": "LLM,CHAT,400k",
+ "max_tokens": 400000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-4o-mini",
+ "tags": "LLM,CHAT,128k",
+ "max_tokens": 128000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "Qwen/Qwen3-Max",
+ "tags": "LLM,CHAT,131k",
+ "max_tokens": 131072,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "Qwen/Qwen3-Coder",
+ "tags": "LLM,CHAT,131k",
+ "max_tokens": 131072,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "Qwen/Qwen3-32B",
+ "tags": "LLM,CHAT,131k",
+ "max_tokens": 131072,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "Qwen/Qwen3-VL-235B-A22B-Instruct",
+ "tags": "LLM,CHAT,131k",
+ "max_tokens": 131072,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "kimi-k2.6",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "glm-5.1",
+ "tags": "LLM,CHAT,128k",
+ "max_tokens": 128000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "MiniMax-M2.7",
+ "tags": "LLM,CHAT,1M",
+ "max_tokens": 1000000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "MiniMax-M2",
+ "tags": "LLM,CHAT,1M",
+ "max_tokens": 1000000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gemini-2.5-pro",
+ "tags": "LLM,CHAT,1M",
+ "max_tokens": 1000000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gemini-2.5-flash",
+ "tags": "LLM,CHAT,1M",
+ "max_tokens": 1000000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "qwen3-embedding-8b",
+ "tags": "TEXT EMBEDDING,8K",
+ "max_tokens": 8192,
+ "model_type": "embedding",
+ "is_tools": false
+ },
+ {
+ "llm_name": "text-embedding-3-large",
+ "tags": "TEXT EMBEDDING,8K",
+ "max_tokens": 8191,
+ "model_type": "embedding",
+ "is_tools": false
+ },
+ {
+ "llm_name": "text-embedding-ada-002",
+ "tags": "TEXT EMBEDDING,8K",
+ "max_tokens": 8191,
+ "model_type": "embedding",
+ "is_tools": false
+ }
+ ]
+ },
+ {
+ "name": "FuturMix",
+ "logo": "",
+ "tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT,SPEECH2TEXT,TTS,TEXT RE-RANK",
+ "status": "1",
+ "rank": "248",
+ "url": "https://futurmix.ai/v1",
+ "llm": [
+ {
+ "llm_name": "claude-sonnet-4-20250514",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "claude-3.5-haiku",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-4o",
+ "tags": "LLM,CHAT,128k",
+ "max_tokens": 128000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-4o-mini",
+ "tags": "LLM,CHAT,128k",
+ "max_tokens": 128000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gemini-2.5-flash",
+ "tags": "LLM,CHAT,1M",
+ "max_tokens": 1000000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gemini-2.0-flash",
+ "tags": "LLM,CHAT,1M",
+ "max_tokens": 1000000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "deepseek-chat",
+ "tags": "LLM,CHAT,64k",
+ "max_tokens": 65536,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "deepseek-reasoner",
+ "tags": "LLM,CHAT,64k",
+ "max_tokens": 65536,
+ "model_type": "chat",
+ "is_tools": false
+ },
+ {
+ "llm_name": "gpt-4o",
+ "tags": "IMAGE2TEXT,CHAT,128k",
+ "max_tokens": 128000,
+ "model_type": "image2text",
+ "is_tools": true
+ },
+ {
+ "llm_name": "text-embedding-3-small",
+ "tags": "TEXT EMBEDDING,8K",
+ "max_tokens": 8191,
+ "model_type": "embedding",
+ "is_tools": false
+ },
+ {
+ "llm_name": "text-embedding-3-large",
+ "tags": "TEXT EMBEDDING,8K",
+ "max_tokens": 8191,
+ "model_type": "embedding",
+ "is_tools": false
+ },
+ {
+ "llm_name": "tts-1",
+ "tags": "TTS",
+ "max_tokens": 4096,
+ "model_type": "tts",
+ "is_tools": false
+ },
+ {
+ "llm_name": "tts-1-hd",
+ "tags": "TTS",
+ "max_tokens": 4096,
+ "model_type": "tts",
+ "is_tools": false
+ },
+ {
+ "llm_name": "whisper-1",
+ "tags": "SPEECH2TEXT",
+ "max_tokens": 25000000,
+ "model_type": "speech2text",
+ "is_tools": false
+ },
+ {
+ "llm_name": "jina-reranker-v2-base-multilingual",
+ "tags": "RE-RANK,8k",
+ "max_tokens": 8192,
+ "model_type": "rerank",
+ "is_tools": false
+ }
+ ]
+ },
+ {
+ "name": "Astraflow-CN",
+ "logo": "",
+ "tags": "LLM,TEXT EMBEDDING",
+ "status": "1",
+ "rank": "249",
+ "url": "https://api.modelverse.cn/v1",
+ "llm": [
+ {
+ "llm_name": "claude-opus-4-7",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "claude-opus-4-6",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "claude-sonnet-4-5-20250929",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "claude-haiku-4-5-20251001",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-5.4",
+ "tags": "LLM,CHAT,400k",
+ "max_tokens": 400000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-5.4-mini",
+ "tags": "LLM,CHAT,400k",
+ "max_tokens": 400000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-5.4-nano",
+ "tags": "LLM,CHAT,400k",
+ "max_tokens": 400000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gpt-4o-mini",
+ "tags": "LLM,CHAT,128k",
+ "max_tokens": 128000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "Qwen/Qwen3-Max",
+ "tags": "LLM,CHAT,131k",
+ "max_tokens": 131072,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "Qwen/Qwen3-Coder",
+ "tags": "LLM,CHAT,131k",
+ "max_tokens": 131072,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "Qwen/Qwen3-32B",
+ "tags": "LLM,CHAT,131k",
+ "max_tokens": 131072,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "Qwen/Qwen3-VL-235B-A22B-Instruct",
+ "tags": "LLM,CHAT,131k",
+ "max_tokens": 131072,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "kimi-k2.6",
+ "tags": "LLM,CHAT,200k",
+ "max_tokens": 200000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "glm-5.1",
+ "tags": "LLM,CHAT,128k",
+ "max_tokens": 128000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "MiniMax-M2.7",
+ "tags": "LLM,CHAT,1M",
+ "max_tokens": 1000000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "MiniMax-M2",
+ "tags": "LLM,CHAT,1M",
+ "max_tokens": 1000000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gemini-2.5-pro",
+ "tags": "LLM,CHAT,1M",
+ "max_tokens": 1000000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "gemini-2.5-flash",
+ "tags": "LLM,CHAT,1M",
+ "max_tokens": 1000000,
+ "model_type": "chat",
+ "is_tools": true
+ },
+ {
+ "llm_name": "qwen3-embedding-8b",
+ "tags": "TEXT EMBEDDING,8K",
+ "max_tokens": 8192,
+ "model_type": "embedding",
+ "is_tools": false
+ },
+ {
+ "llm_name": "text-embedding-3-large",
+ "tags": "TEXT EMBEDDING,8K",
+ "max_tokens": 8191,
+ "model_type": "embedding",
+ "is_tools": false
+ },
+ {
+ "llm_name": "text-embedding-ada-002",
+ "tags": "TEXT EMBEDDING,8K",
+ "max_tokens": 8191,
+ "model_type": "embedding",
+ "is_tools": false
+ }
+ ]
+ },
{
"name": "Avian",
"logo": "",
diff --git a/conf/mapping.json b/conf/mapping.json
index f32acb02bc3..495f7c7763c 100644
--- a/conf/mapping.json
+++ b/conf/mapping.json
@@ -92,7 +92,7 @@
{
"kwd": {
"match_pattern": "regex",
- "match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
+ "match": "^(.*_(kwd|id|ids|uid|uids)|uid|id)$",
"mapping": {
"type": "keyword",
"similarity": "boolean",
diff --git a/conf/models/aliyun.json b/conf/models/aliyun.json
new file mode 100644
index 00000000000..51adef5d748
--- /dev/null
+++ b/conf/models/aliyun.json
@@ -0,0 +1,52 @@
+{
+ "name": "Aliyun",
+ "url": {
+ "default": "https://dashscope.aliyuncs.com",
+ "singapore": "https://dashscope-intl.aliyuncs.com",
+ "us": "https://dashscope-us.aliyuncs.com"
+ },
+ "url_suffix": {
+ "chat": "compatible-mode/v1/chat/completions",
+ "embedding": "compatible-mode/v1/embeddings",
+ "rerank": "compatible-api/v1/reranks",
+ "models": "api/v1/deployments/models"
+ },
+ "models": [
+ {
+ "name": "qwen-flash",
+ "max_tokens": 995904,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "text-embedding-v4",
+ "max_tokens": 8192,
+ "model_types": [
+ "embedding"
+ ]
+ },
+ {
+ "name": "text-embedding-v3",
+ "max_tokens": 8192,
+ "model_types": [
+ "embedding"
+ ]
+ },
+ {
+ "name": "qwen3-rerank",
+ "max_tokens": 8192,
+ "model_types": [
+ "rerank"
+ ]
+ }
+ ],
+ "features": {
+ "thinking": {
+ "default_value": true,
+ "supported_models": [
+ "qwen-flash"
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/conf/models/deepseek.json b/conf/models/deepseek.json
new file mode 100644
index 00000000000..146e11862a9
--- /dev/null
+++ b/conf/models/deepseek.json
@@ -0,0 +1,36 @@
+{
+ "name": "DeepSeek",
+ "url": {
+ "default": "https://api.deepseek.com"
+ },
+ "url_suffix": {
+ "chat": "chat/completions",
+ "models": "models",
+ "balance": "user/balance"
+ },
+ "class": "deepseek",
+ "models": [
+ {
+ "name": "deepseek-v4-flash",
+ "max_tokens": 1048576,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "deepseek-v4-pro",
+ "max_tokens": 1048576,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/conf/models/gitee.json b/conf/models/gitee.json
new file mode 100644
index 00000000000..630106592f2
--- /dev/null
+++ b/conf/models/gitee.json
@@ -0,0 +1,44 @@
+{
+ "name": "Gitee",
+ "url": {
+ "default": "https://api.moark.com/v1"
+ },
+ "url_suffix": {
+ "chat": "chat/completions",
+ "models": "models",
+ "status": "",
+ "balance": "tokens/packages/balance",
+ "embedding": "embedding",
+ "rerank": "rerank"
+ },
+ "models": [
+ {
+ "name": "qwen3-8b",
+ "max_tokens": 32768,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "qwen3-0.6b",
+ "max_tokens": 32768,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "glm-4.7-flash",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "BAAI/bge-reranker-v2-m3",
+ "max_tokens": 8192,
+ "model_types": [
+ "rerank"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/conf/models/google.json b/conf/models/google.json
new file mode 100644
index 00000000000..2e4cf30525f
--- /dev/null
+++ b/conf/models/google.json
@@ -0,0 +1,37 @@
+{
+ "name": "Google",
+ "url": {
+ "default": "https://generativelanguage.googleapis.com"
+ },
+ "url_suffix": {
+ "models": "v1beta/models"
+ },
+ "class": "gemini",
+ "models": [
+ {
+ "name": "gemini-2.5-flash",
+ "max_tokens": 1048576,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ }
+ ],
+ "features": {
+ "thinking": {
+ "default_value": true,
+ "supported_models": [
+ "gemini-2.5-flash"
+ ]
+ },
+ "reasoning_effort": {
+ "default_value": "high",
+ "supported_modes": [
+ "gemini-2.5-flash"
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/conf/models/huggingface.json b/conf/models/huggingface.json
new file mode 100644
index 00000000000..c46ab4a46bd
--- /dev/null
+++ b/conf/models/huggingface.json
@@ -0,0 +1,21 @@
+{
+ "name": "HuggingFace",
+ "url": {
+ "default": "https://router.huggingface.co/v1/"
+ },
+ "url-suffix": {
+ "chat": "chat/completions",
+ "models": "models",
+ "embedding": "hf-inference/models"
+ },
+ "class": "huggingface",
+ "models": [
+ {
+ "name": "openai/gpt-oss-120b:fastest",
+ "max_tokens": 32768,
+ "model_types": [
+ "chat"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/conf/models/lmstudio.json b/conf/models/lmstudio.json
new file mode 100644
index 00000000000..a22cbb982fe
--- /dev/null
+++ b/conf/models/lmstudio.json
@@ -0,0 +1,8 @@
+{
+ "name": "lmstudio",
+ "url_suffix": {
+ "chat": "chat/completions",
+ "models": "models"
+ },
+ "class": "local"
+}
\ No newline at end of file
diff --git a/conf/models/minimax.json b/conf/models/minimax.json
new file mode 100644
index 00000000000..31760ac2597
--- /dev/null
+++ b/conf/models/minimax.json
@@ -0,0 +1,104 @@
+{
+ "name": "MiniMax",
+ "url": {
+ "default": "https://api.minimaxi.com/",
+ "global": "https://api.minimax.io/"
+ },
+ "url_suffix": {
+ "chat": "v1/text/chatcompletion_v2",
+ "models": "v1/models",
+ "tts": "v1/t2a_v2",
+ "files": "v1/files/list"
+ },
+ "class": "minimax",
+ "models": [
+ {
+ "name": "minimax-m2.7",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "minimax-m2.7-highspeed",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "minimax-m2.5",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "minimax-m2.5-highspeed",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "minimax-m2.1",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "minimax-m2.1-highspeed",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "minimax-m2",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "minimax-m2-her",
+ "max_tokens": 65536,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/conf/models/moonshot.json b/conf/models/moonshot.json
new file mode 100644
index 00000000000..b9df95e0c22
--- /dev/null
+++ b/conf/models/moonshot.json
@@ -0,0 +1,84 @@
+{
+ "name": "Moonshot",
+ "url": {
+ "default": "https://api.moonshot.cn/v1"
+ },
+ "url_suffix": {
+ "chat": "chat/completions",
+ "models": "models",
+ "balance": "users/me/balance"
+ },
+ "class": "kimi",
+ "models": [
+ {
+ "name": "kimi-k2.6",
+ "max_tokens": 262144,
+ "model_types": [
+ "chat",
+ "vision"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "kimi-k2.5",
+ "max_tokens": 262144,
+ "model_types": [
+ "chat",
+ "vision"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "moonshot-v1-8k",
+ "max_tokens": 8000,
+ "model_types": [
+ "chat",
+ "vision"
+ ]
+ },
+ {
+ "name": "moonshot-v1-32k",
+ "max_tokens": 32000,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "moonshot-v1-128k",
+ "max_tokens": 128000,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "moonshot-v1-8k-vision-preview",
+ "max_tokens": 8000,
+ "model_types": [
+ "chat",
+ "vision"
+ ]
+ },
+ {
+ "name": "moonshot-v1-32k-vision-preview",
+ "max_tokens": 32000,
+ "model_types": [
+ "chat",
+ "vision"
+ ]
+ },
+ {
+ "name": "moonshot-v1-128k-vision-preview",
+ "max_tokens": 128000,
+ "model_types": [
+ "chat",
+ "vision"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/conf/models/nvidia.json b/conf/models/nvidia.json
new file mode 100644
index 00000000000..8ba81f1fd3f
--- /dev/null
+++ b/conf/models/nvidia.json
@@ -0,0 +1,461 @@
+{
+ "name": "Nvidia",
+ "url": {
+ "default": "https://integrate.api.nvidia.com/v1"
+ },
+ "url_suffix": {
+ "chat": "chat/completions",
+ "models": "models"
+ },
+ "class": "nvidia",
+ "models": [
+ {
+ "name": "abacusai/dracarys-llama-3.1-70b-instruct",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "bytedance/seed-oss-36b-instruct",
+ "max_tokens": 32768,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "deepseek-ai/deepseek-v4-flash",
+ "max_tokens": 1048576,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "deepseek-ai/deepseek-v4-pro",
+ "max_tokens": 1048576,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "deepseek-ai/deepseek-v3.2",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "deepseek-ai/deepseek-v3.1",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "google/codegemma-7b",
+ "max_tokens": 8192,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "google/gemma-2-2b-it",
+ "max_tokens": 8192,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "google/gemma-4-31b-it",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "google/gemma-7b",
+ "max_tokens": 8192,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "ibm/granite-3.3-8b-instruct",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "meta/llama-3.1-405b-instruct",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "meta/llama-3.2-90b-vision-instruct",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat",
+ "vision"
+ ]
+ },
+ {
+ "name": "meta/llama-4-maverick-17b-128e-instruct",
+ "max_tokens": 1048576,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "microsoft/phi-4-mini-flash-reasoning",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "minimaxai/minimax-m2.1",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "minimaxai/minimax-m2.5",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "minimaxai/minimax-m2.7",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "mistralai/devstral-2-123b-instruct-2512",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "mistralai/magistral-small-2506",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "mistralai/mistral-7b-instruct-v0.3",
+ "max_tokens": 32768,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "mistralai/mistral-large-3-675b-instruct-2512",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "mistralai/mistral-medium-3-5-128b",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat",
+ "vision"
+ ]
+ },
+ {
+ "name": "mistralai/mistral-nemotron",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "mistralai/mixtral-8x22b-instruct",
+ "max_tokens": 65536,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "moonshotai/kimi-k2.5",
+ "max_tokens": 262144,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "moonshotai/kimi-k2.6",
+ "max_tokens": 262144,
+ "model_types": [
+ "chat",
+ "vision"
+ ]
+ },
+ {
+ "name": "moonshotai/kimi-k2-instruct",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "moonshotai/kimi-k2-instruct-0905",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "moonshotai/kimi-k2-thinking",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "nvidia/gliner-pii",
+ "max_tokens": 4096,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/llama-3.1-nemoguard-8b-content-safety",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/llama-3.1-nemoguard-8b-topic-control",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/llama-3.1-nemotron-nano-8b-v1",
+ "max_tokens": 8192,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/llama-3.1-nemotron-safety-guard-8b-v3",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/llama-3.1-nemotron-ultra-253b-v1",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "nvidia/llama-3.2-nemoretriever-1b-vlm-embed-v1",
+ "max_tokens": 8192,
+ "model_types": [
+ "embedding"
+ ]
+ },
+ {
+ "name": "nvidia/llama-3.3-nemotron-super-49b-v1",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/llama-3.3-nemotron-super-49b-v1.5",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "nvidia/nemoguard-jailbreak-detect",
+ "max_tokens": 4096,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/nemotron-3-nano-30b-a3b",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/nemotron-3-nano-omni-30b-a3b-reasoning",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat",
+ "vision"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "nvidia/nemotron-3-super-120b-a12b",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/nemotron-content-safety-reasoning-4b",
+ "max_tokens": 8192,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/nemotron-mini-4b-instruct",
+ "max_tokens": 4096,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/nvidia-nemotron-nano-9b-v2",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/riva-translate-4b-instruct-v1_1",
+ "max_tokens": 4096,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "nvidia/usdcode",
+ "max_tokens": 8192,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "openai/gpt-oss-120b",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "qwen/qwen2.5-coder-7b-instruct",
+ "max_tokens": 32768,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "qwen/qwen3-5-122b-a10b",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "qwen/qwen3-235b-a22b",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "qwen/qwen3-coder-480b-a35b-instruct",
+ "max_tokens": 262144,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "z-ai/glm-5",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "z-ai/glm-5.1",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "z-ai/glm-4.7",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/conf/models/ollama.json b/conf/models/ollama.json
new file mode 100644
index 00000000000..ed0a1e011b9
--- /dev/null
+++ b/conf/models/ollama.json
@@ -0,0 +1,8 @@
+{
+ "name": "ollama",
+ "url_suffix": {
+ "chat": "chat/completions",
+ "models": "models"
+ },
+ "class": "local"
+}
\ No newline at end of file
diff --git a/conf/models/openai.json b/conf/models/openai.json
index f89c6c0d1db..696c6f93b3c 100644
--- a/conf/models/openai.json
+++ b/conf/models/openai.json
@@ -4,8 +4,10 @@
"default": "https://api.openai.com/v1"
},
"url_suffix": {
- "chat": "chat/completions"
+ "chat": "chat/completions",
+ "models": "models"
},
+ "class": "gpt",
"models": [
{
"name": "gpt-5.2-pro",
@@ -13,8 +15,7 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-5.2",
@@ -22,8 +23,7 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-5.1",
@@ -31,8 +31,7 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-5.1-chat-latest",
@@ -40,8 +39,7 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-5",
@@ -49,8 +47,7 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-5-mini",
@@ -58,8 +55,7 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-5-nano",
@@ -67,8 +63,7 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-5-chat-latest",
@@ -76,8 +71,7 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-4.1",
@@ -85,8 +79,7 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-4.1-mini",
@@ -94,8 +87,7 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-4.1-nano",
@@ -103,43 +95,14 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-4.5-preview",
"max_tokens": 128000,
"model_types": [
"chat"
- ],
- "features": {}
- },
- {
- "name": "o3",
- "max_tokens": 200000,
- "model_types": [
- "chat",
- "vision"
- ],
- "features": {}
- },
- {
- "name": "o4-mini",
- "max_tokens": 200000,
- "model_types": [
- "chat",
- "vision"
- ],
- "features": {}
- },
- {
- "name": "o4-mini-high",
- "max_tokens": 200000,
- "model_types": [
- "chat",
- "vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-4o-mini",
@@ -147,8 +110,7 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-4o",
@@ -156,88 +118,77 @@
"model_types": [
"chat",
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-3.5-turbo",
"max_tokens": 4096,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-3.5-turbo-16k-0613",
"max_tokens": 16385,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "text-embedding-ada-002",
"max_tokens": 8191,
"model_types": [
"embedding"
- ],
- "features": {}
+ ]
},
{
"name": "text-embedding-3-small",
"max_tokens": 8191,
"model_types": [
"embedding"
- ],
- "features": {}
+ ]
},
{
"name": "text-embedding-3-large",
"max_tokens": 8191,
"model_types": [
"embedding"
- ],
- "features": {}
+ ]
},
{
"name": "whisper-1",
"max_tokens": 26214400,
"model_types": [
"asr"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-4",
"max_tokens": 8191,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-4-turbo",
"max_tokens": 8191,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "gpt-4-32k",
"max_tokens": 32768,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "tts-1",
"max_tokens": 2048,
"model_types": [
"tts"
- ],
- "features": {}
+ ]
}
]
}
\ No newline at end of file
diff --git a/conf/models/openrouter.json b/conf/models/openrouter.json
new file mode 100644
index 00000000000..6af1e2d15df
--- /dev/null
+++ b/conf/models/openrouter.json
@@ -0,0 +1,49 @@
+{
+ "name": "OpenRouter",
+ "url": {
+ "default": "https://openrouter.ai/api/v1"
+ },
+ "url_suffix": {
+ "chat": "chat/completions",
+ "models": "models",
+ "embedding": "embeddings",
+ "rerank": "rerank",
+ "balance": "credits"
+ },
+ "class": "openrouter",
+ "models": [
+ {
+ "name": "google/gemma-4-31b-it",
+ "max_tokens": 262144,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "minimax/minimax-m2.5",
+ "max_tokens": 196608,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "tencent/hy3-preview",
+ "max_tokens": 262144,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/conf/models/siliconflow.json b/conf/models/siliconflow.json
new file mode 100644
index 00000000000..4da3e0dcab8
--- /dev/null
+++ b/conf/models/siliconflow.json
@@ -0,0 +1,50 @@
+{
+ "name": "SiliconFlow",
+ "url": {
+ "default": "https://api.siliconflow.cn/v1"
+ },
+ "url_suffix": {
+ "chat": "chat/completions",
+ "models": "models",
+ "embedding": "embeddings",
+ "rerank": "rerank",
+ "balance": "user/info"
+ },
+ "models": [
+ {
+ "name": "qwen/qwen3-8b",
+ "max_tokens": 32768,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "qwen/qwen3.5-4b",
+ "max_tokens": 262144,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "tencent/hunyuan-mt-7b",
+ "max_tokens": 32768,
+ "model_types": [
+ "chat"
+ ]
+ },
+ {
+ "name": "BAAI/bge-reranker-v2-m3",
+ "max_tokens": 8192,
+ "model_types": [
+ "rerank"
+ ]
+ },
+ {
+ "name": "Qwen/Qwen3-Embedding-0.6B",
+ "max_tokens": 8192,
+ "model_types": [
+ "embedding"
+ ]
+ }
+ ]
+}
diff --git a/conf/models/vllm.json b/conf/models/vllm.json
new file mode 100644
index 00000000000..96ec1a2403b
--- /dev/null
+++ b/conf/models/vllm.json
@@ -0,0 +1,8 @@
+{
+ "name": "vllm",
+ "url_suffix": {
+ "chat": "chat/completions",
+ "models": "models"
+ },
+ "class": "local"
+}
\ No newline at end of file
diff --git a/conf/models/volcengine.json b/conf/models/volcengine.json
new file mode 100644
index 00000000000..96a6004097a
--- /dev/null
+++ b/conf/models/volcengine.json
@@ -0,0 +1,32 @@
+{
+ "name": "VolcEngine",
+ "url": {
+ "default": "https://ark.cn-beijing.volces.com/api/v3"
+ },
+ "url_suffix": {
+ "chat": "chat/completions",
+ "files": "files",
+ "embedding": "embeddings/multimodal"
+ },
+ "class": "volcengine",
+ "models": [
+ {
+ "name": "doubao-seed-2-0-pro-260215",
+ "max_tokens": 262144,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "doubao-embedding-vision-250615",
+ "max_tokens": 131072,
+ "model_types": [
+ "embedding"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/conf/models/xai.json b/conf/models/xai.json
index 5e12776c92e..41fe7978f12 100644
--- a/conf/models/xai.json
+++ b/conf/models/xai.json
@@ -6,42 +6,37 @@
"url_suffix": {
"chat": "chat/completions"
},
+ "class": "grok",
"models": [
{
"name": "grok-4",
"max_tokens": 256000,
- "model_types": ["chat"],
- "features": {}
+ "model_types": ["chat"]
},
{
"name": "grok-3",
"max_tokens": 131072,
- "model_types": ["chat"],
- "features": {}
+ "model_types": ["chat"]
},
{
"name": "grok-3-fast",
"max_tokens": 131072,
- "model_types": ["chat"],
- "features": {}
+ "model_types": ["chat"]
},
{
"name": "grok-3-mini",
"max_tokens": 131072,
- "model_types": ["chat"],
- "features": {}
+ "model_types": ["chat"]
},
{
"name": "grok-3-mini-mini-fast",
"max_tokens": 131072,
- "model_types": ["chat"],
- "features": {}
+ "model_types": ["chat"]
},
{
"name": "grok-2-vision",
"max_tokens": 32768,
- "model_types": ["vision"],
- "features": {}
+ "model_types": ["vision"]
}
]
}
\ No newline at end of file
diff --git a/conf/models/zhipu-ai.json b/conf/models/zhipu-ai.json
index b38624bffe2..d1bbac649fd 100644
--- a/conf/models/zhipu-ai.json
+++ b/conf/models/zhipu-ai.json
@@ -7,66 +7,144 @@
"chat": "chat/completions",
"async_chat": "async/chat/completions",
"async_result": "async-result",
- "embedding": "embedding",
- "rerank": "rerank"
+ "embedding": "embeddings",
+ "rerank": "rerank",
+ "files": "files"
},
+ "class": "glm",
"models": [
+ {
+ "name": "glm-5",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "glm-5-turbo",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "glm-5v-turbo",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
{
"name": "glm-4.7",
- "max_tokens": 128000,
+ "max_tokens": 204800,
"model_types": [
"chat"
],
- "features": {}
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
},
{
- "name": "glm-4.5",
- "max_tokens": 128000,
+ "name": "glm-4.7-flashx",
+ "max_tokens": 204800,
"model_types": [
"chat"
],
- "features": {}
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "glm-4.6",
+ "max_tokens": 204800,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
},
{
"name": "glm-4.6v-Flash",
- "max_tokens": 128000,
+ "max_tokens": 131072,
"model_types": [
"chat",
"vision"
],
- "features": {}
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
+ },
+ {
+ "name": "glm-4.5",
+ "max_tokens": 131072,
+ "model_types": [
+ "chat"
+ ],
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
},
{
"name": "glm-4.5-x",
- "max_tokens": 128000,
+ "max_tokens": 131072,
"model_types": [
"chat"
],
- "features": {}
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
},
{
"name": "glm-4.5-air",
- "max_tokens": 128000,
+ "max_tokens": 131072,
"model_types": [
"chat"
],
- "features": {}
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
},
{
"name": "glm-4.5-airx",
- "max_tokens": 128000,
+ "max_tokens": 131072,
"model_types": [
"chat"
],
- "features": {}
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
},
{
"name": "glm-4.5-flash",
- "max_tokens": 128000,
+ "max_tokens": 131072,
"model_types": [
"chat"
],
- "features": {}
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
},
{
"name": "glm-4.5v",
@@ -74,168 +152,119 @@
"model_types": [
"vision"
],
- "features": {}
+ "thinking": {
+ "default_value": true,
+ "clear_thinking": true
+ }
},
{
"name": "glm-4-plus",
- "max_tokens": 128000,
+ "max_tokens": 131072,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "glm-4-0520",
- "max_tokens": 128000,
+ "max_tokens": 131072,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "glm-4",
- "max_tokens": 128000,
+ "max_tokens": 131072,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "glm-4-airx",
"max_tokens": 8000,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "glm-4-air",
- "max_tokens": 128000,
+ "max_tokens": 131072,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "glm-4-flash",
- "max_tokens": 128000,
+ "max_tokens": 131072,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "glm-4-flashx",
- "max_tokens": 128000,
+ "max_tokens": 131072,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "glm-4-long",
"max_tokens": 1000000,
"model_types": [
"chat"
- ],
- "features": {}
- },
- {
- "name": "glm-3-turbo",
- "max_tokens": 128000,
- "model_types": [
- "chat"
- ],
- "features": {}
+ ]
},
{
"name": "glm-4v",
"max_tokens": 2000,
"model_types": [
"vision"
- ],
- "features": {}
+ ]
},
{
"name": "glm-4-9b",
"max_tokens": 8192,
"model_types": [
"chat"
- ],
- "features": {}
+ ]
},
{
"name": "embedding-2",
"max_tokens": 512,
"model_types": [
"embedding"
- ],
- "features": {}
+ ]
},
{
"name": "embedding-3",
"max_tokens": 512,
"model_types": [
"embedding"
- ],
- "features": {}
+ ]
},
{
- "name": "glm-asr",
+ "name": "glm-asr-2512",
"max_tokens": 4096,
"model_types": [
"asr"
- ],
- "features": {}
+ ]
},
{
"name": "glm-tts",
"model_types": [
"tts"
- ],
- "features": {}
+ ]
},
{
"name": "glm-ocr",
"model_types": [
"ocr"
- ],
- "features": {}
+ ]
},
{
- "name": "glm-rerank",
+ "name": "rerank",
"model_types": [
"rerank"
- ],
- "features": {}
- }
- ],
- "features": {
- "thinking": {
- "default_value": true,
- "supported_models": [
- "glm-5.1",
- "glm-5",
- "glm-5v-turbo",
- "glm-4.7",
- "glm-4.6",
- "glm-4.6v",
- "glm-4.5",
- "glm-4.5v"
- ]
- },
- "clear_thinking": {
- "default_value": true,
- "supported_models": [
- "glm-5.1",
- "glm-5",
- "glm-5v-turbo",
- "glm-4.7",
- "glm-4.6",
- "glm-4.6v",
- "glm-4.5",
- "glm-4.5v"
]
}
- }
+ ]
}
\ No newline at end of file
diff --git a/conf/skill_es_mapping.json b/conf/skill_es_mapping.json
new file mode 100644
index 00000000000..a9d3cba8699
--- /dev/null
+++ b/conf/skill_es_mapping.json
@@ -0,0 +1,136 @@
+{
+ "settings": {
+ "index": {
+ "number_of_shards": 1,
+ "number_of_replicas": 0,
+ "refresh_interval": "1000ms"
+ },
+ "similarity": {
+ "scripted_sim": {
+ "type": "scripted",
+ "script": {
+ "source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
+ }
+ }
+ }
+ },
+ "mappings": {
+ "dynamic": false,
+ "properties": {
+ "skill_id": {
+ "type": "keyword",
+ "store": true
+ },
+ "space_id": {
+ "type": "keyword",
+ "store": true
+ },
+ "folder_id": {
+ "type": "keyword",
+ "store": true
+ },
+ "name": {
+ "type": "text",
+ "index": false,
+ "store": true
+ },
+ "name_tks": {
+ "type": "text",
+ "similarity": "scripted_sim",
+ "analyzer": "whitespace",
+ "store": true
+ },
+ "tags": {
+ "type": "text",
+ "index": false,
+ "store": true
+ },
+ "tags_tks": {
+ "type": "text",
+ "similarity": "scripted_sim",
+ "analyzer": "whitespace",
+ "store": true
+ },
+ "description": {
+ "type": "text",
+ "index": false,
+ "store": true
+ },
+ "description_tks": {
+ "type": "text",
+ "similarity": "scripted_sim",
+ "analyzer": "whitespace",
+ "store": true
+ },
+ "content": {
+ "type": "text",
+ "index": false,
+ "store": true
+ },
+ "content_tks": {
+ "type": "text",
+ "similarity": "scripted_sim",
+ "analyzer": "whitespace",
+ "store": true
+ },
+ "q_3072_vec": {
+ "type": "dense_vector",
+ "dims": 3072,
+ "index": true,
+ "similarity": "cosine"
+ },
+ "q_2560_vec": {
+ "type": "dense_vector",
+ "dims": 2560,
+ "index": true,
+ "similarity": "cosine"
+ },
+ "q_1536_vec": {
+ "type": "dense_vector",
+ "dims": 1536,
+ "index": true,
+ "similarity": "cosine"
+ },
+ "q_1024_vec": {
+ "type": "dense_vector",
+ "dims": 1024,
+ "index": true,
+ "similarity": "cosine"
+ },
+ "q_768_vec": {
+ "type": "dense_vector",
+ "dims": 768,
+ "index": true,
+ "similarity": "cosine"
+ },
+ "q_512_vec": {
+ "type": "dense_vector",
+ "dims": 512,
+ "index": true,
+ "similarity": "cosine"
+ },
+ "q_256_vec": {
+ "type": "dense_vector",
+ "dims": 256,
+ "index": true,
+ "similarity": "cosine"
+ },
+ "version": {
+ "type": "keyword",
+ "store": true
+ },
+ "status": {
+ "type": "keyword",
+ "store": true
+ },
+ "create_time": {
+ "type": "long",
+ "store": true
+ },
+ "update_time": {
+ "type": "long",
+ "store": true
+ }
+ }
+ }
+}
diff --git a/conf/skill_infinity_mapping.json b/conf/skill_infinity_mapping.json
new file mode 100644
index 00000000000..4e4766ea8f5
--- /dev/null
+++ b/conf/skill_infinity_mapping.json
@@ -0,0 +1,64 @@
+{
+ "skill_id": {
+ "type": "varchar",
+ "default": "",
+ "index_type": "secondary"
+ },
+ "space_id": {
+ "type": "varchar",
+ "default": "",
+ "index_type": "secondary"
+ },
+ "folder_id": {
+ "type": "varchar",
+ "default": ""
+ },
+ "name": {
+ "type": "varchar",
+ "default": "",
+ "analyzer": [
+ "rag-coarse",
+ "rag-fine"
+ ]
+ },
+ "tags": {
+ "type": "varchar",
+ "default": "",
+ "analyzer": [
+ "rag-coarse",
+ "rag-fine"
+ ]
+ },
+ "description": {
+ "type": "varchar",
+ "default": "",
+ "analyzer": [
+ "rag-coarse",
+ "rag-fine"
+ ]
+ },
+ "content": {
+ "type": "varchar",
+ "default": "",
+ "analyzer": [
+ "rag-coarse",
+ "rag-fine"
+ ]
+ },
+ "version": {
+ "type": "varchar",
+ "default": "1.0.0"
+ },
+ "status": {
+ "type": "varchar",
+ "default": "1"
+ },
+ "create_time": {
+ "type": "bigint",
+ "default": 0
+ },
+ "update_time": {
+ "type": "bigint",
+ "default": 0
+ }
+}
\ No newline at end of file
diff --git a/deepdoc/parser/docling_parser.py b/deepdoc/parser/docling_parser.py
index a2ebc400255..948a7acb0cd 100644
--- a/deepdoc/parser/docling_parser.py
+++ b/deepdoc/parser/docling_parser.py
@@ -30,10 +30,12 @@
import requests
from PIL import Image
+from common.constants import MAXIMUM_PAGE_NUMBER
+
try:
from docling.document_converter import DocumentConverter
except Exception:
- DocumentConverter = None
+ DocumentConverter = None
try:
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
@@ -44,6 +46,7 @@ class RAGFlowPdfParser:
from deepdoc.parser.utils import extract_pdf_outlines
+
class DoclingContentType(str, Enum):
IMAGE = "image"
TABLE = "table"
@@ -124,7 +127,7 @@ def check_installation(self, docling_server_url: Optional[str] = None) -> bool:
self.logger.error(f"[Docling] init DocumentConverter failed: {e}")
return False
- def __images__(self, fnm, zoomin: int = 1, page_from=0, page_to=600, callback=None):
+ def __images__(self, fnm, zoomin: int = 1, page_from=0, page_to=MAXIMUM_PAGE_NUMBER, callback=None):
self.page_from = page_from
self.page_to = page_to
bytes_io = None
@@ -350,6 +353,13 @@ def _parse_pdf_remote(
docling_server_url: Optional[str] = None,
request_timeout: Optional[int] = None,
):
+ """
+ Parses a PDF document using a remote Docling server.
+
+ Prioritizes native chunking endpoints (/v1/chunk/source, /v1alpha/chunk/source)
+ to prevent token overflow, with a graceful fallback to standard conversion
+ endpoints if chunking is unavailable.
+ """
server_url = self._effective_server_url(docling_server_url)
if not server_url:
raise RuntimeError("[Docling] DOCLING_SERVER_URL is not configured.")
@@ -372,36 +382,48 @@ def _parse_pdf_remote(
filename = Path(filepath).name or "input.pdf"
b64 = base64.b64encode(pdf_bytes).decode("ascii")
- v1_payload = {
- "options": {
- "from_formats": ["pdf"],
- "to_formats": ["json", "md", "text"],
- },
- "sources": [
- {
- "kind": "file",
- "filename": filename,
- "base64_string": b64,
- }
- ],
+
+ # Standard payloads
+ # Standard fallback payloads (no chunking)
+ v1_payload_standard = {
+ "options": {"from_formats": ["pdf"], "to_formats": ["json", "md", "text"]},
+ "sources": [{"kind": "file", "filename": filename, "base64_string": b64}],
+ }
+ v1alpha_payload_standard = {
+ "options": {"from_formats": ["pdf"], "to_formats": ["json", "md", "text"]},
+ "file_sources": [{"filename": filename, "base64_string": b64}],
+ }
+
+ # --- NEW: Correct API Contract for Chunking ---
+ chunking_opts = {
+ "from_formats": ["pdf"],
+ "to_formats": ["json", "md", "text"],
+ "do_chunking": True,
+ "chunking_options": {
+ "max_tokens": 512,
+ "overlap": 50,
+ "tokenizer": "sentencepiece" # Required by Docling contract
+ }
}
- v1alpha_payload = {
- "options": {
- "from_formats": ["pdf"],
- "to_formats": ["json", "md", "text"],
- },
- "file_sources": [
- {
- "filename": filename,
- "base64_string": b64,
- }
- ],
+ v1_payload_chunked = {
+ "options": chunking_opts,
+ "sources": [{"kind": "file", "filename": filename, "base64_string": b64}],
}
+ v1alpha_payload_chunked = {
+ "options": chunking_opts,
+ "file_sources": [{"filename": filename, "base64_string": b64}],
+ }
+
errors = []
response_json = None
- for endpoint, payload in (
- ("/v1/convert/source", v1_payload),
- ("/v1alpha/convert/source", v1alpha_payload),
+ is_chunked_response = False
+
+ # Try chunked endpoints first, then fall back to standard if the server is older
+ for endpoint, payload, chunk_flag in (
+ ("/v1/convert/source", v1_payload_chunked, True),
+ ("/v1alpha/convert/source", v1alpha_payload_chunked, True),
+ ("/v1/convert/source", v1_payload_standard, False),
+ ("/v1alpha/convert/source", v1alpha_payload_standard, False),
):
try:
resp = requests.post(
@@ -411,20 +433,57 @@ def _parse_pdf_remote(
)
if resp.status_code < 300:
response_json = resp.json()
+ is_chunked_response = chunk_flag
+
+ if chunk_flag:
+ self.logger.info(f"[Docling] Successfully used native chunking on: {endpoint}")
+ else:
+ self.logger.info(f"[Docling] Chunking unavailable, fell back to standard: {endpoint}")
break
+
+ # If chunking request is rejected (e.g., 422 Unprocessable Entity on older servers),
+ # log it and let the loop naturally fall back to the standard payload.
+ if chunk_flag:
+ self.logger.warning(f"[Docling] Server rejected chunking parameters: HTTP {resp.status_code}")
+ continue
+
errors.append(f"{endpoint}: HTTP {resp.status_code} {resp.text[:300]}")
+
except Exception as exc:
+ self.logger.error(f"[Docling] Request error on {endpoint}: {exc}")
errors.append(f"{endpoint}: {exc}")
if response_json is None:
raise RuntimeError("[Docling] remote convert failed: " + " | ".join(errors))
+ sections: list[tuple[str, ...]] = []
+ tables = []
+
+ # --- NEW: Handle Native Chunked Response ---
+ if is_chunked_response:
+ # The chunking endpoint returns an array of chunk items
+ chunks = response_json if isinstance(response_json, list) else response_json.get("results", [])
+ for chunk_data in chunks:
+ if not isinstance(chunk_data, dict):
+ continue
+ # Depending on the exact docling-serve spec, the text might be nested
+ chunk_text = chunk_data.get("text", "")
+ if not chunk_text and isinstance(chunk_data.get("chunk"), dict):
+ chunk_text = chunk_data["chunk"].get("text", "")
+
+ if isinstance(chunk_text, str) and chunk_text.strip():
+ # Feed the pre-sliced chunks directly into RAGFlow's expected format
+ sections.extend(self._sections_from_remote_text(chunk_text, parse_method=parse_method))
+
+ if callback:
+ callback(0.95, f"[Docling] Native chunks received: {len(sections)}")
+ return sections, tables
+
+ # --- FALLBACK: Standard RAGFlow parsing for older docling servers ---
docs = self._extract_remote_document_entries(response_json)
if not docs:
raise RuntimeError("[Docling] remote response does not contain parsed documents.")
- sections: list[tuple[str, ...]] = []
- tables = []
for doc in docs:
md = doc.get("md_content")
txt = doc.get("text_content")
diff --git a/deepdoc/parser/docx_parser.py b/deepdoc/parser/docx_parser.py
index 0257a320f7f..2d56729b744 100644
--- a/deepdoc/parser/docx_parser.py
+++ b/deepdoc/parser/docx_parser.py
@@ -21,6 +21,7 @@
from rag.nlp import rag_tokenizer
from io import BytesIO
import logging
+from common.constants import MAXIMUM_PAGE_NUMBER
from docx.image.exceptions import (
InvalidImageStreamError,
UnexpectedEndOfFileError,
@@ -158,7 +159,7 @@ def blockType(b):
return lines
return ["\n".join(lines)]
- def __call__(self, fnm, from_page=0, to_page=100000000):
+ def __call__(self, fnm, from_page=0, to_page=MAXIMUM_PAGE_NUMBER):
self.doc = Document(fnm) if isinstance(
fnm, str) else Document(BytesIO(fnm))
pn = 0 # parsed page
diff --git a/deepdoc/parser/html_parser.py b/deepdoc/parser/html_parser.py
index f4d360c6413..7462ad99e9f 100644
--- a/deepdoc/parser/html_parser.py
+++ b/deepdoc/parser/html_parser.py
@@ -52,7 +52,7 @@ def parser_txt(cls, txt, chunk_token_num):
raise TypeError("txt type should be string!")
temp_sections = []
- soup = BeautifulSoup(txt, "html5lib")
+ soup = BeautifulSoup(txt, "html.parser")
# delete