diff --git a/.gitignore b/.gitignore index 906c13dbfa4..f65d204fb24 100644 --- a/.gitignore +++ b/.gitignore @@ -231,3 +231,4 @@ internal/cpp/cmake-build-debug/ # Go server build output bin/* !bin/.gitkeep +.claude/settings.local.json \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index 82d23b99039..b558df135a1 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -35,7 +35,7 @@ The project uses **uv** for dependency management. 1. **Setup Environment**: ```bash uv sync --python 3.12 --all-extras - uv run download_deps.py + uv run python3 download_deps.py ``` 2. **Run Server**: diff --git a/CLAUDE.md b/CLAUDE.md index f42613a6697..81888ba3d71 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -52,7 +52,7 @@ RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on d ```bash # Install Python dependencies uv sync --python 3.12 --all-extras -uv run download_deps.py +uv run python3 download_deps.py pre-commit install # Start dependent services diff --git a/README.md b/README.md index 4574d64554d..5f8bed3db16 100644 --- a/README.md +++ b/README.md @@ -10,9 +10,9 @@ 繁體版中文自述文件 日本語のREADME 한국어 + README en Français Bahasa Indonesia Português(Brasil) - README en Français README in Arabic Türkçe README

@@ -22,10 +22,10 @@ follow on X(Twitter) - Static Badge + Static Badge - docker pull infiniflow/ragflow:v0.25.0 + docker pull infiniflow/ragflow:v0.25.2 Latest Release @@ -39,11 +39,10 @@

+ Cloud | Document | Roadmap | - Twitter | - Discord | - Demo + Discord

@@ -58,11 +57,11 @@ 📕 Table of Contents - 💡 [What is RAGFlow?](#-what-is-ragflow) -- 🎮 [Demo](#-demo) +- 🎮 [Get Started](#-get-started) - 📌 [Latest Updates](#-latest-updates) - 🌟 [Key Features](#-key-features) - 🔎 [System Architecture](#-system-architecture) -- 🎬 [Get Started](#-get-started) +- 🎬 [Self-Hosting](#-self-hosting) - 🔧 [Configurations](#-configurations) - 🔧 [Build a Docker image](#-build-a-docker-image) - 🔨 [Launch service from source for development](#-launch-service-from-source-for-development) @@ -77,9 +76,9 @@ [RAGFlow](https://ragflow.io/) is a leading open-source Retrieval-Augmented Generation ([RAG](https://ragflow.io/basics/what-is-rag)) engine that fuses cutting-edge RAG with Agent capabilities to create a superior context layer for LLMs. It offers a streamlined RAG workflow adaptable to enterprises of any scale. Powered by a converged [context engine](https://ragflow.io/basics/what-is-agent-context-engine) and pre-built agent templates, RAGFlow enables developers to transform complex data into high-fidelity, production-ready AI systems with exceptional efficiency and precision. -## 🎮 Demo +## 🎮 Get Started -Try our demo at [https://cloud.ragflow.io](https://cloud.ragflow.io). +Try our cloud service at [https://cloud.ragflow.io](https://cloud.ragflow.io).
@@ -88,6 +87,7 @@ Try our demo at [https://cloud.ragflow.io](https://cloud.ragflow.io). ## 🔥 Latest Updates +- 2026-04-24 Supports DeepSeek v4. - 2026-03-24 [RAGFlow Skill on OpenClaw](https://clawhub.ai/yingfeng/ragflow-skill) — Provides an official skill for accessing RAGFlow datasets via OpenClaw. - 2025-12-26 Supports 'Memory' for AI agent. - 2025-11-19 Supports Gemini 3 Pro. @@ -144,7 +144,7 @@ releases! 🌟
-## 🎬 Get Started +## 🎬 Self-Hosting ### 📝 Prerequisites @@ -192,12 +192,12 @@ releases! 🌟 > All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64. > If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system. -> The command below downloads the `v0.25.0` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.25.0`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. +> The command below downloads the `v0.25.2` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.25.2`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. ```bash $ cd ragflow/docker - # git checkout v0.25.0 + # git checkout v0.25.2 # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) # This step ensures the **entrypoint.sh** file in the code matches the Docker image version. @@ -405,7 +405,7 @@ See the [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/1224 ## 🏄 Community - [Discord](https://discord.gg/NjYzJD3GM3) -- [Twitter](https://twitter.com/infiniflowai) +- [X](https://x.com/infiniflowai) - [GitHub Discussions](https://github.com/orgs/infiniflow/discussions) ## 🙌 Contributing diff --git a/README_ar.md b/README_ar.md index d03fa2a1eee..a02003d8342 100644 --- a/README_ar.md +++ b/README_ar.md @@ -10,9 +10,9 @@ 繁體版中文自述文件 日本語のREADME 한국어 + README en Français Bahasa Indonesia Português(Brasil) - README en Français README in Arabic Türkçe README

@@ -22,10 +22,10 @@ follow on X(Twitter) - Static Badge + Static Badge - docker pull infiniflow/ragflow:v0.25.0 + docker pull infiniflow/ragflow:v0.25.2 Latest Release @@ -39,11 +39,10 @@

+ Cloud | Document | Roadmap | - Twitter | - Discord | - Demo + Discord

@@ -58,11 +57,11 @@ 📕 جدول المحتويات - 💡 [ما هو RAGFlow؟](#-what-is-ragflow) -- 🎮 [Demo](#-demo) +- 🎮 [ابدأ](#-get-started) - 📌 [آخر التحديثات](#-latest-updates) - 🌟 [الميزات الرئيسية](#-key-features) - 🔎 [بنية النظام](#-system-architecture) -- 🎬 [ابدأ](#-get-started) +- 🎬 [الاستضافة الذاتية](#-self-hosting) - 🔧 [التكوينات](#-configurations) - 🔧 [إنشاء صورة Docker](#-build-a-docker-image) - 🔨 [إطلاق الخدمة من المصدر للتطوير](#-launch-service-from-source-for-development) @@ -77,7 +76,7 @@ يُعد مشروع [RAGFlow](https://ragflow.io/) محركًا رائدًا ومفتوح المصدر للاسترجاع المعزز بالتوليد (RAG)، ويجمع أحدث تقنيات RAG مع قدرات الوكلاء لبناء طبقة سياق متقدمة لنماذج LLMs. يوفّر سير عمل RAG مبسّطًا وقابلًا للتكيّف مع المؤسسات بمختلف أحجامها. وبالاعتماد على [محرك سياق موحّد](https://ragflow.io/basics/what-is-agent-context-engine) وقوالب وكلاء جاهزة، يتيح RAGFlow للمطورين تحويل البيانات المعقّدة إلى أنظمة AI عالية الدقة وجاهزة للإنتاج بكفاءة وموثوقية. -## 🎮 Demo +## 🎮 ابدأ جرّب النسخة التجريبية على [https://cloud.ragflow.io](https://cloud.ragflow.io). @@ -88,8 +87,9 @@ ## 🔥 آخر التحديثات -- 2026-03-24 [RAGFlow Skill on OpenClaw](https://clawhub.ai/yingfeng/ragflow-skill) — توفر مهارة رسمية للوصول إلى مجموعات بيانات RAGFlow عبر OpenClaw. -- 2025-12-26 يدعم ميزة "Memory" لوكلاء الذكاء الاصطناعي. +- 24-04-2026 يدعم DeepSeek v4. +- 24-03-2026 [RAGFlow Skill on OpenClaw](https://clawhub.ai/yingfeng/ragflow-skill) — توفر مهارة رسمية للوصول إلى مجموعات بيانات RAGFlow عبر OpenClaw. +- 26-12-2025 يدعم ميزة "Memory" لوكلاء الذكاء الاصطناعي. - 11-11-2025 يدعم Gemini 3 Pro. - 12-11-2025 يدعم مزامنة البيانات من Confluence، S3، Notion، Discord، Google Drive. - 23-10-2025 يدعم MinerU وDocling كطرق لتحليل المستندات. @@ -144,7 +144,7 @@
-## 🎬 ابدأ +## 🎬 الاستضافة الذاتية ### 📝 المتطلبات الأساسية @@ -192,12 +192,12 @@ > جميع الصور Docker مصممة لمنصات x86. لا نعرض حاليًا صور Docker لـ ARM64. > إذا كنت تستخدم نظامًا أساسيًا ARM64، فاتبع [هذا الدليل](https://ragflow.io/docs/dev/build_docker_image) لإنشاء صورة Docker متوافقة مع نظامك. -> يقوم الأمر أدناه بتنزيل إصدار `v0.25.0` من الصورة RAGFlow Docker. راجع الجدول التالي للحصول على أوصاف لإصدارات RAGFlow المختلفة. لتنزيل إصدار RAGFlow مختلف عن `v0.25.0`، قم بتحديث المتغير `RAGFLOW_IMAGE` وفقًا لذلك في **docker/.env** قبل استخدام `docker compose` لبدء تشغيل الخادم. +> يقوم الأمر أدناه بتنزيل إصدار `v0.25.2` من الصورة RAGFlow Docker. راجع الجدول التالي للحصول على أوصاف لإصدارات RAGFlow المختلفة. لتنزيل إصدار RAGFlow مختلف عن `v0.25.2`، قم بتحديث المتغير `RAGFLOW_IMAGE` وفقًا لذلك في **docker/.env** قبل استخدام `docker compose` لبدء تشغيل الخادم. ```bash $ cd ragflow/docker - # git checkout v0.25.0 + # git checkout v0.25.2 # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) # This step ensures the **entrypoint.sh** file in the code matches the Docker image version. @@ -405,7 +405,7 @@ docker build --platform linux/amd64 \ ## 🏄 المجتمع - [Discord](https://discord.gg/NjYzJD3GM3) -- [Twitter](https://twitter.com/infiniflowai) +- [X](https://x.com/infiniflowai) - [مناقشات جيثب](https://github.com/orgs/infiniflow/discussions) ## 🙌 المساهمة diff --git a/README_fr.md b/README_fr.md index 301cbba2853..37253de7e60 100644 --- a/README_fr.md +++ b/README_fr.md @@ -10,9 +10,9 @@ 繁體版中文自述文件 日本語のREADME 한국어 + README en Français Bahasa Indonesia Português(Brasil) - README en Français README in Arabic Türkçe README

@@ -22,10 +22,10 @@ suivre sur X(Twitter) - Badge statique + Badge statique - docker pull infiniflow/ragflow:v0.25.0 + docker pull infiniflow/ragflow:v0.25.2 Dernière version @@ -39,11 +39,10 @@

+ Cloud | Documentation | Roadmap | - Twitter | - Discord | - Démo + Discord

@@ -58,11 +57,11 @@ 📕 Table des matières - 💡 [Qu'est-ce que RAGFlow?](#-quest-ce-que-ragflow) -- 🎮 [Démo](#-démo) +- 🎮 [Démarrage](#-démarrage) - 📌 [Dernières mises à jour](#-dernières-mises-à-jour) - 🌟 [Fonctionnalités clés](#-fonctionnalités-clés) - 🔎 [Architecture du système](#-architecture-du-système) -- 🎬 [Démarrage](#-démarrage) +- 🎬 [Auto-hébergement](#-auto-hébergement) - 🔧 [Configurations](#-configurations) - 🔧 [Construire une image Docker](#-construire-une-image-docker) - 🔨 [Lancer le service depuis les sources pour le développement](#-lancer-le-service-depuis-les-sources-pour-le-développement) @@ -77,9 +76,9 @@ [RAGFlow](https://ragflow.io/) est un moteur de [RAG](https://ragflow.io/basics/what-is-rag) (Retrieval-Augmented Generation) open-source de premier plan qui fusionne les technologies RAG de pointe avec des capacités Agent pour créer une couche de contexte supérieure pour les LLM. Il offre un flux de travail RAG rationalisé, adaptable aux entreprises de toute taille. Alimenté par un [moteur de contexte](https://ragflow.io/basics/what-is-agent-context-engine) convergent et des modèles d'agents préconstruits, RAGFlow permet aux développeurs de transformer des données complexes en systèmes d'IA haute-fidélité, prêts pour la production, avec une efficacité et une précision exceptionnelles. -## 🎮 Démo +## 🎮 Démarrage -Essayez notre démo sur [https://cloud.ragflow.io](https://cloud.ragflow.io). +Essayez notre service cloud sur [https://cloud.ragflow.io](https://cloud.ragflow.io).
@@ -88,6 +87,7 @@ Essayez notre démo sur [https://cloud.ragflow.io](https://cloud.ragflow.io). ## 🔥 Dernières mises à jour +- 24-04-2026 Prise en charge de DeepSeek v4. - 24-03-2026 [RAGFlow Skill on OpenClaw](https://clawhub.ai/yingfeng/ragflow-skill) — Fournit un skill officiel pour accéder aux datasets RAGFlow via OpenClaw. - 26-12-2025 Prise en charge de la « Mémoire » pour l'agent IA. - 19-11-2025 Prise en charge de Gemini 3 Pro. @@ -142,7 +142,7 @@ Essayez notre démo sur [https://cloud.ragflow.io](https://cloud.ragflow.io).
-## 🎬 Démarrage +## 🎬 Auto-hébergement ### 📝 Prérequis @@ -189,12 +189,12 @@ Essayez notre démo sur [https://cloud.ragflow.io](https://cloud.ragflow.io). > Toutes les images Docker sont construites pour les plateformes x86. Nous ne proposons pas actuellement d'images Docker pour ARM64. > Si vous êtes sur une plateforme ARM64, suivez [ce guide](https://ragflow.io/docs/dev/build_docker_image) pour construire une image Docker compatible avec votre système. -> La commande ci-dessous télécharge l'édition `v0.25.0` de l'image Docker RAGFlow. Consultez le tableau suivant pour les descriptions des différentes éditions de RAGFlow. Pour télécharger une édition de RAGFlow différente de `v0.25.0`, mettez à jour la variable `RAGFLOW_IMAGE` dans **docker/.env** avant d'utiliser `docker compose` pour démarrer le serveur. +> La commande ci-dessous télécharge l'édition `v0.25.2` de l'image Docker RAGFlow. Consultez le tableau suivant pour les descriptions des différentes éditions de RAGFlow. Pour télécharger une édition de RAGFlow différente de `v0.25.2`, mettez à jour la variable `RAGFLOW_IMAGE` dans **docker/.env** avant d'utiliser `docker compose` pour démarrer le serveur. ```bash $ cd ragflow/docker - # git checkout v0.25.0 + # git checkout v0.25.2 # Optionnel : utiliser un tag stable (voir les versions : https://github.com/infiniflow/ragflow/releases) # Cette étape garantit que le fichier **entrypoint.sh** dans le code correspond à la version de l'image Docker. @@ -396,7 +396,7 @@ Voir la [Feuille de route RAGFlow 2026](https://github.com/infiniflow/ragflow/is ## 🏄 Communauté - [Discord](https://discord.gg/NjYzJD3GM3) -- [Twitter](https://twitter.com/infiniflowai) +- [X](https://x.com/infiniflowai) - [GitHub Discussions](https://github.com/orgs/infiniflow/discussions) ## 🙌 Contribuer diff --git a/README_id.md b/README_id.md index e275e1b6264..d2cecfcfc5a 100644 --- a/README_id.md +++ b/README_id.md @@ -10,9 +10,9 @@ 繁體中文版自述文件 日本語のREADME 한국어 + README en Français Bahasa Indonesia Português(Brasil) - README en Français README in Arabic Türkçe README

@@ -22,10 +22,10 @@ Ikuti di X (Twitter) - Lencana Daring + Lencana Daring - docker pull infiniflow/ragflow:v0.25.0 + docker pull infiniflow/ragflow:v0.25.2 Rilis Terbaru @@ -39,11 +39,10 @@

+ Cloud | Dokumentasi | Peta Jalan | - Twitter | - Discord | - Demo + Discord

@@ -58,11 +57,11 @@ 📕 Daftar Isi - 💡 [Apa Itu RAGFlow?](#-apa-itu-ragflow) -- 🎮 [Demo](#-demo) +- 🎮 [Mulai](#-mulai) - 📌 [Pembaruan Terbaru](#-pembaruan-terbaru) - 🌟 [Fitur Utama](#-fitur-utama) - 🔎 [Arsitektur Sistem](#-arsitektur-sistem) -- 🎬 [Mulai](#-mulai) +- 🎬 [Pengelolaan Mandiri](#-pengelolaan-mandiri) - 🔧 [Konfigurasi](#-konfigurasi) - 🔧 [Membangun Image Docker](#-membangun-docker-image) - 🔨 [Meluncurkan aplikasi dari Sumber untuk Pengembangan](#-meluncurkan-aplikasi-dari-sumber-untuk-pengembangan) @@ -77,9 +76,9 @@ [RAGFlow](https://ragflow.io/) adalah mesin [RAG](https://ragflow.io/basics/what-is-rag) (Retrieval-Augmented Generation) open-source terkemuka yang mengintegrasikan teknologi RAG mutakhir dengan kemampuan Agent untuk menciptakan lapisan kontekstual superior bagi LLM. Menyediakan alur kerja RAG yang efisien dan dapat diadaptasi untuk perusahaan segala skala. Didukung oleh mesin konteks terkonvergensi dan template Agent yang telah dipra-bangun, RAGFlow memungkinkan pengembang mengubah data kompleks menjadi sistem AI kesetiaan-tinggi dan siap-produksi dengan efisiensi dan presisi yang luar biasa. -## 🎮 Demo +## 🎮 Mulai -Coba demo kami di [https://cloud.ragflow.io](https://cloud.ragflow.io). +Coba layanan cloud kami di [https://cloud.ragflow.io](https://cloud.ragflow.io).
@@ -88,6 +87,7 @@ Coba demo kami di [https://cloud.ragflow.io](https://cloud.ragflow.io). ## 🔥 Pembaruan Terbaru +- 2026-04-24 Mendukung DeepSeek v4. - 2026-03-24 [RAGFlow Skill on OpenClaw](https://clawhub.ai/yingfeng/ragflow-skill) — Menyediakan skill resmi untuk mengakses dataset RAGFlow melalui OpenClaw. - 2025-12-26 Mendukung 'Memori' untuk agen AI. - 2025-11-19 Mendukung Gemini 3 Pro. @@ -144,7 +144,7 @@ Coba demo kami di [https://cloud.ragflow.io](https://cloud.ragflow.io).
-## 🎬 Mulai +## 🎬 Pengelolaan Mandiri ### 📝 Prasyarat @@ -192,12 +192,12 @@ Coba demo kami di [https://cloud.ragflow.io](https://cloud.ragflow.io). > Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64. > Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image). -> Perintah di bawah ini mengunduh edisi v0.25.0 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.25.0, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. +> Perintah di bawah ini mengunduh edisi v0.25.2 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.25.2, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. ```bash $ cd ragflow/docker - # git checkout v0.25.0 + # git checkout v0.25.2 # Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases) # This steps ensures the **entrypoint.sh** file in the code matches the Docker image version. @@ -377,7 +377,7 @@ Lihat [Roadmap RAGFlow 2026](https://github.com/infiniflow/ragflow/issues/12241) ## 🏄 Komunitas - [Discord](https://discord.gg/NjYzJD3GM3) -- [Twitter](https://twitter.com/infiniflowai) +- [X](https://x.com/infiniflowai) - [GitHub Discussions](https://github.com/orgs/infiniflow/discussions) ## 🙌 Kontribusi diff --git a/README_ja.md b/README_ja.md index 84f42b05876..1d4100d2eda 100644 --- a/README_ja.md +++ b/README_ja.md @@ -10,9 +10,9 @@ 繁體中文版自述文件 日本語のREADME 한국어 + README en Français Bahasa Indonesia Português(Brasil) - README en Français README in Arabic Türkçe README

@@ -22,10 +22,10 @@ follow on X(Twitter) - Static Badge + Static Badge - docker pull infiniflow/ragflow:v0.25.0 + docker pull infiniflow/ragflow:v0.25.2 Latest Release @@ -39,11 +39,10 @@

+ Cloud | Document | Roadmap | - Twitter | - Discord | - Demo + Discord

@@ -58,9 +57,9 @@ [RAGFlow](https://ragflow.io/) は、先進的な[RAG](https://ragflow.io/basics/what-is-rag)(Retrieval-Augmented Generation)技術と Agent 機能を融合し、大規模言語モデル(LLM)に優れたコンテキスト層を構築する最先端のオープンソース RAG エンジンです。あらゆる規模の企業に対応可能な合理化された RAG ワークフローを提供し、統合型[コンテキストエンジン](https://ragflow.io/basics/what-is-agent-context-engine)と事前構築されたAgentテンプレートにより、開発者が複雑なデータを驚異的な効率性と精度で高精細なプロダクションレディAIシステムへ変換することを可能にします。 -## 🎮 Demo +## 🎮 はじめに -デモをお試しください:[https://cloud.ragflow.io](https://cloud.ragflow.io)。 +当社のクラウドサービスをぜひお試しください:[https://cloud.ragflow.io](https://cloud.ragflow.io)。
@@ -69,6 +68,7 @@ ## 🔥 最新情報 +- 2026-04-24 DeepSeek v4 をサポート。 - 2026-03-24 [RAGFlow Skill on OpenClaw](https://clawhub.ai/yingfeng/ragflow-skill) — OpenClaw経由でRAGFlowデータセットにアクセスする公式スキルを提供。 - 2025-12-26 AIエージェントの「メモリ」機能をサポート。 - 2025-11-19 Gemini 3 Proをサポートしています。 @@ -125,7 +125,7 @@
-## 🎬 初期設定 +## 🎬 セルフホスティング ### 📝 必要条件 @@ -172,12 +172,12 @@ > 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。 > ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。 -> 以下のコマンドは、RAGFlow Docker イメージの v0.25.0 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.25.0 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。 +> 以下のコマンドは、RAGFlow Docker イメージの v0.25.2 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.25.2 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。 ```bash $ cd ragflow/docker - # git checkout v0.25.0 + # git checkout v0.25.2 # 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) # この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。 @@ -377,7 +377,7 @@ docker build --platform linux/amd64 \ ## 🏄 コミュニティ - [Discord](https://discord.gg/NjYzJD3GM3) -- [Twitter](https://twitter.com/infiniflowai) +- [X](https://x.com/infiniflowai) - [GitHub Discussions](https://github.com/orgs/infiniflow/discussions) ## 🙌 コントリビュート diff --git a/README_ko.md b/README_ko.md index 578e247e9fa..2d293a44f72 100644 --- a/README_ko.md +++ b/README_ko.md @@ -10,9 +10,9 @@ 繁體版中文自述文件 日本語のREADME 한국어 + README en Français Bahasa Indonesia Português(Brasil) - README en Français README in Arabic Türkçe README

@@ -22,10 +22,10 @@ follow on X(Twitter) - Static Badge + Static Badge - docker pull infiniflow/ragflow:v0.25.0 + docker pull infiniflow/ragflow:v0.25.2 Latest Release @@ -39,11 +39,10 @@

+ Cloud | Document | Roadmap | - Twitter | - Discord | - Demo + Discord

@@ -59,9 +58,9 @@ [RAGFlow](https://ragflow.io/) 는 최첨단 [RAG](https://ragflow.io/basics/what-is-rag)(Retrieval-Augmented Generation)와 Agent 기능을 융합하여 대규모 언어 모델(LLM)을 위한 우수한 컨텍스트 계층을 생성하는 선도적인 오픈소스 RAG 엔진입니다. 모든 규모의 기업에 적용 가능한 효율적인 RAG 워크플로를 제공하며, 통합 [컨텍스트 엔진](https://ragflow.io/basics/what-is-agent-context-engine)과 사전 구축된 Agent 템플릿을 통해 개발자들이 복잡한 데이터를 예외적인 효율성과 정밀도로 고급 구현도의 프로덕션 준비 완료 AI 시스템으로 변환할 수 있도록 지원합니다. -## 🎮 데모 +## 🎮 시작하기 -데모를 [https://cloud.ragflow.io](https://cloud.ragflow.io)에서 실행해 보세요. +[https://cloud.ragflow.io](https://cloud.ragflow.io)에서 저희 클라우드 서비스를 이용해 보세요.
@@ -70,6 +69,7 @@ ## 🔥 업데이트 +- 2026-04-24 DeepSeek v4를 지원합니다. - 2026-03-24 [RAGFlow Skill on OpenClaw](https://clawhub.ai/yingfeng/ragflow-skill) — OpenClaw를 통해 RAGFlow 데이터셋에 접근하는 공식 스킬 제공. - 2025-12-26 AI 에이전트의 '메모리' 기능 지원. - 2025-11-19 Gemini 3 Pro를 지원합니다. @@ -126,7 +126,7 @@
-## 🎬 시작하기 +## 🎬 자체 호스팅 ### 📝 사전 준비 사항 @@ -174,12 +174,12 @@ > 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다. > ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image). - > 아래 명령어는 RAGFlow Docker 이미지의 v0.25.0 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.25.0과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. + > 아래 명령어는 RAGFlow Docker 이미지의 v0.25.2 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.25.2와 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. ```bash $ cd ragflow/docker - # git checkout v0.25.0 + # git checkout v0.25.2 # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) # 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다. @@ -381,7 +381,7 @@ docker build --platform linux/amd64 \ ## 🏄 커뮤니티 - [Discord](https://discord.gg/NjYzJD3GM3) -- [Twitter](https://twitter.com/infiniflowai) +- [X](https://x.com/infiniflowai) - [GitHub Discussions](https://github.com/orgs/infiniflow/discussions) ## 🙌 컨트리뷰션 diff --git a/README_pt_br.md b/README_pt_br.md index 88f34b19532..c830f1facd8 100644 --- a/README_pt_br.md +++ b/README_pt_br.md @@ -10,9 +10,9 @@ 繁體版中文自述文件 日本語のREADME 한국어 + README en Français Bahasa Indonesia Português(Brasil) - README en Français README in Arabic Türkçe README

@@ -22,10 +22,10 @@ seguir no X(Twitter) - Badge Estático + Badge Estático - docker pull infiniflow/ragflow:v0.25.0 + docker pull infiniflow/ragflow:v0.25.2 Última Versão @@ -39,11 +39,10 @@

+ Cloud | Documentação | Roadmap | - Twitter | - Discord | - Demo + Discord

@@ -58,11 +57,11 @@ 📕 Índice - 💡 [O que é o RAGFlow?](#-o-que-é-o-ragflow) -- 🎮 [Demo](#-demo) +- 🎮 [Primeiros Passos](#-primeiros-passos) - 📌 [Últimas Atualizações](#-últimas-atualizações) - 🌟 [Principais Funcionalidades](#-principais-funcionalidades) - 🔎 [Arquitetura do Sistema](#-arquitetura-do-sistema) -- 🎬 [Primeiros Passos](#-primeiros-passos) +- 🎬 [Auto-hospedagem](#-auto-hospedagem) - 🔧 [Configurações](#-configurações) - 🔧 [Construir uma imagem docker sem incorporar modelos](#-construir-uma-imagem-docker-sem-incorporar-modelos) - 🔧 [Construir uma imagem docker incluindo modelos](#-construir-uma-imagem-docker-incluindo-modelos) @@ -78,9 +77,9 @@ [RAGFlow](https://ragflow.io/) é um mecanismo de [RAG](https://ragflow.io/basics/what-is-rag) (Retrieval-Augmented Generation) open-source líder que fusiona tecnologias RAG de ponta com funcionalidades Agent para criar uma camada contextual superior para LLMs. Oferece um fluxo de trabalho RAG otimizado adaptável a empresas de qualquer escala. Alimentado por [um motor de contexto](https://ragflow.io/basics/what-is-agent-context-engine) convergente e modelos Agent pré-construídos, o RAGFlow permite que desenvolvedores transformem dados complexos em sistemas de IA de alta fidelidade e pronto para produção com excepcional eficiência e precisão. -## 🎮 Demo +## 🎮 Primeiros Passos -Experimente nossa demo em [https://cloud.ragflow.io](https://cloud.ragflow.io). +Experimente o nosso serviço na nuvem em [https://cloud.ragflow.io](https://cloud.ragflow.io).
@@ -89,6 +88,7 @@ Experimente nossa demo em [https://cloud.ragflow.io](https://cloud.ragflow.io). ## 🔥 Últimas Atualizações +- 24-04-2026 Suporta DeepSeek v4. - 24-03-2026 [RAGFlow Skill on OpenClaw](https://clawhub.ai/yingfeng/ragflow-skill) — Fornece um skill oficial para acessar datasets do RAGFlow via OpenClaw. - 26-12-2025 Suporte à função 'Memória' para agentes de IA. - 19-11-2025 Suporta Gemini 3 Pro. @@ -145,7 +145,7 @@ Experimente nossa demo em [https://cloud.ragflow.io](https://cloud.ragflow.io).
-## 🎬 Primeiros Passos +## 🎬 Auto-hospedagem ### 📝 Pré-requisitos @@ -192,12 +192,12 @@ Experimente nossa demo em [https://cloud.ragflow.io](https://cloud.ragflow.io). > Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64. > Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema. - > O comando abaixo baixa a edição`v0.25.0` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.25.0`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. + > O comando abaixo baixa a edição`v0.25.2` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.25.2`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. ```bash $ cd ragflow/docker - # git checkout v0.25.0 + # git checkout v0.25.2 # Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases) # Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker. @@ -394,7 +394,7 @@ Veja o [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241 ## 🏄 Comunidade - [Discord](https://discord.gg/NjYzJD3GM3) -- [Twitter](https://twitter.com/infiniflowai) +- [X](https://x.com/infiniflowai) - [GitHub Discussions](https://github.com/orgs/infiniflow/discussions) ## 🙌 Contribuindo diff --git a/README_tr.md b/README_tr.md index 89be2c0d790..c022dcbf7a1 100644 --- a/README_tr.md +++ b/README_tr.md @@ -10,9 +10,9 @@ 繁體版中文自述文件 日本語のREADME 한국어 + README en Français Bahasa Indonesia Português(Brasil) - README en Français README in Arabic Türkçe README

@@ -22,10 +22,10 @@ X(Twitter)'da takip et - Çevrimiçi Demo + Çevrimiçi Demo - docker pull infiniflow/ragflow:v0.25.0 + docker pull infiniflow/ragflow:v0.25.2 Son Sürüm @@ -39,11 +39,10 @@

+ Cloud | Dokümantasyon | Yol Haritası | - Twitter | - Discord | - Demo + Discord

@@ -58,11 +57,11 @@ 📕 İçindekiler - 💡 [RAGFlow Nedir?](#-ragflow-nedir) -- 🎮 [Demo](#-demo) +- 🎮 [Başlarken](#-başlarken) - 📌 [Son Güncellemeler](#-son-güncellemeler) - 🌟 [Temel Özellikler](#-temel-özellikler) - 🔎 [Sistem Mimarisi](#-sistem-mimarisi) -- 🎬 [Başlarken](#-başlarken) +- 🎬 [Kendi Sunucusunda Barındırma](#-kendi-sunucusunda-barındırma) - 🔧 [Yapılandırmalar](#-yapılandırmalar) - 🔧 [Docker İmajı Oluşturma](#-docker-i̇majı-oluşturma) - 🔨 [Geliştirme İçin Kaynaktan Hizmet Başlatma](#-geliştirme-i̇çin-kaynaktan-hizmet-başlatma) @@ -77,9 +76,9 @@ [RAGFlow](https://ragflow.io/), derin doküman anlayışına dayalı, açık kaynaklı ve öncü bir Artırılmış Üretim ile Bilgi Erişimi ([RAG](https://ragflow.io/basics/what-is-rag)) motorudur. En son RAG teknolojisini Ajan yetenekleriyle birleştirerek LLM'ler için üstün bir bağlam katmanı oluşturur. Her ölçekteki kuruluşa uyarlanabilir, kolaylaştırılmış bir RAG iş akışı sunar. Yakınsanmış bir [bağlam motoru](https://ragflow.io/basics/what-is-agent-context-engine) ve hazır ajan şablonlarıyla donatılmış RAGFlow, geliştiricilerin karmaşık verileri yüksek doğrulukta, üretime hazır yapay zeka sistemlerine olağanüstü verimlilik ve hassasiyetle dönüştürmesini sağlar. -## 🎮 Demo +## 🎮 Başlarken -Demomuzu [https://cloud.ragflow.io](https://cloud.ragflow.io) adresinden deneyebilirsiniz. +Bulut hizmetimizi [https://cloud.ragflow.io](https://cloud.ragflow.io) adresinden deneyin.
@@ -88,6 +87,7 @@ Demomuzu [https://cloud.ragflow.io](https://cloud.ragflow.io) adresinden deneyeb ## 🔥 Son Güncellemeler +- 2026-04-24 DeepSeek v4 desteği. - 2026-03-24 [RAGFlow Skill on OpenClaw](https://clawhub.ai/yingfeng/ragflow-skill) — OpenClaw üzerinden RAGFlow veri setlerine erişmek için resmi bir skill sağlar. - 2025-12-26 Yapay zeka ajanı için 'Bellek' desteği eklendi. - 2025-11-19 Gemini 3 Pro desteği eklendi. @@ -142,7 +142,7 @@ Demomuzu [https://cloud.ragflow.io](https://cloud.ragflow.io) adresinden deneyeb
-## 🎬 Başlarken +## 🎬 Kendi Sunucusunda Barındırma ### 📝 Ön Koşullar @@ -190,12 +190,12 @@ Demomuzu [https://cloud.ragflow.io](https://cloud.ragflow.io) adresinden deneyeb > Tüm Docker imajları x86 platformları için oluşturulmuştur. Şu anda ARM64 için Docker imajı sunmuyoruz. > ARM64 platformundaysanız, sisteminizle uyumlu bir Docker imajı oluşturmak için [bu kılavuzu](https://ragflow.io/docs/dev/build_docker_image) takip edin. -> Aşağıdaki komut RAGFlow Docker imajının `v0.25.0` sürümünü indirir. Farklı RAGFlow sürümleri için aşağıdaki tabloya bakın. `v0.25.0` dışında bir sürüm indirmek için, `docker compose` ile sunucuyu başlatmadan önce **docker/.env** dosyasındaki `RAGFLOW_IMAGE` değişkenini güncelleyin. +> Aşağıdaki komut RAGFlow Docker imajının `v0.25.2` sürümünü indirir. Farklı RAGFlow sürümleri için aşağıdaki tabloya bakın. `v0.25.2` dışında bir sürüm indirmek için, `docker compose` ile sunucuyu başlatmadan önce **docker/.env** dosyasındaki `RAGFLOW_IMAGE` değişkenini güncelleyin. ```bash $ cd ragflow/docker - # git checkout v0.25.0 + # git checkout v0.25.2 # İsteğe bağlı: Kararlı bir etiket kullanın (sürümler: https://github.com/infiniflow/ragflow/releases) # Bu adım, koddaki **entrypoint.sh** dosyasının Docker imaj sürümüyle eşleşmesini sağlar. @@ -400,7 +400,7 @@ docker build --platform linux/amd64 \ ## 🏄 Topluluk - [Discord](https://discord.gg/NjYzJD3GM3) -- [Twitter](https://twitter.com/infiniflowai) +- [X](https://x.com/infiniflowai) - [GitHub Tartışmalar](https://github.com/orgs/infiniflow/discussions) ## 🙌 Katkıda Bulunma diff --git a/README_tzh.md b/README_tzh.md index 14e5fb9d408..172c54a2955 100644 --- a/README_tzh.md +++ b/README_tzh.md @@ -10,9 +10,9 @@ 繁體版中文自述文件 日本語のREADME 한국어 + README en Français Bahasa Indonesia Português(Brasil) - README en Français README in Arabic Türkçe README

@@ -22,10 +22,10 @@ follow on X(Twitter) - Static Badge + Static Badge - docker pull infiniflow/ragflow:v0.25.0 + docker pull infiniflow/ragflow:v0.25.2 Latest Release @@ -39,11 +39,10 @@

+ Cloud | Document | Roadmap | - Twitter | - Discord | - Demo + Discord

@@ -58,11 +57,11 @@ 📕 目錄 - 💡 [RAGFlow 是什麼?](#-RAGFlow-是什麼) -- 🎮 [Demo-試用](#-demo-試用) +- 🎮 [快速開始](#-快速開始) - 📌 [近期更新](#-近期更新) - 🌟 [主要功能](#-主要功能) - 🔎 [系統架構](#-系統架構) -- 🎬 [快速開始](#-快速開始) +- 🎬 [自行架設](#-自行架設) - 🔧 [系統配置](#-系統配置) - 🔨 [以原始碼啟動服務](#-以原始碼啟動服務) - 📚 [技術文檔](#-技術文檔) @@ -77,9 +76,9 @@ [RAGFlow](https://ragflow.io/) 是一款領先的開源 [RAG](https://ragflow.io/basics/what-is-rag)(Retrieval-Augmented Generation)引擎,通過融合前沿的 RAG 技術與 Agent 能力,為大型語言模型提供卓越的上下文層。它提供可適配任意規模企業的端到端 RAG 工作流,憑藉融合式[上下文引擎](https://ragflow.io/basics/what-is-agent-context-engine)與預置的 Agent 模板,助力開發者以極致效率與精度將複雜數據轉化為高可信、生產級的人工智能系統。 -## 🎮 Demo 試用 +## 🎮 快速開始 -請登入網址 [https://cloud.ragflow.io](https://cloud.ragflow.io) 試用 demo。 +請登入網址 [https://cloud.ragflow.io](https://cloud.ragflow.io) 試用雲服務。
@@ -88,6 +87,7 @@ ## 🔥 近期更新 +- 2026-04-24 支援 DeepSeek v4 版本。 - 2026-03-24 發布 [RAGFlow 官方 Skill](https://clawhub.ai/yingfeng/ragflow-skill) — 提供官方 Skill 以透過 OpenClaw 訪問 RAGFlow 數據集。 - 2025-12-26 支援AI代理的「記憶」功能。 - 2025-11-19 支援 Gemini 3 Pro。 @@ -144,7 +144,7 @@
-## 🎬 快速開始 +## 🎬 自行架設 ### 📝 前提條件 @@ -191,12 +191,12 @@ > 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。 > 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。 -> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.25.0`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.25.0` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。 +> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.25.2`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.25.2` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。 ```bash $ cd ragflow/docker - # git checkout v0.25.0 + # git checkout v0.25.2 # 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases) # 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。 @@ -407,8 +407,8 @@ docker build --platform linux/amd64 \ ## 🏄 開源社群 -- [Discord](https://discord.gg/zd4qPW6t) -- [Twitter](https://twitter.com/infiniflowai) +- [Discord](https://discord.gg/NjYzJD3GM3) +- [X](https://x.com/infiniflowai) - [GitHub Discussions](https://github.com/orgs/infiniflow/discussions) ## 🙌 貢獻指南 diff --git a/README_zh.md b/README_zh.md index 473794a934f..72de8935d49 100644 --- a/README_zh.md +++ b/README_zh.md @@ -10,9 +10,9 @@ 繁體版中文自述文件 日本語のREADME 한국어 + README en Français Bahasa Indonesia Português(Brasil) - README en Français README in Arabic Türkçe README

@@ -22,10 +22,10 @@ follow on X(Twitter) - Static Badge + Static Badge - docker pull infiniflow/ragflow:v0.25.0 + docker pull infiniflow/ragflow:v0.25.2 Latest Release @@ -39,11 +39,10 @@

+ Cloud | Document | Roadmap | - Twitter | - Discord | - Demo + Discord

@@ -58,11 +57,11 @@ 📕 目录 - 💡 [RAGFlow 是什么?](#-RAGFlow-是什么) -- 🎮 [Demo](#-demo) +- 🎮 [快速开始](#-快速开始) - 📌 [近期更新](#-近期更新) - 🌟 [主要功能](#-主要功能) - 🔎 [系统架构](#-系统架构) -- 🎬 [快速开始](#-快速开始) +- 🎬 [自主托管](#-自主托管) - 🔧 [系统配置](#-系统配置) - 🔨 [以源代码启动服务](#-以源代码启动服务) - 📚 [技术文档](#-技术文档) @@ -77,9 +76,9 @@ [RAGFlow](https://ragflow.io/) 是一款领先的开源检索增强生成([RAG](https://ragflow.io/basics/what-is-rag))引擎,通过融合前沿的 RAG 技术与 Agent 能力,为大型语言模型提供卓越的上下文层。它提供可适配任意规模企业的端到端 RAG 工作流,凭借融合式[上下文引擎](https://ragflow.io/basics/what-is-agent-context-engine)与预置的 Agent 模板,助力开发者以极致效率与精度将复杂数据转化为高可信、生产级的人工智能系统。 -## 🎮 Demo 试用 +## 🎮 快速开始 -请登录网址 [https://cloud.ragflow.io](https://cloud.ragflow.io) 试用 demo。 +请登录网址 [https://cloud.ragflow.io](https://cloud.ragflow.io) 体验云服务。
@@ -88,8 +87,9 @@ ## 🔥 近期更新 +- 2026-04-24 支持 DeepSeek v4. - 2026-03-24 发布 [RAGFlow 官方 Skill](https://clawhub.ai/yingfeng/ragflow-skill) — 提供官方 Skill 以通过 OpenClaw 访问 RAGFlow 数据集。 -- 2025-12-26 支持AI代理的"记忆"功能。 +- 2025-12-26 支持 AI 代理的"记忆"功能。 - 2025-11-19 支持 Gemini 3 Pro。 - 2025-11-12 支持从 Confluence、S3、Notion、Discord、Google Drive 进行数据同步。 - 2025-10-23 支持 MinerU 和 Docling 作为文档解析方法。 @@ -144,7 +144,7 @@
-## 🎬 快速开始 +## 🎬 自主托管 ### 📝 前提条件 @@ -192,12 +192,12 @@ > 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。 > 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。 - > 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.25.0`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.25.0` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。 + > 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.25.2`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.25.2` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。 ```bash $ cd ragflow/docker - # git checkout v0.25.0 + # git checkout v0.25.2 # 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases) # 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。 @@ -410,8 +410,8 @@ docker build --platform linux/amd64 \ ## 🏄 开源社区 -- [Discord](https://discord.gg/zd4qPW6t) -- [Twitter](https://twitter.com/infiniflowai) +- [Discord](https://discord.gg/NjYzJD3GM3) +- [X](https://x.com/infiniflowai) - [GitHub Discussions](https://github.com/orgs/infiniflow/discussions) ## 🙌 贡献指南 diff --git a/admin/client/README.md b/admin/client/README.md index f71033d6482..cac7425aad8 100644 --- a/admin/client/README.md +++ b/admin/client/README.md @@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple 1. Ensure the Admin Service is running. 2. Install ragflow-cli. ```bash - pip install ragflow-cli==0.25.0 + pip install ragflow-cli==0.25.2 ``` 3. Launch the CLI client: ```bash diff --git a/admin/client/pyproject.toml b/admin/client/pyproject.toml index 48391a836d8..5f70bb1b188 100644 --- a/admin/client/pyproject.toml +++ b/admin/client/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ragflow-cli" -version = "0.25.0" +version = "0.25.2" description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. " authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }] license = { text = "Apache License, Version 2.0" } diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index b9f04783ced..148af4b45fe 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -1215,12 +1215,12 @@ def chat_on_session(self, command): # Prepare payload for completion API # Note: stream parameter is not sent, server defaults to stream=True payload = { - "conversation_id": session_id, + "session_id": session_id, "messages": [{"role": "user", "content": message}] } - response = self.http_client.request("POST", "/conversation/completion", json_body=payload, - use_api_base=False, auth_kind="web", stream=True) + response = self.http_client.request("POST", "/chat/completions", json_body=payload, + use_api_base=True, auth_kind="web", stream=True) if response.status_code != 200: print(f"Fail to chat on session, status code: {response.status_code}") @@ -1325,7 +1325,7 @@ def parse_dataset_docs(self, command_dict): print(f"Documents {document_names} not found in {dataset_name}") payload = {"doc_ids": document_ids, "run": 1} - response = self.http_client.request("POST", "/document/run", json_body=payload, use_api_base=False, + response = self.http_client.request("POST", "/documents/ingest", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: @@ -1351,7 +1351,7 @@ def parse_dataset(self, command_dict): document_ids.append(doc["id"]) payload = {"doc_ids": document_ids, "run": 1} - response = self.http_client.request("POST", "/document/run", json_body=payload, use_api_base=False, + response = self.http_client.request("POST", "/documents/ingest", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200 and res_json["code"] == 0: diff --git a/admin/client/uv.lock b/admin/client/uv.lock index 83868d9a20f..0bf404a2308 100644 --- a/admin/client/uv.lock +++ b/admin/client/uv.lock @@ -188,7 +188,7 @@ wheels = [ [[package]] name = "ragflow-cli" -version = "0.25.0" +version = "0.25.2" source = { virtual = "." } dependencies = [ { name = "beartype" }, diff --git a/admin/server/auth.py b/admin/server/auth.py index bd3c0c058ae..0aa96d0e37d 100644 --- a/admin/server/auth.py +++ b/admin/server/auth.py @@ -58,7 +58,7 @@ def load_user(web_request): return None # Decode JWT to get the UUID access_token - jwt = Serializer(secret_key=settings.SECRET_KEY) + jwt = Serializer(secret_key=settings.get_secret_key()) access_token = str(jwt.loads(jwt_token)) if not access_token or not access_token.strip(): diff --git a/agent/canvas.py b/agent/canvas.py index 65303ca9e9e..ab6d0ba9ff1 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -354,23 +354,21 @@ def reset(self, mem=False): key = k[4:] if key in self.variables: variable = self.variables[key] - if variable["type"] == "string": - self.globals[k] = "" - variable["value"] = "" - elif variable["type"] == "number": - self.globals[k] = 0 - variable["value"] = 0 - elif variable["type"] == "boolean": - self.globals[k] = False - variable["value"] = False - elif variable["type"] == "object": - self.globals[k] = {} - variable["value"] = {} - elif variable["type"].startswith("array"): - self.globals[k] = [] - variable["value"] = [] + value = variable.get("value") + if value is not None: + self.globals[k] = value else: - self.globals[k] = "" + var_type = variable.get("type", "") + if var_type == "number": + self.globals[k] = 0 + elif var_type == "boolean": + self.globals[k] = False + elif var_type == "object": + self.globals[k] = {} + elif var_type.startswith("array"): + self.globals[k] = [] + else: # "string" or unknown + self.globals[k] = "" else: self.globals[k] = "" @@ -381,8 +379,10 @@ async def run(self, **kwargs): self.message_id = get_uuid() created_at = int(time.time()) self.add_user_input(kwargs.get("query")) + path_set = set(self.path) for k, cpn in self.components.items(): - self.components[k]["obj"].reset(True) + if k in path_set: + self.components[k]["obj"].reset(True) if kwargs.get("webhook_payload"): for k, cpn in self.components.items(): diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 56f23afe350..859064046d6 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -145,7 +145,8 @@ def get_meta(self) -> dict[str, Any]: self._param.function_name = self._id.split("-->")[-1] m = super().get_meta() if hasattr(self._param, "user_prompt") and self._param.user_prompt: - m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt + # Keep the JSON schema valid; user_prompt is a string field, not a schema node. + m["function"]["parameters"]["properties"]["user_prompt"]["default"] = self._param.user_prompt return m def get_input_form(self) -> dict[str, dict]: @@ -276,10 +277,13 @@ async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt= return if delta.find("**ERROR**") >= 0: if self.get_exception_default_value(): - self.set_output("content", self.get_exception_default_value()) - yield self.get_exception_default_value() + fallback = self.get_exception_default_value() + self.set_output("content", fallback) + yield fallback else: self.set_output("_ERROR", delta) + self.set_output("content", delta) + yield delta return if not need2cite or cited: yield delta diff --git a/agent/component/docs_generator.py b/agent/component/docs_generator.py index d51b0ea591e..ce7a3abad59 100644 --- a/agent/component/docs_generator.py +++ b/agent/component/docs_generator.py @@ -1,3 +1,4 @@ +import base64 import logging import json import os @@ -48,6 +49,7 @@ def __init__(self): self.watermark_text = "" self.add_page_numbers = True self.add_timestamp = True + self.include_download_info_in_content = False self.font_size = 12 self.outputs = { "download": {"value": "", "type": "string"}, @@ -113,6 +115,7 @@ def _invoke(self, **kwargs): raise Exception("Document file is empty") file_size = len(file_bytes) + file_base64 = base64.b64encode(file_bytes).decode("utf-8") doc_id = get_uuid() settings.STORAGE_IMPL.put(self._canvas.get_tenant_id(), doc_id, file_bytes) @@ -128,6 +131,8 @@ def _invoke(self, **kwargs): "filename": filename, "mime_type": mime_type, "size": file_size, + "base64": file_base64, + "include_download_info_in_content": self._param.include_download_info_in_content, } self.set_output("download", json.dumps(download_info)) return download_info diff --git a/agent/component/invoke.py b/agent/component/invoke.py index 0dce464ebf0..4faaa7d0135 100644 --- a/agent/component/invoke.py +++ b/agent/component/invoke.py @@ -179,10 +179,7 @@ def _build_headers(self, kwargs: dict) -> dict: if not isinstance(headers, dict): raise ValueError("Invoke headers must be a JSON object.") - return { - key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value - for key, value in headers.items() - } + return {key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value for key, value in headers.items()} def _build_proxies(self) -> dict | None: if not re.sub(r"https?:?/?/?", "", self._param.proxy): @@ -215,7 +212,7 @@ def _format_response(self, response) -> str: # HtmlParser keeps the Invoke output text-focused when the endpoint returns HTML. sections = HtmlParser()(None, response.content) return "\n".join(sections) - + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))) def _invoke(self, **kwargs): if self.check_if_canceled("Invoke processing"): diff --git a/agent/component/list_operations.py b/agent/component/list_operations.py index 6016f758507..953e1455293 100644 --- a/agent/component/list_operations.py +++ b/agent/component/list_operations.py @@ -10,8 +10,9 @@ class ListOperationsParam(ComponentParamBase): def __init__(self): super().__init__() self.query = "" - self.operations = "topN" - self.n=0 + self.operations = "nth" + self.n = 0 + self.strict = False self.sort_method = "asc" self.filter = { "operator": "=", @@ -34,7 +35,11 @@ def __init__(self): def check(self): self.check_empty(self.query, "query") - self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"]) + self.check_valid_value( + self.operations, + "Support operations", + ["nth", "head", "tail", "filter", "sort", "drop_duplicates"], + ) def get_input_form(self) -> dict[str, dict]: return {} @@ -51,8 +56,8 @@ def _invoke(self, **kwargs): if not isinstance(self.inputs, list): raise TypeError("The input of List Operations should be an array.") self.set_input_value(inputs, self.inputs) - if self._param.operations == "topN": - self._topN() + if self._param.operations == "nth": + self._nth() elif self._param.operations == "head": self._head() elif self._param.operations == "tail": @@ -70,35 +75,74 @@ def _coerce_n(self): return int(getattr(self._param, "n", 0)) except Exception: return 0 - + + def _is_strict(self): + strict = getattr(self._param, "strict", False) + if isinstance(strict, str): + return strict.strip().lower() in {"1", "true", "yes", "on"} + return bool(strict) + def _set_outputs(self, outputs): self._param.outputs["result"]["value"] = outputs self._param.outputs["first"]["value"] = outputs[0] if outputs else None self._param.outputs["last"]["value"] = outputs[-1] if outputs else None - - def _topN(self): + + def _raise_strict_range_error(self, operation, n): + raise ValueError( + f"{operation} requires n to be within the valid range in strict mode, got {n}." + ) + + def _nth(self): n = self._coerce_n() - if n < 1: + strict = self._is_strict() + if n == 0: + if strict: + self._raise_strict_range_error("nth", n) outputs = [] + elif n > 0: + if n <= len(self.inputs): + outputs = [self.inputs[n - 1]] + elif strict: + self._raise_strict_range_error("nth", n) + else: + outputs = [] else: - n = min(n, len(self.inputs)) - outputs = self.inputs[:n] + if abs(n) <= len(self.inputs): + outputs = [self.inputs[n]] + elif strict: + self._raise_strict_range_error("nth", n) + else: + outputs = [] self._set_outputs(outputs) def _head(self): n = self._coerce_n() - if 1 <= n <= len(self.inputs): - outputs = [self.inputs[n - 1]] + strict = self._is_strict() + if strict: + if 1 <= n <= len(self.inputs): + outputs = self.inputs[:n] + else: + self._raise_strict_range_error("head", n) else: - outputs = [] + if n < 1: + outputs = [] + else: + outputs = self.inputs[:n] self._set_outputs(outputs) def _tail(self): n = self._coerce_n() - if 1 <= n <= len(self.inputs): - outputs = [self.inputs[-n]] + strict = self._is_strict() + if strict: + if 1 <= n <= len(self.inputs): + outputs = self.inputs[-n:] + else: + self._raise_strict_range_error("tail", n) else: - outputs = [] + if n < 1: + outputs = [] + else: + outputs = self.inputs[-n:] self._set_outputs(outputs) def _filter(self): @@ -107,7 +151,7 @@ def _filter(self): def _norm(self,v): s = "" if v is None else str(v) return s - + def _eval(self, v, operator, value): if operator == "=": return v == value @@ -163,6 +207,6 @@ def _hashable(self,x): if isinstance(x, set): return tuple(sorted(self._hashable(v) for v in x)) return x - + def thoughts(self) -> str: return "ListOperation in progress" diff --git a/agent/component/message.py b/agent/component/message.py index 8db4eedbd14..a52741f6b36 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -75,6 +75,22 @@ def _is_download_info(value: Any) -> bool: key in value for key in ("doc_id", "filename", "mime_type") ) + @staticmethod + def _download_info_includes_content(value: Any) -> bool: + return isinstance(value, dict) and bool(value.get("include_download_info_in_content")) + + @staticmethod + def _normalize_download_info(value: Any) -> Any: + if isinstance(value, list): + return [Message._normalize_download_info(item) for item in value] + + if not isinstance(value, dict): + return value + + normalized = value.copy() + normalized.pop("include_download_info_in_content", None) + return normalized + def _extract_downloads(self, value: Any) -> list[dict[str, Any]]: if isinstance(value, str): try: @@ -100,7 +116,19 @@ def _stringify_message_value( extracted_downloads = self._extract_downloads(value) if extracted_downloads: if downloads is not None: - downloads.extend(extracted_downloads) + downloads.extend(self._normalize_download_info(item) for item in extracted_downloads) + if any(self._download_info_includes_content(item) for item in extracted_downloads): + if isinstance(value, str): + try: + value = json.loads(value) + except Exception: + return value + try: + return json.dumps(self._normalize_download_info(value), ensure_ascii=False) + except Exception: + if fallback_to_str: + return str(value) + return "" return "" if value is None: diff --git a/agent/component/variable_assigner.py b/agent/component/variable_assigner.py index 08b28334312..dd6182c7ce0 100644 --- a/agent/component/variable_assigner.py +++ b/agent/component/variable_assigner.py @@ -141,20 +141,18 @@ def _extend(self,variable,parameter): return variable + parameter def _remove_first(self,variable): - if len(variable)==0: - return variable if not isinstance(variable,list): return "ERROR:VARIABLE_NOT_LIST" - else: - return variable[1:] - - def _remove_last(self,variable): if len(variable)==0: return variable + return variable[1:] + + def _remove_last(self,variable): if not isinstance(variable,list): return "ERROR:VARIABLE_NOT_LIST" - else: - return variable[:-1] + if len(variable)==0: + return variable + return variable[:-1] def is_number(self, value): if isinstance(value, bool): diff --git a/agent/sandbox/client.py b/agent/sandbox/client.py index 4d49ae734c6..9ca51cc8e3a 100644 --- a/agent/sandbox/client.py +++ b/agent/sandbox/client.py @@ -23,11 +23,12 @@ import json import logging +import os from typing import Dict, Any, Optional from api.db.services.system_settings_service import SystemSettingsService from agent.sandbox.providers import ProviderManager -from agent.sandbox.providers.base import ExecutionResult +from agent.sandbox.providers.base import ExecutionResult, SandboxProviderConfigError logger = logging.getLogger(__name__) @@ -59,8 +60,8 @@ def _load_provider_from_settings() -> None: """ Load sandbox provider from system settings and configure the provider manager. - This function reads the system settings to determine which provider is active - and initializes it with the appropriate configuration. + This function resolves the active provider type, then loads configuration + from system settings with environment overrides for that provider. """ global _provider_manager @@ -68,41 +69,27 @@ def _load_provider_from_settings() -> None: return try: - # Get active provider type - provider_type_settings = SystemSettingsService.get_by_name("sandbox.provider_type") - if not provider_type_settings: - raise RuntimeError( - "Sandbox provider type not configured. Please set 'sandbox.provider_type' in system settings." - ) - provider_type = provider_type_settings[0].value - - # Get provider configuration - provider_config_settings = SystemSettingsService.get_by_name(f"sandbox.{provider_type}") - - if not provider_config_settings: - logger.warning(f"No configuration found for provider: {provider_type}") - config = {} - else: - try: - config = json.loads(provider_config_settings[0].value) - except json.JSONDecodeError as e: - logger.error(f"Failed to parse sandbox config for {provider_type}: {e}") - config = {} + provider_type, provider_type_from_env = _resolve_provider_type() + config = _load_provider_config(provider_type) # Import and instantiate the provider from agent.sandbox.providers import ( SelfManagedProvider, AliyunCodeInterpreterProvider, E2BProvider, + LocalProvider, ) provider_classes = { "self_managed": SelfManagedProvider, "aliyun_codeinterpreter": AliyunCodeInterpreterProvider, "e2b": E2BProvider, + "local": LocalProvider, } if provider_type not in provider_classes: + if provider_type_from_env: + raise SandboxProviderConfigError(f"Unknown sandbox provider type: {provider_type}") logger.error(f"Unknown provider type: {provider_type}") return @@ -111,19 +98,97 @@ def _load_provider_from_settings() -> None: # Initialize the provider if not provider.initialize(config): - logger.error(f"Failed to initialize sandbox provider: {provider_type}. Config keys: {list(config.keys())}") + message = f"Failed to initialize sandbox provider: {provider_type}. Config keys: {list(config.keys())}" + if provider_type == "local" or provider_type_from_env: + raise SandboxProviderConfigError(message) + logger.error(message) return # Set the active provider _provider_manager.set_provider(provider_type, provider) logger.info(f"Sandbox provider '{provider_type}' initialized successfully") + except SandboxProviderConfigError: + raise except Exception as e: logger.error(f"Failed to load sandbox provider from settings: {e}") import traceback traceback.print_exc() +def _load_provider_config_from_settings(provider_type: str) -> Dict[str, Any]: + provider_config_settings = SystemSettingsService.get_by_name(f"sandbox.{provider_type}") + if not provider_config_settings: + logger.warning(f"No configuration found for provider: {provider_type}") + return {} + + try: + return json.loads(provider_config_settings[0].value) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse sandbox config for {provider_type}: {e}") + return {} + + +def _resolve_provider_type() -> tuple[str, bool]: + provider_type = os.environ.get("SANDBOX_PROVIDER_TYPE", "").strip() + if provider_type: + return provider_type, True + + provider_type_settings = SystemSettingsService.get_by_name("sandbox.provider_type") + if not provider_type_settings: + raise RuntimeError( + "Sandbox provider type not configured. Please set 'sandbox.provider_type' in system settings." + ) + return provider_type_settings[0].value, False + + +def _load_provider_config(provider_type: str) -> Dict[str, Any]: + config = _load_provider_config_from_settings(provider_type) + env_config = _load_provider_config_from_env(provider_type) + if env_config: + config.update(env_config) + return config + + +def _load_provider_config_from_env(provider_type: str) -> Dict[str, Any]: + if provider_type == "local": + return _load_local_provider_config_from_env() + if provider_type == "self_managed": + return _load_self_managed_provider_config_from_env() + return {} + + +def _load_local_provider_config_from_env() -> Dict[str, Any]: + env_to_config = { + "SANDBOX_LOCAL_PYTHON_BIN": "python_bin", + "SANDBOX_LOCAL_NODE_BIN": "node_bin", + "SANDBOX_LOCAL_WORK_DIR": "work_dir", + "SANDBOX_LOCAL_TIMEOUT": "timeout", + "SANDBOX_LOCAL_MAX_MEMORY_MB": "max_memory_mb", + "SANDBOX_LOCAL_MAX_OUTPUT_BYTES": "max_output_bytes", + "SANDBOX_LOCAL_MAX_ARTIFACTS": "max_artifacts", + "SANDBOX_LOCAL_MAX_ARTIFACT_BYTES": "max_artifact_bytes", + } + config = {} + for env_name, config_name in env_to_config.items(): + if env_name in os.environ: + config[config_name] = os.environ[env_name] + return config + + +def _load_self_managed_provider_config_from_env() -> Dict[str, Any]: + host = os.environ.get("SANDBOX_HOST", "").strip() + port = os.environ.get("SANDBOX_EXECUTOR_MANAGER_PORT", "").strip() + pool_size = os.environ.get("SANDBOX_EXECUTOR_MANAGER_POOL_SIZE", "").strip() + + config = {} + if host: + config["endpoint"] = f"http://{host}:{port or '9385'}" + if pool_size: + config["pool_size"] = pool_size + return config + + def reload_provider() -> None: """ Reload the sandbox provider from system settings. diff --git a/agent/sandbox/providers/__init__.py b/agent/sandbox/providers/__init__.py index 7be1463b9ca..e7cfc2ddc9c 100644 --- a/agent/sandbox/providers/__init__.py +++ b/agent/sandbox/providers/__init__.py @@ -24,20 +24,24 @@ - aliyun_codeinterpreter.py: Aliyun Code Interpreter provider implementation Official Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter - e2b.py: E2B provider implementation +- local.py: Local process provider implementation """ -from .base import SandboxProvider, SandboxInstance, ExecutionResult +from .base import SandboxProvider, SandboxInstance, ExecutionResult, SandboxProviderConfigError from .manager import ProviderManager from .self_managed import SelfManagedProvider from .aliyun_codeinterpreter import AliyunCodeInterpreterProvider from .e2b import E2BProvider +from .local import LocalProvider __all__ = [ "SandboxProvider", "SandboxInstance", "ExecutionResult", + "SandboxProviderConfigError", "ProviderManager", "SelfManagedProvider", "AliyunCodeInterpreterProvider", "E2BProvider", + "LocalProvider", ] diff --git a/agent/sandbox/providers/aliyun_codeinterpreter.py b/agent/sandbox/providers/aliyun_codeinterpreter.py index 8ee99ed1ecc..bbec2a26820 100644 --- a/agent/sandbox/providers/aliyun_codeinterpreter.py +++ b/agent/sandbox/providers/aliyun_codeinterpreter.py @@ -30,7 +30,6 @@ import logging import os import time -import base64 import json from typing import Dict, Any, List, Optional from datetime import datetime, timezone @@ -39,10 +38,10 @@ from agentrun.utils.config import Config from agentrun.utils.exception import ServerError +from agent.sandbox.result_protocol import build_javascript_wrapper, build_python_wrapper, extract_structured_result from .base import SandboxProvider, SandboxInstance, ExecutionResult logger = logging.getLogger(__name__) -RESULT_MARKER_PREFIX = "__RAGFLOW_RESULT__:" class AliyunCodeInterpreterProvider(SandboxProvider): @@ -234,9 +233,9 @@ def execute_code(self, instance_id: str, code: str, language: str, timeout: int # Matches self_managed provider behavior: call main(**arguments) args_json = json.dumps(arguments or {}) wrapped_code = ( - self._build_python_wrapper(code, args_json) + build_python_wrapper(code, args_json) if normalized_lang == "python" - else self._build_javascript_wrapper(code, args_json) + else build_javascript_wrapper(code, args_json) ) logger.debug(f"Aliyun Code Interpreter: Wrapped code (first 200 chars): {wrapped_code[:200]}") @@ -284,7 +283,7 @@ def execute_code(self, instance_id: str, code: str, language: str, timeout: int stdout = "\n".join(stdout_parts) stderr = "\n".join(stderr_parts) - stdout, structured_result = self._extract_structured_result(stdout) + stdout, structured_result = extract_structured_result(stdout) logger.info(f"Aliyun Code Interpreter: stdout length={len(stdout)}, stderr length={len(stderr)}, exit_code={exit_code}") if stdout: @@ -364,71 +363,6 @@ def health_check(self) -> bool: # If we get any response (even an error), the service is reachable return "connection" not in str(e).lower() - @staticmethod - def _build_python_wrapper(code: str, args_json: str) -> str: - marker = RESULT_MARKER_PREFIX - return f'''{code} - -if __name__ == "__main__": - import base64 - import json - - result = main(**{args_json}) - payload = json.dumps({{"present": True, "value": result, "type": "json"}}, ensure_ascii=False, separators=(",", ":")) - print("{marker}" + base64.b64encode(payload.encode("utf-8")).decode("ascii")) -''' - - @staticmethod - def _build_javascript_wrapper(code: str, args_json: str) -> str: - marker = RESULT_MARKER_PREFIX - return f'''{code} - -const __ragflowArgs = {args_json}; - -(async () => {{ - try {{ - const output = await Promise.resolve(main(__ragflowArgs)); - if (typeof output === 'undefined') {{ - throw new Error('main() must return a value. Use null for an empty result.'); - }} - const payload = JSON.stringify({{ present: true, value: output, type: 'json' }}); - if (typeof payload === 'undefined') {{ - throw new Error('main() returned a non-JSON-serializable value.'); - }} - console.log('{marker}' + Buffer.from(payload, 'utf8').toString('base64')); - }} catch (err) {{ - console.error(err instanceof Error ? err.stack || err.message : String(err)); - }} -}})(); -''' - - @staticmethod - def _extract_structured_result(stdout: str) -> tuple[str, Dict[str, Any]]: - if not stdout: - return "", {} - - cleaned_lines: list[str] = [] - structured_result: Dict[str, Any] = {} - - for line in str(stdout).splitlines(): - if line.startswith(RESULT_MARKER_PREFIX): - payload_b64 = line[len(RESULT_MARKER_PREFIX) :].strip() - if not payload_b64: - continue - try: - payload = base64.b64decode(payload_b64).decode("utf-8") - structured_result = json.loads(payload) - except Exception as exc: - logger.warning(f"Aliyun Code Interpreter: failed to decode structured result marker: {exc}") - cleaned_lines.append(line) - continue - cleaned_lines.append(line) - - cleaned_stdout = "\n".join(cleaned_lines) - if stdout.endswith("\n") and cleaned_stdout and not cleaned_stdout.endswith("\n"): - cleaned_stdout += "\n" - return cleaned_stdout, structured_result - def get_supported_languages(self) -> List[str]: """ Get list of supported programming languages. diff --git a/agent/sandbox/providers/base.py b/agent/sandbox/providers/base.py index c21b583e02b..8f9c04aaa42 100644 --- a/agent/sandbox/providers/base.py +++ b/agent/sandbox/providers/base.py @@ -26,6 +26,10 @@ from typing import Dict, Any, Optional, List +class SandboxProviderConfigError(Exception): + """Raised when the selected provider is explicitly configured but unusable.""" + + @dataclass class SandboxInstance: """Represents a sandbox execution instance""" @@ -209,4 +213,4 @@ def validate_config(self, config: Dict[str, Any]) -> tuple[bool, Optional[str]]: >>> return True, None """ # Default implementation: no custom validation - return True, None \ No newline at end of file + return True, None diff --git a/agent/sandbox/providers/local.py b/agent/sandbox/providers/local.py new file mode 100644 index 00000000000..b8057fa5b43 --- /dev/null +++ b/agent/sandbox/providers/local.py @@ -0,0 +1,296 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import base64 +import json +import mimetypes +import os +import shutil +import signal +import subprocess +import time +import uuid +from pathlib import Path +from typing import Any, Dict, List, Optional + +from agent.sandbox.result_protocol import build_javascript_wrapper, build_python_wrapper, extract_structured_result +from .base import ExecutionResult, SandboxInstance, SandboxProvider, SandboxProviderConfigError + + +ALLOWED_ARTIFACT_EXTENSIONS = { + ".csv", + ".html", + ".jpeg", + ".jpg", + ".json", + ".pdf", + ".png", + ".svg", +} + + +def _env_enabled(name: str) -> bool: + return os.environ.get(name, "").strip().lower() in {"1", "true", "yes", "on"} + + +class LocalProvider(SandboxProvider): + """ + Execute code as a local child process. + + This provider is intentionally gated by SANDBOX_LOCAL_ENABLED because it is + not a sandbox boundary. Use a low-privilege runtime account. + """ + + def __init__(self): + self.python_bin = "python3" + self.node_bin = "node" + self.work_dir = Path("/tmp/ragflow-codeexec") + self.timeout = 30 + self.max_memory_mb = 512 + self.max_output_bytes = 1024 * 1024 + self.max_artifacts = 20 + self.max_artifact_bytes = 10 * 1024 * 1024 + self._initialized = False + self._instances: dict[str, Path] = {} + + def initialize(self, config: Dict[str, Any]) -> bool: + if not _env_enabled("SANDBOX_LOCAL_ENABLED"): + raise SandboxProviderConfigError("Local code execution is disabled. Set SANDBOX_LOCAL_ENABLED=true to enable it.") + + self.python_bin = str(self._resolve_config_value(config, "python_bin", "SANDBOX_LOCAL_PYTHON_BIN", "python3")) + self.node_bin = str(self._resolve_config_value(config, "node_bin", "SANDBOX_LOCAL_NODE_BIN", "node")) + self.work_dir = Path(self._resolve_config_value(config, "work_dir", "SANDBOX_LOCAL_WORK_DIR", "/tmp/ragflow-codeexec")).resolve() + self.timeout = int(self._resolve_config_value(config, "timeout", "SANDBOX_LOCAL_TIMEOUT", 30)) + self.max_memory_mb = int(self._resolve_config_value(config, "max_memory_mb", "SANDBOX_LOCAL_MAX_MEMORY_MB", 512)) + self.max_output_bytes = int(self._resolve_config_value(config, "max_output_bytes", "SANDBOX_LOCAL_MAX_OUTPUT_BYTES", 1024 * 1024)) + self.max_artifacts = int(self._resolve_config_value(config, "max_artifacts", "SANDBOX_LOCAL_MAX_ARTIFACTS", 20)) + self.max_artifact_bytes = int(self._resolve_config_value(config, "max_artifact_bytes", "SANDBOX_LOCAL_MAX_ARTIFACT_BYTES", 10 * 1024 * 1024)) + + self._validate_limits() + self.work_dir.mkdir(parents=True, exist_ok=True, mode=0o700) + self._initialized = True + return True + + def create_instance(self, template: str = "python") -> SandboxInstance: + if not self._initialized: + raise RuntimeError("Provider not initialized. Call initialize() first.") + + language = self._normalize_language(template) + instance_id = str(uuid.uuid4()) + instance_dir = self.work_dir / instance_id + instance_dir.mkdir(mode=0o700) + (instance_dir / "artifacts").mkdir(mode=0o700) + self._instances[instance_id] = instance_dir + + return SandboxInstance( + instance_id=instance_id, + provider="local", + status="running", + metadata={"language": language, "work_dir": str(instance_dir)}, + ) + + def execute_code( + self, + instance_id: str, + code: str, + language: str, + timeout: int = 10, + arguments: Optional[Dict[str, Any]] = None, + ) -> ExecutionResult: + if not self._initialized: + raise RuntimeError("Provider not initialized. Call initialize() first.") + + normalized_lang = self._normalize_language(language) + instance_dir = self._instances[instance_id] + args_json = json.dumps(arguments or {}, ensure_ascii=False) + command, script_path = self._prepare_script(instance_dir, normalized_lang, code, args_json) + requested_timeout = self.timeout if timeout is None else int(timeout) + if requested_timeout <= 0: + raise RuntimeError(f"Execution timeout must be greater than 0 seconds, got {requested_timeout}.") + exec_timeout = min(requested_timeout, self.timeout) + + start_time = time.time() + process = subprocess.Popen( + command, + cwd=instance_dir, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + encoding="utf-8", + errors="replace", + env=self._build_child_env(instance_dir), + preexec_fn=self._limit_child_process if os.name == "posix" else None, + start_new_session=os.name == "posix", + ) + + try: + stdout, stderr = process.communicate(timeout=exec_timeout) + except subprocess.TimeoutExpired: + if os.name == "posix": + os.killpg(process.pid, signal.SIGKILL) + else: + process.kill() + process.communicate() + raise TimeoutError(f"Execution timed out after {exec_timeout} seconds") + + execution_time = time.time() - start_time + self._validate_output_size(stdout, stderr) + stdout, structured_result = extract_structured_result(stdout) + + return ExecutionResult( + stdout=stdout, + stderr=stderr, + exit_code=process.returncode, + execution_time=execution_time, + metadata={ + "instance_id": instance_id, + "language": normalized_lang, + "script_path": str(script_path), + "status": "ok" if process.returncode == 0 else "error", + "timeout": exec_timeout, + "artifacts": self._collect_artifacts(instance_dir / "artifacts"), + "result_present": structured_result.get("present", False), + "result_value": structured_result.get("value"), + "result_type": structured_result.get("type"), + }, + ) + + def destroy_instance(self, instance_id: str) -> bool: + if not self._initialized: + raise RuntimeError("Provider not initialized. Call initialize() first.") + + instance_dir = self._instances.pop(instance_id) + shutil.rmtree(instance_dir) + return True + + def health_check(self) -> bool: + return self._initialized and self.work_dir.exists() and os.access(self.work_dir, os.W_OK) + + def get_supported_languages(self) -> List[str]: + return ["python", "javascript", "nodejs"] + + @staticmethod + def get_config_schema() -> Dict[str, Dict]: + return { + "python_bin": {"type": "string", "required": False, "default": "python3"}, + "node_bin": {"type": "string", "required": False, "default": "node"}, + "work_dir": {"type": "string", "required": False, "default": "/tmp/ragflow-codeexec"}, + "timeout": {"type": "integer", "required": False, "default": 30}, + "max_memory_mb": {"type": "integer", "required": False, "default": 512}, + "max_output_bytes": {"type": "integer", "required": False, "default": 1048576}, + "max_artifacts": {"type": "integer", "required": False, "default": 20}, + "max_artifact_bytes": {"type": "integer", "required": False, "default": 10485760}, + } + + def _validate_limits(self) -> None: + if self.timeout <= 0: + raise SandboxProviderConfigError("SANDBOX_LOCAL_TIMEOUT must be greater than 0.") + if self.max_memory_mb <= 0: + raise SandboxProviderConfigError("SANDBOX_LOCAL_MAX_MEMORY_MB must be greater than 0.") + if self.max_output_bytes <= 0: + raise SandboxProviderConfigError("SANDBOX_LOCAL_MAX_OUTPUT_BYTES must be greater than 0.") + if self.max_artifacts < 0: + raise SandboxProviderConfigError("SANDBOX_LOCAL_MAX_ARTIFACTS must be greater than or equal to 0.") + if self.max_artifact_bytes <= 0: + raise SandboxProviderConfigError("SANDBOX_LOCAL_MAX_ARTIFACT_BYTES must be greater than 0.") + + def _prepare_script(self, instance_dir: Path, language: str, code: str, args_json: str) -> tuple[list[str], Path]: + if language == "python": + script_path = instance_dir / "main.py" + script_path.write_text(build_python_wrapper(code, args_json), encoding="utf-8") + return [self.python_bin, str(script_path)], script_path + if language in {"javascript", "nodejs"}: + script_path = instance_dir / "main.js" + script_path.write_text(build_javascript_wrapper(code, args_json), encoding="utf-8") + return [self.node_bin, str(script_path)], script_path + raise RuntimeError(f"Unsupported language for local provider: {language}") + + @staticmethod + def _resolve_config_value(config: Dict[str, Any], key: str, env_name: str, default: Any) -> Any: + value = config.get(key) + if value is not None: + return value + return os.environ.get(env_name, default) + + def _build_child_env(self, instance_dir: Path) -> dict[str, str]: + return { + "HOME": str(instance_dir), + "MPLBACKEND": "Agg", + "PATH": os.environ.get("PATH", ""), + "PYTHONUNBUFFERED": "1", + "TMPDIR": str(instance_dir), + } + + def _limit_child_process(self) -> None: + import resource + + self._set_resource_limit(resource.RLIMIT_CPU, self.timeout + 1) + self._set_resource_limit(resource.RLIMIT_AS, self.max_memory_mb * 1024 * 1024) + self._set_resource_limit(resource.RLIMIT_FSIZE, self.max_artifact_bytes) + self._set_resource_limit(resource.RLIMIT_NOFILE, 64) + + @staticmethod + def _set_resource_limit(kind: int, value: int) -> None: + import resource + + _, hard = resource.getrlimit(kind) + limit = value if hard == resource.RLIM_INFINITY else min(value, hard) + resource.setrlimit(kind, (limit, limit)) + + def _validate_output_size(self, stdout: str, stderr: str) -> None: + output_size = len((stdout or "").encode("utf-8")) + len((stderr or "").encode("utf-8")) + if output_size > self.max_output_bytes: + raise RuntimeError(f"Local execution output exceeded {self.max_output_bytes} bytes.") + + def _collect_artifacts(self, artifacts_dir: Path) -> list[dict[str, Any]]: + artifacts: list[dict[str, Any]] = [] + for path in sorted(artifacts_dir.rglob("*")): + if path.is_symlink(): + raise RuntimeError(f"Artifact symlinks are not allowed: {path.name}") + if path.is_dir(): + continue + if not path.is_file(): + raise RuntimeError(f"Unsupported artifact entry: {path.name}") + + if len(artifacts) >= self.max_artifacts: + raise RuntimeError(f"Local execution produced more than {self.max_artifacts} artifacts.") + + size = path.stat().st_size + if size > self.max_artifact_bytes: + raise RuntimeError(f"Artifact exceeds {self.max_artifact_bytes} bytes: {path.name}") + + ext = path.suffix.lower() + if ext not in ALLOWED_ARTIFACT_EXTENSIONS: + raise RuntimeError(f"Unsupported artifact type: {path.name}") + + artifacts.append( + { + "name": path.relative_to(artifacts_dir).as_posix(), + "content_b64": base64.b64encode(path.read_bytes()).decode("ascii"), + "mime_type": mimetypes.guess_type(path.name)[0] or "application/octet-stream", + "size": size, + } + ) + return artifacts + + @staticmethod + def _normalize_language(language: str) -> str: + lang_lower = (language or "python").lower() + if lang_lower in {"python", "python3"}: + return "python" + if lang_lower in {"javascript", "nodejs"}: + return "nodejs" + return lang_lower diff --git a/agent/sandbox/result_protocol.py b/agent/sandbox/result_protocol.py new file mode 100644 index 00000000000..f71e5f49968 --- /dev/null +++ b/agent/sandbox/result_protocol.py @@ -0,0 +1,85 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import base64 +import json +from typing import Any + + +RESULT_MARKER_PREFIX = "__RAGFLOW_RESULT__:" + + +def build_python_wrapper(code: str, args_json: str) -> str: + return f'''{code} + +if __name__ == "__main__": + import base64 + import json + + result = main(**{args_json}) + payload = json.dumps({{"present": True, "value": result, "type": "json"}}, ensure_ascii=False, separators=(",", ":")) + print("{RESULT_MARKER_PREFIX}" + base64.b64encode(payload.encode("utf-8")).decode("ascii")) +''' + + +def build_javascript_wrapper(code: str, args_json: str) -> str: + return f'''{code} + +const __ragflowArgs = {args_json}; + +(async () => {{ + const __ragflowMain = typeof main !== 'undefined' ? main : module.exports && module.exports.main; + if (typeof __ragflowMain !== 'function') {{ + throw new Error('main() must be defined or exported.'); + }} + const output = await Promise.resolve(__ragflowMain(__ragflowArgs)); + if (typeof output === 'undefined') {{ + throw new Error('main() must return a value. Use null for an empty result.'); + }} + const payload = JSON.stringify({{ present: true, value: output, type: 'json' }}); + if (typeof payload === 'undefined') {{ + throw new Error('main() returned a non-JSON-serializable value.'); + }} + console.log('{RESULT_MARKER_PREFIX}' + Buffer.from(payload, 'utf8').toString('base64')); +}})(); +''' + + +def extract_structured_result(stdout: str) -> tuple[str, dict[str, Any]]: + if not stdout: + return "", {} + + cleaned_lines: list[str] = [] + structured_result: dict[str, Any] = {} + + for line in str(stdout).splitlines(): + if line.startswith(RESULT_MARKER_PREFIX): + payload_b64 = line[len(RESULT_MARKER_PREFIX) :].strip() + if not payload_b64: + cleaned_lines.append(line) + continue + try: + payload = base64.b64decode(payload_b64, validate=True).decode("utf-8") + structured_result = json.loads(payload) + except Exception: + cleaned_lines.append(line) + continue + cleaned_lines.append(line) + + cleaned_stdout = "\n".join(cleaned_lines) + if stdout.endswith("\n") and cleaned_stdout and not cleaned_stdout.endswith("\n"): + cleaned_stdout += "\n" + return cleaned_stdout, structured_result diff --git a/agent/templates/ingestion_pipeline_Book.json b/agent/templates/ingestion_pipeline_book.json similarity index 100% rename from agent/templates/ingestion_pipeline_Book.json rename to agent/templates/ingestion_pipeline_book.json diff --git a/agent/templates/ingestion_pipeline_General.json b/agent/templates/ingestion_pipeline_general.json similarity index 100% rename from agent/templates/ingestion_pipeline_General.json rename to agent/templates/ingestion_pipeline_general.json diff --git a/agent/templates/ingestion_pipeline_Laws.json b/agent/templates/ingestion_pipeline_laws.json similarity index 100% rename from agent/templates/ingestion_pipeline_Laws.json rename to agent/templates/ingestion_pipeline_laws.json diff --git a/agent/templates/ingestion_pipeline_Manual.json b/agent/templates/ingestion_pipeline_manual.json similarity index 100% rename from agent/templates/ingestion_pipeline_Manual.json rename to agent/templates/ingestion_pipeline_manual.json diff --git a/agent/templates/ingestion_pipeline_One.json b/agent/templates/ingestion_pipeline_one.json similarity index 100% rename from agent/templates/ingestion_pipeline_One.json rename to agent/templates/ingestion_pipeline_one.json diff --git a/agent/templates/ingestion_pipeline_Paper.json b/agent/templates/ingestion_pipeline_paper.json similarity index 100% rename from agent/templates/ingestion_pipeline_Paper.json rename to agent/templates/ingestion_pipeline_paper.json diff --git a/agent/templates/ingestion_pipeline_Resume.json b/agent/templates/ingestion_pipeline_resume.json similarity index 98% rename from agent/templates/ingestion_pipeline_Resume.json rename to agent/templates/ingestion_pipeline_resume.json index 7b8d9899577..cb35eb2043e 100644 --- a/agent/templates/ingestion_pipeline_Resume.json +++ b/agent/templates/ingestion_pipeline_resume.json @@ -242,13 +242,14 @@ "include_heading_content": false, "levels": [ [ - "^\\s*(?i:(?:\\d+[\\.\\)]\\s*)?(?:EDUCATION|ACADEMIC\\s*BACKGROUND|ACADEMIC\\s*HISTORY|EDUCATIONAL\\s*BACKGROUND|RELEVANT\\s*COURSEWORK|COURSEWORK|EXPERIENCE|WORK\\s*EXPERIENCE|PROFESSIONAL\\s*EXPERIENCE|RELEVANT\\s*EXPERIENCE|EMPLOYMENT\\s*HISTORY|CAREER\\s*HISTORY|INTERNSHIP\\s*EXPERIENCE|PROJECTS|PROJECT\\s*EXPERIENCE|ACADEMIC\\s*PROJECTS|PROFESSIONAL\\s*PROJECTS|SKILLS|TECHNICAL\\s*SKILLS|CORE\\s*COMPETENCIES|COMPETENCIES|QUALIFICATIONS|SUMMARY\\s*OF\\s*QUALIFICATIONS|CERTIFICATIONS|LICENSES|CERTIFICATES|AWARDS|HONORS|HONOURS|ACHIEVEMENTS|PUBLICATIONS|RESEARCH|RESEARCH\\s*EXPERIENCE|LEADERSHIP|LEADERSHIP\\s*EXPERIENCE|ACTIVITIES|EXTRACURRICULAR\\s*ACTIVITIES|ACTIVITIES\\s*(?:&|AND)\\s*SKILLS|INVOLVEMENT|CAMPUS\\s*INVOLVEMENT|VOLUNTEER\\s*EXPERIENCE|VOLUNTEERING|COMMUNITY\\s*SERVICE|LANGUAGES|INTERESTS|HOBBIES|PROFILE|PROFESSIONAL\\s*PROFILE|SUMMARY|PROFESSIONAL\\s*SUMMARY|CAREER\\s*SUMMARY|OBJECTIVE|CAREER\\s*OBJECTIVE|PERSONAL\\s*INFORMATION|CONTACT\\s*INFORMATION|ADDITIONAL\\s*INFORMATION|TRAINING))\\s*[:\uff1a]?\\s*$" + "^\\s*(?i:(?:\\d+[\\.\\)]\\s*)?(?:EDUCATION|ACADEMIC\\s*BACKGROUND|ACADEMIC\\s*HISTORY|EDUCATIONAL\\s*BACKGROUND|RELEVANT\\s*COURSEWORK|COURSEWORK|EXPERIENCE|WORK\\s*EXPERIENCE|PROFESSIONAL\\s*EXPERIENCE|RELEVANT\\s*EXPERIENCE|EMPLOYMENT\\s*HISTORY|CAREER\\s*HISTORY|INTERNSHIP\\s*EXPERIENCE|PROJECTS|PROJECT\\s*EXPERIENCE|ACADEMIC\\s*PROJECTS|PROFESSIONAL\\s*PROJECTS|SKILLS|TECHNICAL\\s*SKILLS|CORE\\s*COMPETENCIES|COMPETENCIES|QUALIFICATIONS|SUMMARY\\s*OF\\s*QUALIFICATIONS|CERTIFICATIONS|LICENSES|CERTIFICATES|AWARDS|HONORS|HONOURS|ACHIEVEMENTS|PUBLICATIONS|RESEARCH|RESEARCH\\s*EXPERIENCE|LEADERSHIP|LEADERSHIP\\s*EXPERIENCE|ACTIVITIES|EXTRACURRICULAR\\s*ACTIVITIES|ACTIVITIES\\s*(?:&|AND)\\s*SKILLS|INVOLVEMENT|CAMPUS\\s*INVOLVEMENT|VOLUNTEER\\s*EXPERIENCE|VOLUNTEERING|COMMUNITY\\s*SERVICE|LANGUAGES|INTERESTS|HOBBIES|PROFILE|PROFESSIONAL\\s*PROFILE|SUMMARY|PROFESSIONAL\\s*SUMMARY|CAREER\\s*SUMMARY|OBJECTIVE|CAREER\\s*OBJECTIVE|PERSONAL\\s*INFORMATION|CONTACT\\s*INFORMATION|ADDITIONAL\\s*INFORMATION|TRAINING))\\s*[:\uff1a]?\\s*$" ], [ "^\\s*(?:\\d+[\\.\u3001\\)]\\s*)?(?:\u6559\u80b2\u80cc\u666f|\u6559\u80b2\u7ecf\u5386|\u5b66\u5386\u80cc\u666f|\u5b66\u672f\u80cc\u666f|\u6280\u672f\u80cc\u666f|\u5de5\u4f5c\u7ecf\u5386|\u5de5\u4f5c\u7ecf\u9a8c|\u5b9e\u4e60\u7ecf\u5386|\u9879\u76ee\u7ecf\u5386|\u9879\u76ee\u7ecf\u9a8c|\u79d1\u7814\u7ecf\u5386|\u7814\u7a76\u7ecf\u5386|\u6821\u56ed\u7ecf\u5386|\u5b9e\u8df5\u7ecf\u5386|\u4e13\u4e1a\u7ecf\u5386|\u804c\u4e1a\u7ecf\u5386|\u6280\u80fd|\u4e13\u4e1a\u6280\u80fd|\u6280\u80fd\u7279\u957f|\u6838\u5fc3\u6280\u80fd|\u6280\u672f\u6808|\u4e2a\u4eba\u6280\u80fd|\u5de5\u4f5c\u6280\u80fd|\u804c\u4e1a\u6280\u80fd|\u6280\u80fd\u4e0e\u8bc4\u4ef7|\u6280\u80fd\u4e0e\u81ea\u6211\u8bc4\u4ef7|\u5de5\u4f5c\u6280\u80fd\u4e0e\u81ea\u6211\u8bc4\u4ef7|\u804c\u4e1a\u6280\u80fd\u4e0e\u81ea\u6211\u8bc4\u4ef7|\u8bc1\u4e66|\u8d44\u683c\u8bc1\u4e66|\u804c\u4e1a\u8d44\u683c|\u8d44\u8d28\u8bc1\u4e66|\u83b7\u5956\u60c5\u51b5|\u83b7\u5956\u7ecf\u5386|\u8363\u8a89|\u8363\u8a89\u5956\u9879|\u5956\u9879|\u79d1\u7814\u6210\u679c|\u8bba\u6587\u53d1\u8868|\u53d1\u8868\u8bba\u6587|\u9886\u5bfc\u7ecf\u5386|\u5b66\u751f\u5de5\u4f5c|\u6821\u56ed\u6d3b\u52a8|\u793e\u56e2\u7ecf\u5386|\u6d3b\u52a8\u7ecf\u5386|\u5fd7\u613f\u7ecf\u5386|\u5fd7\u613f\u670d\u52a1|\u793e\u4f1a\u5b9e\u8df5|\u8bed\u8a00\u80fd\u529b|\u8bed\u8a00|\u81ea\u6211\u8bc4\u4ef7|\u4e2a\u4eba\u8bc4\u4ef7|\u81ea\u6211\u603b\u7ed3|\u4e2a\u4eba\u603b\u7ed3|\u4e2a\u4eba\u4f18\u52bf|\u4e2a\u4eba\u7b80\u4ecb|\u4e2a\u4eba\u4fe1\u606f|\u57fa\u672c\u4fe1\u606f|\u8054\u7cfb\u65b9\u5f0f|\u6c42\u804c\u610f\u5411|\u5e94\u8058\u610f\u5411|\u804c\u4e1a\u76ee\u6807|\u6c42\u804c\u76ee\u6807|\u5174\u8da3\u7231\u597d|\u5174\u8da3\u7279\u957f|\u57f9\u8bad\u7ecf\u5386|\u5176\u4ed6\u4fe1\u606f|\u9644\u52a0\u4fe1\u606f)\\s*[:\uff1a]?\\s*$" ] ], - "method": "hierarchy" + "method": "hierarchy", + "root_chunk_as_heading": true } }, "upstream": [ @@ -303,21 +304,24 @@ "data": { "isHovered": false }, - "id": "xy-edge__TitleChunker:FlatMiceFixstart-Extractor:ThreeDrinksActend", - "source": "TitleChunker:FlatMiceFix", + "id": "xy-edge__Extractor:ThreeDrinksActstart-Tokenizer:KindHandsWinend", + "markerEnd": "logo", + "source": "Extractor:ThreeDrinksAct", "sourceHandle": "start", - "target": "Extractor:ThreeDrinksAct", - "targetHandle": "end" + "target": "Tokenizer:KindHandsWin", + "targetHandle": "end", + "type": "buttonEdge", + "zIndex": 1001 }, { "data": { "isHovered": false }, - "id": "xy-edge__Extractor:ThreeDrinksActstart-Tokenizer:KindHandsWinend", + "id": "xy-edge__TitleChunker:FlatMiceFixstart-Extractor:ThreeDrinksActend", "markerEnd": "logo", - "source": "Extractor:ThreeDrinksAct", + "source": "TitleChunker:FlatMiceFix", "sourceHandle": "start", - "target": "Tokenizer:KindHandsWin", + "target": "Extractor:ThreeDrinksAct", "targetHandle": "end", "type": "buttonEdge", "zIndex": 1001 @@ -331,7 +335,7 @@ }, "id": "File", "measured": { - "height": 50, + "height": 49, "width": 200 }, "position": { @@ -460,7 +464,7 @@ "dragging": false, "id": "Parser:HipSignsRhyme", "measured": { - "height": 198, + "height": 197, "width": 200 }, "position": { @@ -489,12 +493,12 @@ "dragging": false, "id": "Tokenizer:KindHandsWin", "measured": { - "height": 114, + "height": 113, "width": 200 }, "position": { - "x": 876.4654525205967, - "y": 189.1906747329592 + "x": 883.0243372012395, + "y": 156.39625132974524 }, "selected": false, "sourcePosition": "right", @@ -514,6 +518,7 @@ } }, "promote_first_heading_to_root": false, + "root_chunk_as_heading": true, "rules": [ { "levels": [ @@ -537,14 +542,14 @@ "dragging": false, "id": "TitleChunker:FlatMiceFix", "measured": { - "height": 74, + "height": 73, "width": 200 }, "position": { "x": 572.7908769627791, "y": 141.55515313482098 }, - "selected": false, + "selected": true, "sourcePosition": "right", "targetPosition": "left", "type": "chunkerNode" @@ -580,12 +585,12 @@ "dragging": false, "id": "Extractor:ThreeDrinksAct", "measured": { - "height": 90, + "height": 89, "width": 200 }, "position": { - "x": 583.3659219536569, - "y": 274.7600100230409 + "x": 623.8123774842874, + "y": 236.49984938595793 }, "selected": false, "sourcePosition": "right", diff --git a/agent/tools/base.py b/agent/tools/base.py index f5a42de4d10..194b47fceec 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -67,6 +67,19 @@ async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any: else: resp = await thread_pool_exec(tool_obj.invoke, **arguments) + if resp is None and hasattr(tool_obj, "output") and callable(tool_obj.output): + try: + fallback_output = tool_obj.output() + if isinstance(fallback_output, dict) and fallback_output.get("content") not in (None, ""): + resp = fallback_output["content"] + elif fallback_output not in (None, ""): + resp = fallback_output + else: + resp = fallback_output + logging.warning(f"[ToolCall] resp is None, fallback to output name={name} output_keys={list(fallback_output.keys()) if isinstance(fallback_output, dict) else type(fallback_output).__name__}") + except Exception as e: + logging.warning(f"[ToolCall] resp is None and output fallback failed name={name} err={e}") + elapsed = timer() - st logging.info(f"[ToolCall] done name={name} elapsed={elapsed:.2f}s result={str(resp)[:200]}") self.callback(name, arguments, resp, elapsed_time=elapsed) diff --git a/agent/tools/code_exec.py b/agent/tools/code_exec.py index 5d65a2e33ae..ece67d97fc9 100644 --- a/agent/tools/code_exec.py +++ b/agent/tools/code_exec.py @@ -357,6 +357,7 @@ def _execute_code(self, language: str, code: str, arguments: dict): # Try using the new sandbox provider system first try: from agent.sandbox.client import execute_code as sandbox_execute_code + from agent.sandbox.providers.base import SandboxProviderConfigError if self.check_if_canceled("CodeExec execution"): return @@ -376,8 +377,16 @@ def _execute_code(self, language: str, code: str, arguments: dict): execution_metadata=result.metadata, ) - except (ImportError, RuntimeError) as provider_error: - # Provider system not available or not configured, fall back to HTTP + except SandboxProviderConfigError as provider_error: + self.set_output("_ERROR", str(provider_error)) + return self.output() + except ImportError as provider_error: + # Provider modules are unavailable, fall back to legacy HTTP sandbox. + logging.info(f"[CodeExec]: Provider system not available, using HTTP fallback: {provider_error}") + except RuntimeError as provider_error: + if not self._should_fallback_to_http(provider_error): + self.set_output("_ERROR", f"Provider system execution failed: {provider_error}") + return self.output() logging.info(f"[CodeExec]: Provider system not available, using HTTP fallback: {provider_error}") # Fallback to direct HTTP request @@ -487,6 +496,15 @@ def _resolve_execution_result_value(self, stdout: str, execution_metadata: Mappi return metadata.get("result_value"), False return self._deserialize_stdout(stdout), True + @staticmethod + def _should_fallback_to_http(provider_error: RuntimeError) -> bool: + message = str(provider_error).lower() + fallback_markers = ( + "no sandbox provider configured", + "sandbox provider type not configured", + ) + return any(marker in message for marker in fallback_markers) + @classmethod def _ensure_bucket_lifecycle(cls): if cls._lifecycle_configured: @@ -533,7 +551,7 @@ def _upload_artifacts(self, artifacts: list) -> list[dict]: settings.STORAGE_IMPL.put(SANDBOX_ARTIFACT_BUCKET, storage_name, binary) - url = f"/v1/document/artifact/{storage_name}" + url = f"/api/v1/documents/artifact/{storage_name}" uploaded.append( { "name": name, diff --git a/agent/tools/crawler.py b/agent/tools/crawler.py index e4d049e1bdd..6558c524f0a 100644 --- a/agent/tools/crawler.py +++ b/agent/tools/crawler.py @@ -19,7 +19,6 @@ from agent.tools.base import ToolParamBase, ToolBase - class CrawlerParam(ToolParamBase): """ Define the Crawler component parameters. @@ -31,20 +30,26 @@ def __init__(self): self.extract_type = "markdown" def check(self): - self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content']) + self.check_valid_value(self.extract_type, "Type of content from the crawler", ["html", "markdown", "content"]) class Crawler(ToolBase, ABC): component_name = "Crawler" def _run(self, history, **kwargs): - from api.utils.web_utils import is_valid_url + from common.ssrf_guard import assert_url_is_safe, pin_dns_global + ans = self.get_input() ans = " - ".join(ans["content"]) if "content" in ans else "" - if not is_valid_url(ans): + try: + _ssrf_hostname, _ssrf_ip = assert_url_is_safe(ans) + except ValueError: return Crawler.be_output("URL not valid") try: - result = asyncio.run(self.get_web(ans)) + # pin_dns_global is used (not thread-local) because crawl4ai resolves + # DNS in asyncio executor threads that don't share thread-local state. + with pin_dns_global(_ssrf_hostname, _ssrf_ip): + result = asyncio.run(self.get_web(ans)) return Crawler.be_output(result) @@ -57,18 +62,15 @@ async def get_web(self, url): proxy = self._param.proxy if self._param.proxy else None async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler: - result = await crawler.arun( - url=url, - bypass_cache=True - ) + result = await crawler.arun(url=url, bypass_cache=True) if self.check_if_canceled("Crawler async operation"): return - if self._param.extract_type == 'html': + if self._param.extract_type == "html": return result.cleaned_html - elif self._param.extract_type == 'markdown': + elif self._param.extract_type == "markdown": return result.markdown - elif self._param.extract_type == 'content': + elif self._param.extract_type == "content": return result.extracted_content return result.markdown diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 912a5c34850..4496f497aef 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -135,7 +135,11 @@ async def _retrieve_kb(self, query_text: str): doc_ids = [] if self._param.meta_data_filter != {}: - metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) + # Defer the (potentially expensive) metadata table load — manual + # filters served by ES push-down never need it. The loader is + # invoked at most once per request by ``apply_meta_data_filter``. + def _load_metas() -> dict: + return DocMetadataService.get_flatted_meta_by_kbs(kb_ids) def _resolve_manual_filter(flt: dict) -> dict: pat = re.compile(self.variable_ref_patt) @@ -174,11 +178,13 @@ def _resolve_manual_filter(flt: dict) -> dict: doc_ids = await apply_meta_data_filter( self._param.meta_data_filter, - metas, + None, query, chat_mdl, doc_ids, _resolve_manual_filter if self._param.meta_data_filter.get("method") == "manual" else None, + kb_ids=kb_ids, + metas_loader=_load_metas, ) if self._param.cross_languages: diff --git a/agent/tools/searxng.py b/agent/tools/searxng.py index fdc7bea525c..ef03375b306 100644 --- a/agent/tools/searxng.py +++ b/agent/tools/searxng.py @@ -20,6 +20,7 @@ import requests from agent.tools.base import ToolMeta, ToolParamBase, ToolBase from common.connection_utils import timeout +from common.ssrf_guard import assert_url_is_safe, pin_dns class SearXNGParam(ToolParamBase): @@ -36,15 +37,15 @@ def __init__(self): "type": "string", "description": "The search keywords to execute with SearXNG. The keywords should be the most important words/terms(includes synonyms) from the original request.", "default": "{sys.query}", - "required": True + "required": True, }, "searxng_url": { "type": "string", "description": "The base URL of your SearXNG instance (e.g., http://localhost:4000). This is required to connect to your SearXNG server.", "required": False, - "default": "" - } - } + "default": "", + }, + }, } super().__init__() self.top_n = 10 @@ -61,17 +62,7 @@ def check(self): self.check_positive_integer(self.top_n, "Top N") def get_input_form(self) -> dict[str, dict]: - return { - "query": { - "name": "Query", - "type": "line" - }, - "searxng_url": { - "name": "SearXNG URL", - "type": "line", - "placeholder": "http://localhost:4000" - } - } + return {"query": {"name": "Query", "type": "line"}, "searxng_url": {"name": "SearXNG URL", "type": "line", "placeholder": "http://localhost:4000"}} class SearXNG(ToolBase, ABC): @@ -94,26 +85,22 @@ def _invoke(self, **kwargs): self.set_output("formalized_content", "") return "" + try: + _ssrf_hostname, _ssrf_ip = assert_url_is_safe(searxng_url) + except ValueError as e: + self.set_output("_ERROR", str(e)) + return f"SearXNG error: SSRF guard blocked {searxng_url!r}: {e}" + last_e = "" - for _ in range(self._param.max_retries+1): + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("SearXNG processing"): return try: - search_params = { - 'q': query, - 'format': 'json', - 'categories': 'general', - 'language': 'auto', - 'safesearch': 1, - 'pageno': 1 - } - - response = requests.get( - f"{searxng_url}/search", - params=search_params, - timeout=10 - ) + search_params = {"q": query, "format": "json", "categories": "general", "language": "auto", "safesearch": 1, "pageno": 1} + + with pin_dns(_ssrf_hostname, _ssrf_ip): + response = requests.get(f"{searxng_url}/search", params=search_params, timeout=10) response.raise_for_status() if self.check_if_canceled("SearXNG processing"): @@ -128,15 +115,12 @@ def _invoke(self, **kwargs): if not isinstance(results, list): raise ValueError("Invalid results format from SearXNG") - results = results[:self._param.top_n] + results = results[: self._param.top_n] if self.check_if_canceled("SearXNG processing"): return - self._retrieve_chunks(results, - get_title=lambda r: r.get("title", ""), - get_url=lambda r: r.get("url", ""), - get_content=lambda r: r.get("content", "")) + self._retrieve_chunks(results, get_title=lambda r: r.get("title", ""), get_url=lambda r: r.get("url", ""), get_content=lambda r: r.get("content", "")) self.set_output("json", results) return self.output("formalized_content") diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 9139954115c..e05bbb03d42 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -79,8 +79,8 @@ def _unauthorized_message(error): app.config["MAX_CONTENT_LENGTH"] = int( os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024) ) -app.config['SECRET_KEY'] = settings.SECRET_KEY -app.secret_key = settings.SECRET_KEY +app.config['SECRET_KEY'] = settings.get_secret_key() +app.secret_key = settings.get_secret_key() commands.register_commands(app) from functools import wraps @@ -93,7 +93,7 @@ def _unauthorized_message(error): def _load_user(): - jwt = Serializer(secret_key=settings.SECRET_KEY) + jwt = Serializer(secret_key=settings.get_secret_key()) authorization = request.headers.get("Authorization") g.user = None if not authorization: @@ -301,6 +301,10 @@ def register_page(page_path): register_page(path) for directory in pages_dir for path in search_pages_path(directory) ] +# Register backward compatibility routes for deprecated APIs +from api.apps.backward_compat import register_backward_compat_routes +register_backward_compat_routes(app) + @app.errorhandler(404) async def not_found(error): diff --git a/api/apps/auth/README.md b/api/apps/auth/README.md index 372e75cfbd8..8edab999f82 100644 --- a/api/apps/auth/README.md +++ b/api/apps/auth/README.md @@ -20,7 +20,7 @@ oauth_config = { "authorization_url": "https://your-oauth-provider.com/oauth/authorize", "token_url": "https://your-oauth-provider.com/oauth/token", "userinfo_url": "https://your-oauth-provider.com/oauth/userinfo", - "redirect_uri": "https://your-app.com/v1/user/oauth/callback/" + "redirect_uri": "https://your-app.com/api/v1/auth/oauth//callback" } # OIDC configuration @@ -29,7 +29,7 @@ oidc_config = { "issuer": "https://your-oauth-provider.com/oidc", "client_id": "your_client_id", "client_secret": "your_client_secret", - "redirect_uri": "https://your-app.com/v1/user/oauth/callback/" + "redirect_uri": "https://your-app.com/api/v1/auth/oauth//callback" } # Github OAuth configuration diff --git a/api/apps/backward_compat.py b/api/apps/backward_compat.py new file mode 100644 index 00000000000..a2c950158e6 --- /dev/null +++ b/api/apps/backward_compat.py @@ -0,0 +1,522 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Backward compatibility layer for deprecated API endpoints. + +This module adds support for old API routes that were deprecated during the +RESTful API migration. Each deprecated route forwards to the corresponding +new API implementation. + +Deprecated APIs and their replacements: +- POST /api/v1/agents/{agent_id}/completions -> POST /api/v1/agents/chat/completion +- POST /api/v1/chats/{chat_id}/completions -> POST /api/v1/chat/completions +- POST /api/v1/chats_openai/{chat_id}/chat/completions -> POST /api/v1/openai/{chat_id}/chat/completions +- PUT /api/v1/chats/{chat_id}/sessions/{session_id} -> PATCH /api/v1/chats/{chat_id}/sessions/{session_id} +- DELETE /api/v1/chats -> DELETE /api/v1/chats/{chat_id} (with body) +- POST /api/v1/file/convert -> POST /api/v1/files/link-to-datasets +- GET /api/v1/file/* -> GET /api/v1/files* +- POST /api/v1/file/* -> POST /api/v1/files* +- GET /api/v1/document/get/{doc_id} -> GET /api/v1/documents/{doc_id}/preview +- GET /api/v1/document/download/{doc_id} -> GET /api/v1/documents/{doc_id}/download +- GET /v1/document/download/{attachment_id} -> GET /api/v1/documents/{attachment_id}/download +- GET /v1/system/healthz -> GET /api/v1/system/healthz +- POST /api/v1/sessions/related_questions -> POST /api/v1/chat/recommandation +- PUT (chunk update) -> PATCH (chunk update) +""" +import logging + +from quart import Blueprint, jsonify, request + +from api.apps import login_required +from api.apps.restful_apis import chat_api, file_api, file2document_api, chunk_api, openai_api, document_api +from api.apps.restful_apis.system_api import run_health_checks +from api.apps.restful_apis import agent_api +from api.apps.services import file_api_service +from api.utils.api_utils import get_data_error_result, get_json_result, add_tenant_id_to_kwargs + +manager = Blueprint("backward_compat", __name__) +legacy_v1_manager = Blueprint("backward_compat_legacy_v1", __name__) + + +# ============================================================================= +# System APIs +# ============================================================================= + +@legacy_v1_manager.route("/system/healthz", methods=["GET"]) +async def deprecated_system_healthz(): + """ + Deprecated: Use GET /api/v1/system/healthz instead. + + Old path: GET /v1/system/healthz + New path: GET /api/v1/system/healthz + """ + logging.warning( + "API endpoint /v1/system/healthz is deprecated. " + "Please use /api/v1/system/healthz instead." + ) + result, all_ok = run_health_checks() + return jsonify(result), (200 if all_ok else 500) + +# ============================================================================= +# Chat Completion APIs +# ============================================================================= + +@manager.route("/chats//completions", methods=["POST"]) +@login_required +async def deprecated_chat_completions(chat_id): + """ + Deprecated: Use POST /api/v1/chat/completions instead. + + Old path: POST /api/v1/chats/{chat_id}/completions + New path: POST /api/v1/chat/completions + """ + logging.warning( + "API endpoint /api/v1/chats/%s/completions is deprecated. " + "Please use /api/v1/chat/completions instead.", + chat_id, + ) + # Forward to the new API implementation + return await chat_api.session_completion(chat_id) + + +@manager.route("/chats_openai//chat/completions", methods=["POST"]) +@login_required +async def deprecated_openai_chat_completions(chat_id): + """ + Deprecated: Use POST /api/v1/openai/{chat_id}/chat/completions instead. + + Old path: POST /api/v1/chats_openai/{chat_id}/chat/completions + New path: POST /api/v1/openai/{chat_id}/chat/completions + """ + logging.warning( + "API endpoint /api/v1/chats_openai/%s/chat/completions is deprecated. " + "Please use /api/v1/openai/%s/chat/completions instead.", + chat_id, chat_id, + ) + # Forward to the new API implementation + return await openai_api.openai_chat_completions(chat_id) + + +# ============================================================================= +# Chat Session APIs +# ============================================================================= + +@manager.route("/chats//sessions/", methods=["PUT"]) +@login_required +async def deprecated_update_session(chat_id, session_id): + """ + Deprecated: Use PATCH /api/v1/chats/{chat_id}/sessions/{session_id} instead. + + Old path: PUT /api/v1/chats/{chat_id}/sessions/{session_id} + New path: PATCH /api/v1/chats/{chat_id}/sessions/{session_id} + """ + logging.warning( + "API endpoint PUT /api/v1/chats/%s/sessions/%s is deprecated. " + "Please use PATCH /api/v1/chats/%s/sessions/%s instead.", + chat_id, session_id, chat_id, session_id, + ) + # Forward to the new API implementation + return await chat_api.update_session(chat_id, session_id) + + +# ============================================================================= +# File APIs (Old /api/v1/file/* -> New /api/v1/files*) +# ============================================================================= + +@manager.route("/file/get/", methods=["GET"]) +@login_required +async def deprecated_file_get(file_id): + """ + Deprecated: Use GET /api/v1/files/{file_id} instead. + + Old path: GET /api/v1/file/get/{file_id} + New path: GET /api/v1/files/{file_id} + """ + logging.warning( + "API endpoint /api/v1/file/get/%s is deprecated. " + "Please use /api/v1/files/%s instead.", + file_id, file_id, + ) + # Forward to the new API implementation (download) + return await file_api.download(file_id=file_id) + + +@manager.route("/file/list", methods=["GET"]) +@login_required +async def deprecated_file_list(): + """ + Deprecated: Use GET /api/v1/files instead. + + Old path: GET /api/v1/file/list?... + New path: GET /api/v1/files?... + """ + logging.warning( + "API endpoint /api/v1/file/list is deprecated. " + "Please use /api/v1/files instead." + ) + # Forward to the new API implementation + return await file_api.list_files() + + +@manager.route("/file/all_parent_folder", methods=["GET"]) +@login_required +async def deprecated_file_all_parent_folder(): + """ + Deprecated: Use GET /api/v1/files/{file_id}/ancestors instead. + + Old path: GET /api/v1/file/all_parent_folder?file_id=... + New path: GET /api/v1/files/{file_id}/ancestors + """ + file_id = request.args.get("file_id") + if not file_id: + return get_data_error_result(message="`file_id` query parameter is required") + logging.warning( + "API endpoint /api/v1/file/all_parent_folder is deprecated. " + "Please use /api/v1/files/%s/ancestors instead.", + file_id, + ) + # Forward to the new API implementation + return await file_api.ancestors(file_id=file_id) + + +@manager.route("/file/parent_folder", methods=["GET"]) +@login_required +async def deprecated_file_parent_folder(): + """ + Deprecated: Use GET /api/v1/files/{file_id}/parent instead. + + Old path: GET /api/v1/file/parent_folder?file_id=... + New path: GET /api/v1/files/{file_id}/parent + """ + file_id = request.args.get("file_id") + if not file_id: + return get_data_error_result(message="`file_id` query parameter is required") + logging.warning( + "API endpoint /api/v1/file/parent_folder is deprecated. " + "Please use /api/v1/files/%s/parent instead.", + file_id, + ) + # Forward to the new API implementation + return await file_api.parent_folder(file_id=file_id) + + +@manager.route("/file/root_folder", methods=["GET"]) +@login_required +async def deprecated_file_root_folder(): + """ + Deprecated: Root folder is now accessible via GET /api/v1/files with parent_id=... + + Old path: GET /api/v1/file/root_folder + New path: GET /api/v1/files?parent_id= + """ + logging.warning( + "API endpoint /api/v1/file/root_folder is deprecated. " + "Please use /api/v1/files with appropriate parent_id instead." + ) + # Forward to the new API implementation with empty parent_id to get root + return await file_api.list_files() + + +@manager.route("/file/create", methods=["POST"]) +@login_required +@add_tenant_id_to_kwargs +async def deprecated_file_create(tenant_id=None): + """ + Deprecated: Use POST /api/v1/files instead. + + Old path: POST /api/v1/file/create + New path: POST /api/v1/files + """ + logging.warning( + "API endpoint /api/v1/file/create is deprecated. " + "Please use POST /api/v1/files instead." + ) + # Forward to the new API implementation + return await file_api.create_or_upload(tenant_id=tenant_id) + + +@manager.route("/file/upload", methods=["POST"]) +@login_required +@add_tenant_id_to_kwargs +async def deprecated_file_upload(tenant_id=None): + """ + Deprecated: Use POST /api/v1/files (with multipart/form-data) instead. + + Old path: POST /api/v1/file/upload + New path: POST /api/v1/files + """ + logging.warning( + "API endpoint /api/v1/file/upload is deprecated. " + "Please use POST /api/v1/files with multipart/form-data instead." + ) + # Forward to the new API implementation + return await file_api.create_or_upload(tenant_id=tenant_id) + + +@manager.route("/file/convert", methods=["POST"]) +@login_required +async def deprecated_file_convert(): + """ + Deprecated: Use POST /api/v1/files/link-to-datasets instead. + + Old path: POST /api/v1/file/convert + New path: POST /api/v1/files/link-to-datasets + """ + logging.warning( + "API endpoint /api/v1/file/convert is deprecated. " + "Please use POST /api/v1/files/link-to-datasets instead." + ) + return await file2document_api.convert() + + +@manager.route("/file/mv", methods=["POST"]) +@login_required +@add_tenant_id_to_kwargs +async def deprecated_file_mv(tenant_id=None): + """ + Deprecated: Use POST /api/v1/files/move instead. + + Old path: POST /api/v1/file/mv + New path: POST /api/v1/files/move + """ + logging.warning( + "API endpoint /api/v1/file/mv is deprecated. " + "Please use POST /api/v1/files/move instead." + ) + # Forward to the new API implementation + return await file_api.move(tenant_id=tenant_id) + + +@manager.route("/file/rename", methods=["POST"]) +@login_required +@add_tenant_id_to_kwargs +async def deprecated_file_rename(tenant_id=None): + """ + Deprecated: Use POST /api/v1/files/move with new_name instead. + + Old path: POST /api/v1/file/rename + New path: POST /api/v1/files/move + """ + logging.warning( + "API endpoint /api/v1/file/rename is deprecated. " + "Please use POST /api/v1/files/move with `new_name` instead." + ) + # Transform the old API format to new format + req = await request.get_json() + # Old API used `file_id` and `name`, new API uses `src_file_ids` and `new_name` + src_file_ids = [req.get("file_id")] + new_name = req.get("name") + # Call the underlying service directly with transformed data + try: + success, result = await file_api_service.move_files( + tenant_id, src_file_ids, None, new_name + ) + if success: + return get_json_result(data=result) + else: + return get_data_error_result(message=result) + except Exception as e: + logging.exception(e) + return get_data_error_result(message="Internal server error") + + +@manager.route("/file/rm", methods=["POST"]) +@login_required +@add_tenant_id_to_kwargs +async def deprecated_file_rm(tenant_id=None): + """ + Deprecated: Use DELETE /api/v1/files instead. + + Old path: POST /api/v1/file/rm + New path: DELETE /api/v1/files + """ + logging.warning( + "API endpoint /api/v1/file/rm is deprecated. " + "Please use DELETE /api/v1/files instead." + ) + # Transform POST with body to DELETE behavior + # The new API expects a JSON body with `ids` + return await file_api.delete(tenant_id=tenant_id) + + +# ============================================================================= +# Related Questions API +# ============================================================================= + +@manager.route("/sessions/related_questions", methods=["POST"]) +@login_required +async def deprecated_related_questions(): + """ + Deprecated: Use POST /api/v1/chat/recommendation instead. + + Old path: POST /api/v1/sessions/related_questions + New path: POST /api/v1/chat/recommendation + """ + logging.warning( + "API endpoint /api/v1/sessions/related_questions is deprecated. " + "Please use /api/v1/chat/recommendation instead." + ) + # Forward to the new API implementation + return await chat_api.recommendation() + + +# ============================================================================= +# Chunk Update API (PUT -> PATCH) +# ============================================================================= + +@manager.route("/datasets//documents//chunks/", methods=["PUT"]) +@login_required +async def deprecated_update_chunk(dataset_id, document_id, chunk_id): + """ + Deprecated: Use PATCH /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id} instead. + + Old path: PUT /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id} + New path: PATCH /api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id} + """ + logging.warning( + "API endpoint PUT /api/v1/datasets/%s/documents/%s/chunks/%s is deprecated. " + "Please use PATCH instead.", + dataset_id, document_id, chunk_id, + ) + # Forward to the new API implementation + return await chunk_api.update_chunk(dataset_id=dataset_id, document_id=document_id, chunk_id=chunk_id) + + +# ============================================================================= +# File Upload Info API +# ============================================================================= + +@manager.route("/file/upload_info", methods=["POST"]) +@login_required +async def deprecated_file_upload_info(): + """ + Deprecated: Use POST /api/v1/documents/upload instead. + + Old path: POST /api/v1/file/upload_info + New path: POST /api/v1/documents/upload + """ + from api.apps import current_user + + logging.warning( + "API endpoint /api/v1/file/upload_info is deprecated. " + "Please use POST /api/v1/documents/upload instead." + ) + # Forward to the new API implementation + # Need to pass tenant_id explicitly since we're calling the function directly + tenant_id = current_user.id + return await document_api.upload_info(tenant_id=tenant_id) + + +# ============================================================================= +# Document APIs +# ============================================================================= + +@manager.route("/datasets//documents/", methods=["PUT"]) +@login_required +async def deprecated_update_document(dataset_id, document_id): + """ + Deprecated: Use PATCH /api/v1/datasets/{dataset_id}/documents/{document_id} instead. + + Old path: PUT /api/v1/datasets/{dataset_id}/documents/{document_id} + New path: PATCH /api/v1/datasets/{dataset_id}/documents/{document_id} + """ + logging.warning( + "API endpoint PUT /api/v1/datasets/%s/documents/%s is deprecated. " + "Please use PATCH instead.", + dataset_id, document_id, + ) + # Forward to the new API implementation + return await document_api.update_document(dataset_id=dataset_id, document_id=document_id) + + +@manager.route("/document/get/", methods=["GET"]) +@login_required +async def deprecated_document_get(doc_id): + """ + Deprecated: Use GET /api/v1/documents/{doc_id}/preview instead. + + Old path: GET /api/v1/document/get/{doc_id} + New path: GET /api/v1/documents/{doc_id}/preview + """ + logging.warning( + "API endpoint /api/v1/document/get/%s is deprecated. " + "Please use /api/v1/documents/%s/preview instead.", + doc_id, doc_id, + ) + return await document_api.get(doc_id) + + +@manager.route("/document/download/", methods=["GET"]) +@login_required +async def deprecated_document_download(doc_id): + """ + Deprecated: Use GET /api/v1/documents/{doc_id}/download instead. + + Old path: GET /api/v1/document/download/{doc_id} + New path: GET /api/v1/documents/{doc_id}/download + """ + logging.warning( + "API endpoint /api/v1/document/download/%s is deprecated. " + "Please use /api/v1/documents/%s/download instead.", + doc_id, doc_id, + ) + return await document_api.download_attachment(doc_id=doc_id) + + +@legacy_v1_manager.route("/document/download/", methods=["GET"]) +@login_required +async def document_download_v1(attachment_id): + """ + Compatibility alias for document download under /v1. + + Old path: GET /v1/document/download/{attachment_id} + New path: GET /api/v1/documents/{attachment_id}/download + """ + logging.warning( + "API endpoint /v1/document/download/%s is deprecated. " + "Please use /api/v1/documents/%s/download instead.", + attachment_id, attachment_id, + ) + return await document_api.download_attachment(attachment_id=attachment_id) + +# ============================================================================= +# Agent Chat API +# ============================================================================= + +@manager.route("/agents//completions", methods=["POST"]) +@login_required +@add_tenant_id_to_kwargs +async def deprecated_agent_completions(agent_id, tenant_id=None): + """ + Deprecated: Use POST /api/v1/agents/chat/completions instead. + + Old path: POST /api/v1/agents/{agent_id}/completions + New path: POST /api/v1/agents/chat/completions + """ + logging.warning( + "API endpoint /api/v1/agents/%s/completions is deprecated. " + "Please use /api/v1/agents/chat/completions instead.", + agent_id, + ) + return await agent_api.agent_chat_completion(tenant_id=tenant_id, agent_id=agent_id) + +def register_backward_compat_routes(app_instance): + """ + Register all backward compatibility routes with the app. + """ + app_instance.register_blueprint(manager, url_prefix="/api/v1") + app_instance.register_blueprint(legacy_v1_manager, url_prefix="/v1") + logging.info("Backward compatibility routes registered successfully.") diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py deleted file mode 100644 index 8c896e36add..00000000000 --- a/api/apps/canvas_app.py +++ /dev/null @@ -1,755 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import copy -import inspect -import json -import logging -from functools import partial -from quart import request, Response, make_response -from agent.component import LLM -from api.db import CanvasCategory -from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService -from api.db.services.document_service import DocumentService -from api.db.services.file_service import FileService -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.pipeline_operation_log_service import PipelineOperationLogService -from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, TaskService -from api.db.services.user_service import TenantService -from api.db.services.user_canvas_version import UserCanvasVersionService -from common.constants import RetCode -from common.misc_utils import get_uuid, thread_pool_exec -from api.utils.api_utils import ( - get_json_result, - server_error_response, - validate_request, - get_data_error_result, - get_request_json, -) -from agent.canvas import Canvas -from agent.dsl_migration import normalize_chunker_dsl -from peewee import MySQLDatabase, PostgresqlDatabase -from api.db.db_models import APIToken, Task - -from rag.flow.pipeline import Pipeline -from rag.nlp import search -from rag.utils.redis_conn import REDIS_CONN -from common import settings -from api.apps import login_required, current_user -from api.apps.services.canvas_replica_service import CanvasReplicaService -from api.db.services.canvas_service import completion as agent_completion - - -@manager.route('/templates', methods=['GET']) # noqa: F821 -@login_required -def templates(): - return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()]) - - -@manager.route('/rm', methods=['POST']) # noqa: F821 -@validate_request("canvas_ids") -@login_required -async def rm(): - req = await get_request_json() - for i in req["canvas_ids"]: - if not UserCanvasService.accessible(i, current_user.id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) - UserCanvasService.delete_by_id(i) - return get_json_result(data=True) - - -@manager.route('/set', methods=['POST']) # noqa: F821 -@validate_request("dsl", "title") -@login_required -async def save(): - req = await get_request_json() - req['release'] = bool(req.get("release", "")) - try: - req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"]) - except ValueError as e: - return get_data_error_result(message=str(e)) - cate = req.get("canvas_category", CanvasCategory.Agent) - if "id" not in req: - req["user_id"] = current_user.id - if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=cate): - return get_data_error_result(message=f"{req['title'].strip()} already exists.") - req["id"] = get_uuid() - if not UserCanvasService.save(**req): - return get_data_error_result(message="Fail to save canvas.") - else: - if not UserCanvasService.accessible(req["id"], current_user.id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) - UserCanvasService.update_by_id(req["id"], req) - # save version - UserCanvasVersionService.save_or_replace_latest( - user_canvas_id=req["id"], - dsl=req["dsl"], - title=UserCanvasVersionService.build_version_title(getattr(current_user, "nickname", current_user.id), req.get("title")), - release=req.get("release"), - ) - replica_ok = CanvasReplicaService.replace_for_set( - canvas_id=req["id"], - tenant_id=str(current_user.id), - runtime_user_id=str(current_user.id), - dsl=req["dsl"], - canvas_category=req.get("canvas_category", cate), - title=req.get("title", ""), - ) - if not replica_ok: - return get_data_error_result(message="canvas saved, but replica sync failed.") - return get_json_result(data=req) - - -@manager.route('/get/', methods=['GET']) # noqa: F821 -@login_required -def get(canvas_id): - if not UserCanvasService.accessible(canvas_id, current_user.id): - return get_data_error_result(message="canvas not found.") - e, c = UserCanvasService.get_by_canvas_id(canvas_id) - if not e: - return get_data_error_result(message="canvas not found.") - try: - # DELETE - CanvasReplicaService.bootstrap( - canvas_id=canvas_id, - tenant_id=str(current_user.id), - runtime_user_id=str(current_user.id), - dsl=c.get("dsl"), - canvas_category=c.get("canvas_category", CanvasCategory.Agent), - title=c.get("title", ""), - ) - except ValueError as e: - return get_data_error_result(message=str(e)) - - # Get the last publication time (latest released version's update_time) - last_publish_time = None - versions = UserCanvasVersionService.list_by_canvas_id(canvas_id) - if versions: - released_versions = [v for v in versions if v.release] - if released_versions: - # Sort by update_time descending and get the latest - released_versions.sort(key=lambda x: x.update_time, reverse=True) - last_publish_time = released_versions[0].update_time - - # Add last_publish_time to response data - if isinstance(c, dict): - c["dsl"] = normalize_chunker_dsl(c.get("dsl", {})) - c["last_publish_time"] = last_publish_time - else: - # If c is a model object, convert to dict first - c = c.to_dict() - c["dsl"] = normalize_chunker_dsl(c.get("dsl", {})) - c["last_publish_time"] = last_publish_time - - # For pipeline type, get associated datasets - if c.get("canvas_category") == CanvasCategory.DataFlow: - datasets = list(KnowledgebaseService.query(pipeline_id=canvas_id)) - c["datasets"] = [{"id": d.id, "name": d.name, "avatar": d.avatar} for d in datasets] - - return get_json_result(data=c) - - -@manager.route('/getsse/', methods=['GET']) # type: ignore # noqa: F821 -def getsse(canvas_id): - token = request.headers.get('Authorization').split() - if len(token) != 2: - return get_data_error_result(message='Authorization is not valid!') - token = token[1] - objs = APIToken.query(beta=token) - if not objs: - return get_data_error_result(message='Authentication error: API key is invalid!"') - tenant_id = objs[0].tenant_id - if not UserCanvasService.query(user_id=tenant_id, id=canvas_id): - return get_json_result( - data=False, - message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR - ) - e, c = UserCanvasService.get_by_id(canvas_id) - if not e or c.user_id != tenant_id: - return get_data_error_result(message="canvas not found.") - return get_json_result(data=c.to_dict()) - - -@manager.route('/completion', methods=['POST']) # noqa: F821 -@validate_request("id") -@login_required -async def run(): - req = await get_request_json() - query = req.get("query", "") - files = req.get("files", []) - inputs = req.get("inputs", {}) - tenant_id = str(current_user.id) - runtime_user_id = req.get("user_id") or tenant_id - user_id = str(runtime_user_id) - if not await thread_pool_exec(UserCanvasService.accessible, req["id"], tenant_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) - - replica_payload = CanvasReplicaService.load_for_run( - canvas_id=req["id"], - tenant_id=tenant_id, - runtime_user_id=user_id, - ) - - if not replica_payload: - return get_data_error_result(message="canvas replica not found, please call /get/ first.") - - replica_dsl = replica_payload.get("dsl", {}) - canvas_title = replica_payload.get("title", "") - canvas_category = replica_payload.get("canvas_category", CanvasCategory.Agent) - dsl_str = json.dumps(replica_dsl, ensure_ascii=False) - - _, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"]) - if cvs.canvas_category == CanvasCategory.DataFlow: - task_id = get_uuid() - Pipeline(dsl_str, tenant_id=tenant_id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"]) - ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0) - if not ok: - return get_data_error_result(message=error_message) - return get_json_result(data={"message_id": task_id}) - - try: - canvas = Canvas(dsl_str, tenant_id, canvas_id=req["id"]) - except Exception as e: - return server_error_response(e) - - async def sse(): - nonlocal canvas, user_id - try: - async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): - yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" - - commit_ok = CanvasReplicaService.commit_after_run( - canvas_id=req["id"], - tenant_id=tenant_id, - runtime_user_id=user_id, - dsl=json.loads(str(canvas)), - canvas_category=canvas_category, - title=canvas_title, - ) - if not commit_ok: - logging.error( - "Canvas runtime replica commit failed: canvas_id=%s tenant_id=%s runtime_user_id=%s", - req["id"], - tenant_id, - user_id, - ) - - except Exception as e: - logging.exception(e) - canvas.cancel_task() - yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n" - - resp = Response(sse(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - #resp.call_on_close(lambda: canvas.cancel_task()) - return resp - - -@manager.route("//completion", methods=["POST"]) # noqa: F821 -@login_required -async def exp_agent_completion(canvas_id): - tenant_id = current_user.id - req = await get_request_json() - return_trace = bool(req.get("return_trace", False)) - async def generate(): - trace_items = [] - async for answer in agent_completion(tenant_id=tenant_id, agent_id=canvas_id, **req): - if isinstance(answer, str): - try: - ans = json.loads(answer[5:]) # remove "data:" - except Exception: - continue - - event = ans.get("event") - if event == "node_finished": - if return_trace: - data = ans.get("data", {}) - trace_items.append( - { - "component_id": data.get("component_id"), - "trace": [copy.deepcopy(data)], - } - ) - ans.setdefault("data", {})["trace"] = trace_items - answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" - yield answer - - if event not in ["message", "message_end"]: - continue - - yield answer - - yield "data:[DONE]\n\n" - - resp = Response(generate(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - - -@manager.route('/rerun', methods=['POST']) # noqa: F821 -@validate_request("id", "dsl", "component_id") -@login_required -async def rerun(): - req = await get_request_json() - doc = PipelineOperationLogService.get_documents_info(req["id"]) - if not doc: - return get_data_error_result(message="Document not found.") - doc = doc[0] - if 0 < doc["progress"] < 1: - return get_data_error_result(message=f"`{doc['name']}` is processing...") - - if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]): - settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"]) - doc["progress_msg"] = "" - doc["chunk_num"] = 0 - doc["token_num"] = 0 - DocumentService.clear_chunk_num_when_rerun(doc["id"]) - DocumentService.update_by_id(id, doc) - TaskService.filter_delete([Task.doc_id == id]) - - dsl = req["dsl"] - dsl["path"] = [req["component_id"]] - PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl}) - queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True) - return get_json_result(data=True) - - -@manager.route('/cancel/', methods=['PUT']) # noqa: F821 -@login_required -def cancel(task_id): - try: - REDIS_CONN.set(f"{task_id}-cancel", "x") - except Exception as e: - logging.exception(e) - return get_json_result(data=True) - - -@manager.route('/reset', methods=['POST']) # noqa: F821 -@validate_request("id") -@login_required -async def reset(): - req = await get_request_json() - if not UserCanvasService.accessible(req["id"], current_user.id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) - try: - e, user_canvas = UserCanvasService.get_by_id(req["id"]) - if not e: - return get_data_error_result(message="canvas not found.") - - canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id) - canvas.reset() - req["dsl"] = json.loads(str(canvas)) - UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]}) - return get_json_result(data=req["dsl"]) - except Exception as e: - return server_error_response(e) - - -@manager.route("/upload/", methods=["POST"]) # noqa: F821 -async def upload(canvas_id): - e, cvs = UserCanvasService.get_by_canvas_id(canvas_id) - if not e: - return get_data_error_result(message="canvas not found.") - - user_id = cvs["user_id"] - files = await request.files - file_objs = files.getlist("file") if files and files.get("file") else [] - try: - if len(file_objs) == 1: - return get_json_result(data=FileService.upload_info(user_id, file_objs[0], request.args.get("url"))) - results = [FileService.upload_info(user_id, f) for f in file_objs] - return get_json_result(data=results) - except Exception as e: - return server_error_response(e) - - -@manager.route('/input_form', methods=['GET']) # noqa: F821 -@login_required -def input_form(): - cvs_id = request.args.get("id") - cpn_id = request.args.get("component_id") - try: - e, user_canvas = UserCanvasService.get_by_id(cvs_id) - if not e: - return get_data_error_result(message="canvas not found.") - if not UserCanvasService.query(user_id=current_user.id, id=cvs_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) - - canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id) - return get_json_result(data=canvas.get_component_input_form(cpn_id)) - except Exception as e: - return server_error_response(e) - - -@manager.route('/debug', methods=['POST']) # noqa: F821 -@validate_request("id", "component_id", "params") -@login_required -async def debug(): - req = await get_request_json() - if not UserCanvasService.accessible(req["id"], current_user.id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) - try: - e, user_canvas = UserCanvasService.get_by_id(req["id"]) - canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id) - canvas.reset() - canvas.message_id = get_uuid() - component = canvas.get_component(req["component_id"])["obj"] - component.reset() - - if isinstance(component, LLM): - component.set_debug_inputs(req["params"]) - component.invoke(**{k: o["value"] for k,o in req["params"].items()}) - outputs = component.output() - for k in outputs.keys(): - if isinstance(outputs[k], partial): - txt = "" - iter_obj = outputs[k]() - if inspect.isasyncgen(iter_obj): - async for c in iter_obj: - txt += c - else: - for c in iter_obj: - txt += c - outputs[k] = txt - return get_json_result(data=outputs) - except Exception as e: - return server_error_response(e) - - -@manager.route('/test_db_connect', methods=['POST']) # noqa: F821 -@validate_request("db_type", "database", "username", "host", "port", "password") -@login_required -async def test_db_connect(): - req = await get_request_json() - try: - if req["db_type"] in ["mysql", "mariadb"]: - db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], - password=req["password"]) - elif req["db_type"] == "oceanbase": - db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], - password=req["password"], charset="utf8mb4") - elif req["db_type"] == 'postgres': - db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], - password=req["password"]) - elif req["db_type"] == 'mssql': - import pyodbc - connection_string = ( - f"DRIVER={{ODBC Driver 17 for SQL Server}};" - f"SERVER={req['host']},{req['port']};" - f"DATABASE={req['database']};" - f"UID={req['username']};" - f"PWD={req['password']};" - ) - db = pyodbc.connect(connection_string) - cursor = db.cursor() - cursor.execute("SELECT 1") - cursor.close() - elif req["db_type"] == 'IBM DB2': - import ibm_db - conn_str = ( - f"DATABASE={req['database']};" - f"HOSTNAME={req['host']};" - f"PORT={req['port']};" - f"PROTOCOL=TCPIP;" - f"UID={req['username']};" - f"PWD={req['password']};" - ) - redacted_conn_str = ( - f"DATABASE={req['database']};" - f"HOSTNAME={req['host']};" - f"PORT={req['port']};" - f"PROTOCOL=TCPIP;" - f"UID={req['username']};" - f"PWD=****;" - ) - logging.info(redacted_conn_str) - conn = ibm_db.connect(conn_str, "", "") - stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1") - ibm_db.fetch_assoc(stmt) - ibm_db.close(conn) - return get_json_result(data="Database Connection Successful!") - elif req["db_type"] == 'trino': - def _parse_catalog_schema(db_name: str): - if not db_name: - return None, None - if "." in db_name: - catalog_name, schema_name = db_name.split(".", 1) - elif "/" in db_name: - catalog_name, schema_name = db_name.split("/", 1) - else: - catalog_name, schema_name = db_name, "default" - return catalog_name, schema_name - try: - import trino - import os - except Exception as e: - return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}") - - catalog, schema = _parse_catalog_schema(req["database"]) - if not catalog: - return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.") - - http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http" - - auth = None - if http_scheme == "https" and req.get("password"): - auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"]) - - conn = trino.dbapi.connect( - host=req["host"], - port=int(req["port"] or 8080), - user=req["username"] or "ragflow", - catalog=catalog, - schema=schema or "default", - http_scheme=http_scheme, - auth=auth - ) - cur = conn.cursor() - cur.execute("SELECT 1") - cur.fetchall() - cur.close() - conn.close() - return get_json_result(data="Database Connection Successful!") - else: - return server_error_response("Unsupported database type.") - if req["db_type"] != 'mssql': - db.connect() - db.close() - - return get_json_result(data="Database Connection Successful!") - except Exception as e: - return server_error_response(e) - - -#api get list version dsl of canvas -@manager.route('/getlistversion/', methods=['GET']) # noqa: F821 -@login_required -def getlistversion(canvas_id): - try: - versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1) - return get_json_result(data=versions) - except Exception as e: - return get_data_error_result(message=f"Error getting history files: {e}") - - -#api get version dsl of canvas -@manager.route('/getversion/', methods=['GET']) # noqa: F821 -@login_required -def getversion( version_id): - try: - e, version = UserCanvasVersionService.get_by_id(version_id) - if version: - return get_json_result(data=version.to_dict()) - except Exception as e: - return get_json_result(data=f"Error getting history file: {e}") - - -@manager.route('/list', methods=['GET']) # noqa: F821 -@login_required -def list_canvas(): - keywords = request.args.get("keywords", "") - page_number = int(request.args.get("page", 0)) - items_per_page = int(request.args.get("page_size", 0)) - orderby = request.args.get("orderby", "create_time") - canvas_category = request.args.get("canvas_category") - if request.args.get("desc", "true").lower() == "false": - desc = False - else: - desc = True - owner_ids = [id for id in request.args.get("owner_ids", "").strip().split(",") if id] - if not owner_ids: - tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) - tenants = [m["tenant_id"] for m in tenants] - tenants.append(current_user.id) - canvas, total = UserCanvasService.get_by_tenant_ids( - tenants, current_user.id, page_number, - items_per_page, orderby, desc, keywords, canvas_category) - else: - tenants = owner_ids - canvas, total = UserCanvasService.get_by_tenant_ids( - tenants, current_user.id, 0, - 0, orderby, desc, keywords, canvas_category) - return get_json_result(data={"canvas": canvas, "total": total}) - - -@manager.route('/setting', methods=['POST']) # noqa: F821 -@validate_request("id", "title", "permission") -@login_required -async def setting(): - req = await get_request_json() - req["user_id"] = current_user.id - - if not UserCanvasService.accessible(req["id"], current_user.id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) - - e,flow = UserCanvasService.get_by_id(req["id"]) - if not e: - return get_data_error_result(message="canvas not found.") - flow = flow.to_dict() - flow["title"] = req["title"] - - for key in ["description", "permission", "avatar"]: - if value := req.get(key): - flow[key] = value - - num= UserCanvasService.update_by_id(req["id"], flow) - return get_json_result(data=num) - - -@manager.route('/trace', methods=['GET']) # noqa: F821 -def trace(): - cvs_id = request.args.get("canvas_id") - msg_id = request.args.get("message_id") - try: - binary = REDIS_CONN.get(f"{cvs_id}-{msg_id}-logs") - if not binary: - return get_json_result(data={}) - - return get_json_result(data=json.loads(binary.encode("utf-8"))) - except Exception as e: - logging.exception(e) - - -@manager.route('//sessions', methods=['GET']) # noqa: F821 -@login_required -def sessions(canvas_id): - tenant_id = current_user.id - if not UserCanvasService.accessible(canvas_id, tenant_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) - - user_id = request.args.get("user_id") - page_number = int(request.args.get("page", 1)) - items_per_page = int(request.args.get("page_size", 30)) - keywords = request.args.get("keywords") - from_date = request.args.get("from_date") - to_date = request.args.get("to_date") - orderby = request.args.get("orderby", "update_time") - exp_user_id = request.args.get("exp_user_id") - if request.args.get("desc") == "False" or request.args.get("desc") == "false": - desc = False - else: - desc = True - - if exp_user_id: - sess = API4ConversationService.get_names(canvas_id, exp_user_id) - return get_json_result(data={"total": len(sess), "sessions": sess}) - - # dsl defaults to True in all cases except for False and false - include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false" - total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc, - None, user_id, include_dsl, keywords, from_date, to_date, exp_user_id=exp_user_id) - try: - return get_json_result(data={"total": total, "sessions": sess}) - except Exception as e: - return server_error_response(e) - - -@manager.route('//sessions', methods=['PUT']) # noqa: F821 -@login_required -async def set_session(canvas_id): - req = await get_request_json() - tenant_id = current_user.id - e, cvs = UserCanvasService.get_by_id(canvas_id) - assert e, "Agent not found." - if not isinstance(cvs.dsl, str): - cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) - session_id=get_uuid() - canvas = Canvas(cvs.dsl, tenant_id, canvas_id, canvas_id=cvs.id) - canvas.reset() - # Get the version title for this canvas (using latest, not necessarily released) - version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=False) - conv = { - "id": session_id, - "name": req.get("name", ""), - "dialog_id": cvs.id, - "user_id": tenant_id, - "exp_user_id": tenant_id, - "message": [], - "source": "agent", - "dsl": cvs.dsl, - "reference": [], - "version_title": version_title - } - API4ConversationService.save(**conv) - return get_json_result(data=conv) - - -@manager.route('//sessions/', methods=['GET']) # noqa: F821 -@login_required -def get_session(canvas_id, session_id): - tenant_id = current_user.id - if not UserCanvasService.accessible(canvas_id, tenant_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) - _, conv = API4ConversationService.get_by_id(session_id) - return get_json_result(data=conv.to_dict()) - - -@manager.route('//sessions/', methods=['DELETE']) # noqa: F821 -@login_required -def del_session(canvas_id, session_id): - tenant_id = current_user.id - if not UserCanvasService.accessible(canvas_id, tenant_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) - return get_json_result(data=API4ConversationService.delete_by_id(session_id)) - - -@manager.route('/prompts', methods=['GET']) # noqa: F821 -@login_required -def prompts(): - from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE - - return get_json_result(data={ - "task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER, - "plan_generation": NEXT_STEP, - "reflection": REFLECT, - #"context_summary": SUMMARY4MEMORY, - #"context_ranking": RANK_MEMORY, - "citation_guidelines": CITATION_PROMPT_TEMPLATE - }) - - -@manager.route('/download', methods=['GET']) # noqa: F821 -async def download(): - id = request.args.get("id") - created_by = request.args.get("created_by") - blob = FileService.get_blob(created_by, id) - return await make_response(blob) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py deleted file mode 100644 index e6ceb66e695..00000000000 --- a/api/apps/chunk_app.py +++ /dev/null @@ -1,580 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import base64 -import datetime -import json -import logging -import re -import xxhash -from quart import request - -from api.db.services.document_service import DocumentService -from api.db.services.doc_metadata_service import DocMetadataService -from api.utils.image_utils import store_chunk_image -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import LLMBundle -from common.metadata_utils import apply_meta_data_filter -from api.db.services.search_service import SearchService -from api.db.services.user_service import UserTenantService -from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_tenant_default_model_by_type, get_model_config_by_type_and_name -from api.utils.api_utils import ( - get_data_error_result, - get_json_result, - server_error_response, - validate_request, - get_request_json, -) -from common.misc_utils import thread_pool_exec -from common.tag_feature_utils import validate_tag_features -from rag.app.qa import beAdoc, rmPrefix -from rag.app.tag import label_question -from rag.nlp import rag_tokenizer, search -from rag.prompts.generator import cross_languages, keyword_extraction -from common.string_utils import is_content_empty, remove_redundant_spaces -from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD -from common import settings -from api.apps import login_required, current_user - -@manager.route('/list', methods=['POST']) # noqa: F821 -@login_required -@validate_request("doc_id") -async def list_chunk(): - req = await get_request_json() - doc_id = req["doc_id"] - page = int(req.get("page", 1)) - size = int(req.get("size", 30)) - question = req.get("keywords", "") - try: - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - e, doc = DocumentService.get_by_id(doc_id) - if not e: - return get_data_error_result(message="Document not found!") - kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - query = { - "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True - } - if "available_int" in req: - query["available_int"] = int(req["available_int"]) - sres = await settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"]) - res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} - for id in sres.ids: - d = { - "chunk_id": id, - "content_with_weight": remove_redundant_spaces(sres.highlight[id]) if question and id in sres.highlight else sres.field[ - id].get( - "content_with_weight", ""), - "doc_id": sres.field[id]["doc_id"], - "docnm_kwd": sres.field[id]["docnm_kwd"], - "important_kwd": sres.field[id].get("important_kwd", []), - "question_kwd": sres.field[id].get("question_kwd", []), - "image_id": sres.field[id].get("img_id", ""), - "available_int": int(sres.field[id].get("available_int", 1)), - "positions": sres.field[id].get("position_int", []), - "doc_type_kwd": sres.field[id].get("doc_type_kwd") - } - assert isinstance(d["positions"], list) - assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5) - res["chunks"].append(d) - return get_json_result(data=res) - except Exception as e: - if str(e).find("not_found") > 0: - return get_json_result(data=False, message='No chunk found!', - code=RetCode.DATA_ERROR) - return server_error_response(e) - - -@manager.route('/get', methods=['GET']) # noqa: F821 -@login_required -def get(): - chunk_id = request.args["chunk_id"] - try: - chunk = None - tenants = UserTenantService.query(user_id=current_user.id) - if not tenants: - return get_data_error_result(message="Tenant not found!") - for tenant in tenants: - kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id) - chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids) - if chunk: - break - if chunk is None: - return server_error_response(Exception("Chunk not found")) - - k = [] - for n in chunk.keys(): - if re.search(r"(_vec$|_sm_|_tks|_ltks)", n): - k.append(n) - for n in k: - del chunk[n] - - return get_json_result(data=chunk) - except Exception as e: - if str(e).find("NotFoundError") >= 0: - return get_json_result(data=False, message='Chunk not found!', - code=RetCode.DATA_ERROR) - return server_error_response(e) - - -@manager.route('/set', methods=['POST']) # noqa: F821 -@login_required -@validate_request("doc_id", "chunk_id", "content_with_weight") -async def set(): - req = await get_request_json() - content_with_weight = req["content_with_weight"] - if not isinstance(content_with_weight, (str, bytes)): - raise TypeError("expected string or bytes-like object") - if isinstance(content_with_weight, bytes): - content_with_weight = content_with_weight.decode("utf-8", errors="ignore") - if is_content_empty(content_with_weight): - return get_data_error_result(message="`content_with_weight` is required") - d = { - "id": req["chunk_id"], - "content_with_weight": content_with_weight} - d["content_ltks"] = rag_tokenizer.tokenize(content_with_weight) - d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - if "important_kwd" in req: - if not isinstance(req["important_kwd"], list): - return get_data_error_result(message="`important_kwd` should be a list") - d["important_kwd"] = req["important_kwd"] - d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) - if "question_kwd" in req: - if not isinstance(req["question_kwd"], list): - return get_data_error_result(message="`question_kwd` should be a list") - d["question_kwd"] = req["question_kwd"] - d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) - if "tag_kwd" in req: - if not isinstance(req["tag_kwd"], list): - return get_data_error_result(message="`tag_kwd` should be a list") - if not all(isinstance(t, str) for t in req["tag_kwd"]): - return get_data_error_result(message="`tag_kwd` must be a list of strings") - d["tag_kwd"] = req["tag_kwd"] - if "tag_feas" in req: - try: - d["tag_feas"] = validate_tag_features(req["tag_feas"]) - except ValueError as exc: - return get_data_error_result(message=f"`tag_feas` {exc}") - if "available_int" in req: - d["available_int"] = req["available_int"] - - try: - def _set_sync(): - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - - tenant_embd_id = DocumentService.get_tenant_embd_id(req["doc_id"]) - if tenant_embd_id: - embd_model_config = get_model_config_by_id(tenant_embd_id) - else: - embd_id = DocumentService.get_embd_id(req["doc_id"]) - if embd_id: - embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id) - else: - embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING) - embd_mdl = LLMBundle(tenant_id, embd_model_config) - - _d = d - if doc.parser_id == ParserType.QA: - arr = [ - t for t in re.split( - r"[\n\t]", - req["content_with_weight"]) if len(t) > 1] - q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) - _d = beAdoc(d, q, a, not any( - [rag_tokenizer.is_chinese(t) for t in q + a])) - - v, c = embd_mdl.encode([doc.name, content_with_weight if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])]) - v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] - _d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id) - - # update image - image_base64 = req.get("image_base64", None) - img_id = req.get("img_id", "") - if image_base64 and img_id and "-" in img_id: - bkt, name = img_id.split("-", 1) - image_binary = base64.b64decode(image_base64) - settings.STORAGE_IMPL.put(bkt, name, image_binary) - return get_json_result(data=True) - - return await thread_pool_exec(_set_sync) - except Exception as e: - return server_error_response(e) - - -@manager.route('/switch', methods=['POST']) # noqa: F821 -@login_required -@validate_request("chunk_ids", "available_int", "doc_id") -async def switch(): - req = await get_request_json() - try: - def _switch_sync(): - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - for cid in req["chunk_ids"]: - if not settings.docStoreConn.update({"id": cid}, - {"available_int": int(req["available_int"])}, - search.index_name(DocumentService.get_tenant_id(req["doc_id"])), - doc.kb_id): - return get_data_error_result(message="Index updating failure") - return get_json_result(data=True) - - return await thread_pool_exec(_switch_sync) - except Exception as e: - return server_error_response(e) - - -@manager.route('/rm', methods=['POST']) # noqa: F821 -@login_required -@validate_request("doc_id") -async def rm(): - req = await get_request_json() - try: - def _rm_sync(): - deleted_chunk_ids = req.get("chunk_ids") - if isinstance(deleted_chunk_ids, list): - unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids)) - has_ids = len(unique_chunk_ids) > 0 - elif deleted_chunk_ids is not None: - unique_chunk_ids = [deleted_chunk_ids] - has_ids = deleted_chunk_ids not in (None, "") - else: - unique_chunk_ids = [] - has_ids = False - if not has_ids: - if req.get("delete_all") is True: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) - # Clean up storage assets while index rows still exist for discovery - DocumentService.delete_chunk_images(doc, tenant_id) - condition = {"doc_id": req["doc_id"]} - try: - deleted_count = settings.docStoreConn.delete(condition, search.index_name(tenant_id), doc.kb_id) - except Exception: - return get_data_error_result(message="Chunk deleting failure") - if deleted_count > 0: - DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, deleted_count, 0) - return get_json_result(data=True) - return get_json_result(data=True) - - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]} - try: - deleted_count = settings.docStoreConn.delete(condition, - search.index_name(DocumentService.get_tenant_id(req["doc_id"])), - doc.kb_id) - except Exception: - return get_data_error_result(message="Chunk deleting failure") - if has_ids and deleted_count == 0: - return get_data_error_result(message="Index updating failure") - if deleted_count > 0 and deleted_count < len(unique_chunk_ids): - deleted_count += settings.docStoreConn.delete({"doc_id": req["doc_id"]}, - search.index_name(DocumentService.get_tenant_id(req["doc_id"])), - doc.kb_id) - chunk_number = deleted_count - DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) - for cid in deleted_chunk_ids: - if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): - settings.STORAGE_IMPL.rm(doc.kb_id, cid) - return get_json_result(data=True) - - return await thread_pool_exec(_rm_sync) - except Exception as e: - return server_error_response(e) - - -@manager.route('/create', methods=['POST']) # noqa: F821 -@login_required -@validate_request("doc_id", "content_with_weight") -async def create(): - req = await get_request_json() - req_id = request.headers.get("X-Request-ID") - chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest() - d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), - "content_with_weight": req["content_with_weight"]} - d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - d["important_kwd"] = req.get("important_kwd", []) - if not isinstance(d["important_kwd"], list): - return get_data_error_result(message="`important_kwd` is required to be a list") - d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) - d["question_kwd"] = req.get("question_kwd", []) - if not isinstance(d["question_kwd"], list): - return get_data_error_result(message="`question_kwd` is required to be a list") - d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"])) - d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] - d["create_timestamp_flt"] = datetime.datetime.now().timestamp() - if "tag_kwd" in req: - if not isinstance(req["tag_kwd"], list): - return get_data_error_result(message="`tag_kwd` is required to be a list") - if not all(isinstance(t, str) for t in req["tag_kwd"]): - return get_data_error_result(message="`tag_kwd` must be a list of strings") - d["tag_kwd"] = req["tag_kwd"] - if "tag_feas" in req: - try: - d["tag_feas"] = validate_tag_features(req["tag_feas"]) - except ValueError as exc: - return get_data_error_result(message=f"`tag_feas` {exc}") - image_base64 = req.get("image_base64", None) - - try: - def _log_response(resp, code, message): - logging.info( - "chunk_create response req_id=%s status=%s code=%s message=%s", - req_id, - getattr(resp, "status_code", None), - code, - message, - ) - - def _create_sync(): - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - resp = get_data_error_result(message="Document not found!") - _log_response(resp, RetCode.DATA_ERROR, "Document not found!") - return resp - d["kb_id"] = [doc.kb_id] - d["docnm_kwd"] = doc.name - d["title_tks"] = rag_tokenizer.tokenize(doc.name) - d["doc_id"] = doc.id - - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) - if not tenant_id: - resp = get_data_error_result(message="Tenant not found!") - _log_response(resp, RetCode.DATA_ERROR, "Tenant not found!") - return resp - - e, kb = KnowledgebaseService.get_by_id(doc.kb_id) - if not e: - resp = get_data_error_result(message="Knowledgebase not found!") - _log_response(resp, RetCode.DATA_ERROR, "Knowledgebase not found!") - return resp - if kb.pagerank: - d[PAGERANK_FLD] = kb.pagerank - - tenant_embd_id = DocumentService.get_tenant_embd_id(req["doc_id"]) - if tenant_embd_id: - embd_model_config = get_model_config_by_id(tenant_embd_id) - else: - embd_id = DocumentService.get_embd_id(req["doc_id"]) - if embd_id: - embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id) - else: - embd_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.EMBEDDING) - embd_mdl = LLMBundle(tenant_id, embd_model_config) - - if image_base64: - d["img_id"] = "{}-{}".format(doc.kb_id, chunck_id) - d["doc_type_kwd"] = "image" - - v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) - v = 0.1 * v[0] + 0.9 * v[1] - d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) - - if image_base64: - store_chunk_image(doc.kb_id, chunck_id, base64.b64decode(image_base64)) - - DocumentService.increment_chunk_num( - doc.id, doc.kb_id, c, 1, 0) - resp = get_json_result(data={"chunk_id": chunck_id, "image_id": d.get("img_id", "")}) - _log_response(resp, RetCode.SUCCESS, "success") - return resp - - return await thread_pool_exec(_create_sync) - except Exception as e: - logging.info("chunk_create exception req_id=%s error=%r", req_id, e) - return server_error_response(e) - - -@manager.route('/retrieval_test', methods=['POST']) # noqa: F821 -@login_required -@validate_request("kb_id", "question") -async def retrieval_test(): - req = await get_request_json() - page = int(req.get("page", 1)) - size = int(req.get("size", 30)) - question = req["question"] - kb_ids = req["kb_id"] - if isinstance(kb_ids, str): - kb_ids = [kb_ids] - if not kb_ids: - return get_json_result(data=False, message='Please specify dataset firstly.', - code=RetCode.DATA_ERROR) - - doc_ids = req.get("doc_ids", []) - use_kg = req.get("use_kg", False) - top = int(req.get("top_k", 1024)) - langs = req.get("cross_languages", []) - user_id = current_user.id - - async def _retrieval(): - local_doc_ids = list(doc_ids) if doc_ids else [] - tenant_ids = [] - - meta_data_filter = {} - chat_mdl = None - if req.get("search_id", ""): - search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) - meta_data_filter = search_config.get("meta_data_filter", {}) - if meta_data_filter.get("method") in ["auto", "semi_auto"]: - chat_id = search_config.get("chat_id", "") - if chat_id: - chat_model_config = get_model_config_by_type_and_name(user_id, LLMType.CHAT, search_config["chat_id"]) - else: - chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT) - chat_mdl = LLMBundle(user_id, chat_model_config) - else: - meta_data_filter = req.get("meta_data_filter") or {} - if meta_data_filter.get("method") in ["auto", "semi_auto"]: - chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT) - chat_mdl = LLMBundle(user_id, chat_model_config) - - if meta_data_filter: - metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) - local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids) - - tenants = UserTenantService.query(user_id=user_id) - for kb_id in kb_ids: - for tenant in tenants: - if KnowledgebaseService.query( - tenant_id=tenant.tenant_id, id=kb_id): - tenant_ids.append(tenant.tenant_id) - break - else: - return get_json_result( - data=False, message='Only owner of dataset authorized for this operation.', - code=RetCode.OPERATING_ERROR) - - e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) - if not e: - return get_data_error_result(message="Knowledgebase not found!") - - _question = question - if langs: - _question = await cross_languages(kb.tenant_id, None, _question, langs) - if kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) - elif kb.embd_id: - embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) - else: - embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING) - embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) - - rerank_mdl = None - if req.get("tenant_rerank_id"): - rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - elif req.get("rerank_id"): - rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"]) - rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) - - if req.get("keyword", False): - default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config) - _question += await keyword_extraction(chat_mdl, _question) - - labels = label_question(_question, [kb]) - ranks = await settings.retriever.retrieval( - _question, - embd_mdl, - tenant_ids, - kb_ids, - page, - size, - float(req.get("similarity_threshold", 0.0)), - float(req.get("vector_similarity_weight", 0.3)), - doc_ids=local_doc_ids, - top=top, - rerank_mdl=rerank_mdl, - rank_feature=labels - ) - - if use_kg: - default_chat_model_config = get_tenant_default_model_by_type(user_id, LLMType.CHAT) - ck = await settings.kg_retriever.retrieval(_question, - tenant_ids, - kb_ids, - embd_mdl, - LLMBundle(kb.tenant_id, default_chat_model_config)) - if ck["content_with_weight"]: - ranks["chunks"].insert(0, ck) - ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids) - - for c in ranks["chunks"]: - c.pop("vector", None) - ranks["labels"] = labels - - return get_json_result(data=ranks) - - try: - return await _retrieval() - except Exception as e: - if str(e).find("not_found") > 0: - return get_json_result(data=False, message='No chunk found! Check the chunk status please!', - code=RetCode.DATA_ERROR) - return server_error_response(e) - - -@manager.route('/knowledge_graph', methods=['GET']) # noqa: F821 -@login_required -async def knowledge_graph(): - doc_id = request.args["doc_id"] - tenant_id = DocumentService.get_tenant_id(doc_id) - kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - req = { - "doc_ids": [doc_id], - "knowledge_graph_kwd": ["graph", "mind_map"] - } - sres = await settings.retriever.search(req, search.index_name(tenant_id), kb_ids) - obj = {"graph": {}, "mind_map": {}} - for id in sres.ids[:2]: - ty = sres.field[id]["knowledge_graph_kwd"] - try: - content_json = json.loads(sres.field[id]["content_with_weight"]) - except Exception: - continue - - if ty == 'mind_map': - node_dict = {} - - def repeat_deal(content_json, node_dict): - if 'id' in content_json: - if content_json['id'] in node_dict: - node_name = content_json['id'] - content_json['id'] += f"({node_dict[content_json['id']]})" - node_dict[node_name] += 1 - else: - node_dict[content_json['id']] = 1 - if 'children' in content_json and content_json['children']: - for item in content_json['children']: - repeat_deal(item, node_dict) - - repeat_deal(content_json, node_dict) - - obj[ty] = content_json - - return get_json_result(data=obj) diff --git a/api/apps/document_app.py b/api/apps/document_app.py deleted file mode 100644 index 9a9cafb9b1c..00000000000 --- a/api/apps/document_app.py +++ /dev/null @@ -1,716 +0,0 @@ -# -# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -# -import os.path -import re -from pathlib import Path, PurePosixPath, PureWindowsPath - -from quart import make_response, request - -from api.apps import current_user, login_required -from api.common.check_team_permission import check_kb_team_permission -from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX -from api.db import VALID_FILE_TYPES, FileType -from api.db.db_models import Task -from api.db.services import duplicate_name -from api.db.services.doc_metadata_service import DocMetadataService -from api.db.services.document_service import DocumentService, doc_upload_and_parse -from api.db.services.file2document_service import File2DocumentService -from api.db.services.file_service import FileService -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.task_service import TaskService, cancel_all_task_of -from api.db.services.user_service import UserTenantService -from api.utils.api_utils import ( - get_data_error_result, - get_json_result, - get_request_json, - server_error_response, - validate_request, -) -from api.utils.file_utils import filename_type, thumbnail -from api.utils.web_utils import CONTENT_TYPE_MAP, apply_safe_file_response_headers, html2pdf, is_valid_url -from common import settings -from common.constants import SANDBOX_ARTIFACT_BUCKET, VALID_TASK_STATUS, ParserType, RetCode, TaskStatus -from common.file_utils import get_project_base_directory -from common.misc_utils import get_uuid, thread_pool_exec -from deepdoc.parser.html_parser import RAGFlowHtmlParser -from rag.nlp import search - - -def _is_safe_download_filename(name: str) -> bool: - if not name or name in {".", ".."}: - return False - if "\x00" in name or len(name) > 255: - return False - if name != PurePosixPath(name).name: - return False - if name != PureWindowsPath(name).name: - return False - return True - - -@manager.route("/web_crawl", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("kb_id", "name", "url") -async def web_crawl(): - form = await request.form - kb_id = form.get("kb_id") - if not kb_id: - return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - name = form.get("name") - url = form.get("url") - if not is_valid_url(url): - return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR) - e, kb = KnowledgebaseService.get_by_id(kb_id) - if not e: - raise LookupError("Can't find this dataset!") - if not check_kb_team_permission(kb, current_user.id): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - - blob = html2pdf(url) - if not blob: - return server_error_response(ValueError("Download failure.")) - - root_folder = FileService.get_root_folder(current_user.id) - pf_id = root_folder["id"] - FileService.init_knowledgebase_docs(pf_id, current_user.id) - kb_root_folder = FileService.get_kb_folder(current_user.id) - kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) - - try: - filename = duplicate_name(DocumentService.query, name=name + ".pdf", kb_id=kb.id) - filetype = filename_type(filename) - if filetype == FileType.OTHER.value: - raise RuntimeError("This type of file has not been supported yet!") - - location = filename - while settings.STORAGE_IMPL.obj_exist(kb_id, location): - location += "_" - settings.STORAGE_IMPL.put(kb_id, location, blob) - doc = { - "id": get_uuid(), - "kb_id": kb.id, - "parser_id": kb.parser_id, - "parser_config": kb.parser_config, - "created_by": current_user.id, - "type": filetype, - "name": filename, - "location": location, - "size": len(blob), - "thumbnail": thumbnail(filename, blob), - "suffix": Path(filename).suffix.lstrip("."), - } - if doc["type"] == FileType.VISUAL: - doc["parser_id"] = ParserType.PICTURE.value - if doc["type"] == FileType.AURAL: - doc["parser_id"] = ParserType.AUDIO.value - if re.search(r"\.(ppt|pptx|pages)$", filename): - doc["parser_id"] = ParserType.PRESENTATION.value - if re.search(r"\.(eml)$", filename): - doc["parser_id"] = ParserType.EMAIL.value - DocumentService.insert(doc) - FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id) - except Exception as e: - return server_error_response(e) - return get_json_result(data=True) - - -@manager.route("/create", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("name", "kb_id") -async def create(): - req = await get_request_json() - kb_id = req["kb_id"] - if not kb_id: - return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT: - return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR) - - if req["name"].strip() == "": - return get_json_result(data=False, message="File name can't be empty.", code=RetCode.ARGUMENT_ERROR) - req["name"] = req["name"].strip() - - try: - e, kb = KnowledgebaseService.get_by_id(kb_id) - if not e: - return get_data_error_result(message="Can't find this dataset!") - - if DocumentService.query(name=req["name"], kb_id=kb_id): - return get_data_error_result(message="Duplicated document name in the same dataset.") - - kb_root_folder = FileService.get_kb_folder(kb.tenant_id) - if not kb_root_folder: - return get_data_error_result(message="Cannot find the root folder.") - kb_folder = FileService.new_a_file_from_kb( - kb.tenant_id, - kb.name, - kb_root_folder["id"], - ) - if not kb_folder: - return get_data_error_result(message="Cannot find the kb folder for this file.") - - doc = DocumentService.insert( - { - "id": get_uuid(), - "kb_id": kb.id, - "parser_id": kb.parser_id, - "pipeline_id": kb.pipeline_id, - "parser_config": kb.parser_config, - "created_by": current_user.id, - "type": FileType.VIRTUAL, - "name": req["name"], - "suffix": Path(req["name"]).suffix.lstrip("."), - "location": "", - "size": 0, - } - ) - - FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], kb.tenant_id) - - return get_json_result(data=doc.to_json()) - except Exception as e: - return server_error_response(e) - - -@manager.route("/filter", methods=["POST"]) # noqa: F821 -@login_required -async def get_filter(): - req = await get_request_json() - - kb_id = req.get("kb_id") - if not kb_id: - return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - tenants = UserTenantService.query(user_id=current_user.id) - for tenant in tenants: - if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): - break - else: - return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR) - - keywords = req.get("keywords", "") - - suffix = req.get("suffix", []) - - run_status = req.get("run_status", []) - if run_status: - invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS} - if invalid_status: - return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}") - - types = req.get("types", []) - if types: - invalid_types = {t for t in types if t not in VALID_FILE_TYPES} - if invalid_types: - return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}") - - try: - filter, total = DocumentService.get_filter_by_kb_id(kb_id, keywords, run_status, types, suffix) - return get_json_result(data={"total": total, "filter": filter}) - except Exception as e: - return server_error_response(e) - - -@manager.route("/infos", methods=["POST"]) # noqa: F821 -@login_required -async def doc_infos(): - req = await get_request_json() - doc_ids = req["doc_ids"] - for doc_id in doc_ids: - if not DocumentService.accessible(doc_id, current_user.id): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - docs = DocumentService.get_by_ids(doc_ids) - docs_list = list(docs.dicts()) - # Add meta_fields for each document - for doc in docs_list: - doc["meta_fields"] = DocMetadataService.get_document_metadata(doc["id"]) - return get_json_result(data=docs_list) - - -@manager.route("/metadata/update", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("doc_ids") -async def metadata_update(): - req = await get_request_json() - kb_id = req.get("kb_id") - document_ids = req.get("doc_ids") - updates = req.get("updates", []) or [] - deletes = req.get("deletes", []) or [] - - if not kb_id: - return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - - if not isinstance(updates, list) or not isinstance(deletes, list): - return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR) - - for upd in updates: - if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd: - return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR) - for d in deletes: - if not isinstance(d, dict) or not d.get("key"): - return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR) - - updated = DocMetadataService.batch_update_metadata(kb_id, document_ids, updates, deletes) - return get_json_result(data={"updated": updated, "matched_docs": len(document_ids)}) - - -@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("doc_id", "metadata") -async def update_metadata_setting(): - req = await get_request_json() - if not DocumentService.accessible(req["doc_id"], current_user.id): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - - DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]}) - e, doc = DocumentService.get_by_id(doc.id) - if not e: - return get_data_error_result(message="Document not found!") - - return get_json_result(data=doc.to_dict()) - - -@manager.route("/thumbnails", methods=["GET"]) # noqa: F821 -# @login_required -def thumbnails(): - doc_ids = request.args.getlist("doc_ids") - if not doc_ids: - return get_json_result(data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR) - - try: - docs = DocumentService.get_thumbnails(doc_ids) - - for doc_item in docs: - if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX): - doc_item["thumbnail"] = f"/v1/document/image/{doc_item['kb_id']}-{doc_item['thumbnail']}" - - return get_json_result(data={d["id"]: d["thumbnail"] for d in docs}) - except Exception as e: - return server_error_response(e) - - -@manager.route("/change_status", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("doc_ids", "status") -async def change_status(): - req = await get_request_json() - doc_ids = req.get("doc_ids", []) - status = str(req.get("status", "")) - - if status not in ["0", "1"]: - return get_json_result(data=False, message='"Status" must be either 0 or 1!', code=RetCode.ARGUMENT_ERROR) - - result = {} - has_error = False - for doc_id in doc_ids: - if not DocumentService.accessible(doc_id, current_user.id): - result[doc_id] = {"error": "No authorization."} - has_error = True - continue - - try: - e, doc = DocumentService.get_by_id(doc_id) - if not e: - result[doc_id] = {"error": "No authorization."} - has_error = True - continue - e, kb = KnowledgebaseService.get_by_id(doc.kb_id) - if not e: - result[doc_id] = {"error": "Can't find this dataset!"} - has_error = True - continue - current_status = str(doc.status) - if current_status == status: - result[doc_id] = {"status": status} - continue - if not DocumentService.update_by_id(doc_id, {"status": str(status)}): - result[doc_id] = {"error": "Database error (Document update)!"} - has_error = True - continue - - status_int = int(status) - if getattr(doc, "chunk_num", 0) > 0: - try: - ok = settings.docStoreConn.update( - {"doc_id": doc_id}, - {"available_int": status_int}, - search.index_name(kb.tenant_id), - doc.kb_id, - ) - except Exception as exc: - msg = str(exc) - if "3022" in msg: - result[doc_id] = {"error": "Document store table missing."} - else: - result[doc_id] = {"error": f"Document store update failed: {msg}"} - has_error = True - continue - if not ok: - result[doc_id] = {"error": "Database error (docStore update)!"} - has_error = True - continue - result[doc_id] = {"status": status} - except Exception as e: - result[doc_id] = {"error": f"Internal server error: {str(e)}"} - has_error = True - - if has_error: - return get_json_result(data=result, message="Partial failure", code=RetCode.SERVER_ERROR) - return get_json_result(data=result) - - -@manager.route("/rm", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("doc_id") -async def rm(): - req = await get_request_json() - doc_ids = req["doc_id"] - if isinstance(doc_ids, str): - doc_ids = [doc_ids] - - for doc_id in doc_ids: - if not DocumentService.accessible4deletion(doc_id, current_user.id): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - - errors = await thread_pool_exec(FileService.delete_docs, doc_ids, current_user.id) - - if errors: - return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) - - return get_json_result(data=True) - - -@manager.route("/run", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("doc_ids", "run") -async def run(): - req = await get_request_json() - uid = current_user.id - try: - - def _run_sync(): - for doc_id in req["doc_ids"]: - if not DocumentService.accessible(doc_id, uid): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - - kb_table_num_map = {} - for id in req["doc_ids"]: - info = {"run": str(req["run"]), "progress": 0} - if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False): - info["progress_msg"] = "" - info["chunk_num"] = 0 - info["token_num"] = 0 - - tenant_id = DocumentService.get_tenant_id(id) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - e, doc = DocumentService.get_by_id(id) - if not e: - return get_data_error_result(message="Document not found!") - - if str(req["run"]) == TaskStatus.CANCEL.value: - tasks = list(TaskService.query(doc_id=id)) - has_unfinished_task = any((task.progress or 0) < 1 for task in tasks) - if str(doc.run) in [TaskStatus.RUNNING.value, TaskStatus.CANCEL.value] or has_unfinished_task: - cancel_all_task_of(id) - else: - return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status") - if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]): - DocumentService.clear_chunk_num_when_rerun(doc.id) - - DocumentService.update_by_id(id, info) - if req.get("delete", False): - TaskService.filter_delete([Task.doc_id == id]) - if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id): - settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) - - if str(req["run"]) == TaskStatus.RUNNING.value: - if req.get("apply_kb"): - e, kb = KnowledgebaseService.get_by_id(doc.kb_id) - if not e: - raise LookupError("Can't find this dataset!") - doc.parser_config["llm_id"] = kb.parser_config.get("llm_id") - doc.parser_config["enable_metadata"] = kb.parser_config.get("enable_metadata", False) - doc.parser_config["metadata"] = kb.parser_config.get("metadata", {}) - DocumentService.update_parser_config(doc.id, doc.parser_config) - doc_dict = doc.to_dict() - DocumentService.run(tenant_id, doc_dict, kb_table_num_map) - - return get_json_result(data=True) - - return await thread_pool_exec(_run_sync) - except Exception as e: - return server_error_response(e) - -@manager.route("/get/", methods=["GET"]) # noqa: F821 -@login_required -async def get(doc_id): - try: - e, doc = DocumentService.get_by_id(doc_id) - if not e: - return get_data_error_result(message="Document not found!") - - b, n = File2DocumentService.get_storage_address(doc_id=doc_id) - data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n) - response = await make_response(data) - - ext = re.search(r"\.([^.]+)$", doc.name.lower()) - ext = ext.group(1) if ext else None - content_type = None - if ext: - fallback_prefix = "image" if doc.type == FileType.VISUAL.value else "application" - content_type = CONTENT_TYPE_MAP.get(ext, f"{fallback_prefix}/{ext}") - apply_safe_file_response_headers(response, content_type, ext) - return response - except Exception as e: - return server_error_response(e) - - -@manager.route("/download/", methods=["GET"]) # noqa: F821 -@login_required -async def download_attachment(attachment_id): - try: - ext = request.args.get("ext", "markdown") - data = await thread_pool_exec(settings.STORAGE_IMPL.get, current_user.id, attachment_id) - response = await make_response(data) - content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}") - apply_safe_file_response_headers(response, content_type, ext) - - return response - - except Exception as e: - return server_error_response(e) - - -@manager.route("/change_parser", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("doc_id") -async def change_parser(): - req = await get_request_json() - if not DocumentService.accessible(req["doc_id"], current_user.id): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - - def reset_doc(): - nonlocal doc - e = DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"], "parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value}) - if not e: - return get_data_error_result(message="Document not found!") - if doc.token_num > 0: - e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1) - if not e: - return get_data_error_result(message="Document not found!") - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - DocumentService.delete_chunk_images(doc, tenant_id) - if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id): - settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) - return None - - try: - if "pipeline_id" in req and req["pipeline_id"] != "": - if doc.pipeline_id == req["pipeline_id"]: - return get_json_result(data=True) - DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]}) - reset_doc() - return get_json_result(data=True) - - if doc.parser_id.lower() == req["parser_id"].lower(): - if "parser_config" in req: - if req["parser_config"] == doc.parser_config: - return get_json_result(data=True) - else: - return get_json_result(data=True) - - if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"): - return get_data_error_result(message="Not supported yet!") - if "parser_config" in req: - DocumentService.update_parser_config(doc.id, req["parser_config"]) - reset_doc() - return get_json_result(data=True) - except Exception as e: - return server_error_response(e) - - -@manager.route("/image/", methods=["GET"]) # noqa: F821 -# @login_required -async def get_image(image_id): - try: - arr = image_id.split("-") - if len(arr) != 2: - return get_data_error_result(message="Image not found.") - bkt, nm = image_id.split("-") - data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm) - response = await make_response(data) - response.headers.set("Content-Type", "image/JPEG") - return response - except Exception as e: - return server_error_response(e) - - -ARTIFACT_CONTENT_TYPES = { - ".png": "image/png", - ".jpg": "image/jpeg", - ".jpeg": "image/jpeg", - ".svg": "image/svg+xml", - ".pdf": "application/pdf", - ".csv": "text/csv", - ".json": "application/json", - ".html": "text/html", -} - - -@manager.route("/artifact/", methods=["GET"]) # noqa: F821 -@login_required -async def get_artifact(filename): - try: - bucket = SANDBOX_ARTIFACT_BUCKET - # Validate filename: must be uuid hex + allowed extension, nothing else - basename = os.path.basename(filename) - if basename != filename or "/" in filename or "\\" in filename: - return get_data_error_result(message="Invalid filename.") - ext = os.path.splitext(basename)[1].lower() - if ext not in ARTIFACT_CONTENT_TYPES: - return get_data_error_result(message="Invalid file type.") - data = await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, basename) - if not data: - return get_data_error_result(message="Artifact not found.") - content_type = ARTIFACT_CONTENT_TYPES.get(ext, "application/octet-stream") - response = await make_response(data) - safe_filename = re.sub(r"[^\w.\-]", "_", basename) - apply_safe_file_response_headers(response, content_type, ext) - if not response.headers.get("Content-Disposition"): - response.headers.set("Content-Disposition", f'inline; filename="{safe_filename}"') - return response - except Exception as e: - return server_error_response(e) - - -@manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("conversation_id") -async def upload_and_parse(): - files = await request.files - if "file" not in files: - return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR) - - file_objs = files.getlist("file") - for file_obj in file_objs: - if file_obj.filename == "": - return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR) - - form = await request.form - doc_ids = doc_upload_and_parse(form.get("conversation_id"), file_objs, current_user.id) - return get_json_result(data=doc_ids) - - -@manager.route("/parse", methods=["POST"]) # noqa: F821 -@login_required -async def parse(): - req = await get_request_json() - url = req.get("url", "") - if url: - if not is_valid_url(url): - return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR) - download_path = os.path.join(get_project_base_directory(), "logs/downloads") - os.makedirs(download_path, exist_ok=True) - from seleniumwire.webdriver import Chrome, ChromeOptions - - options = ChromeOptions() - options.add_argument("--headless") - options.add_argument("--disable-gpu") - options.add_argument("--no-sandbox") - options.add_argument("--disable-dev-shm-usage") - options.add_experimental_option("prefs", {"download.default_directory": download_path, "download.prompt_for_download": False, "download.directory_upgrade": True, "safebrowsing.enabled": True}) - driver = Chrome(options=options) - driver.get(url) - res_headers = [r.response.headers for r in driver.requests if r and r.response] - if len(res_headers) > 1: - sections = RAGFlowHtmlParser().parser_txt(driver.page_source) - driver.quit() - return get_json_result(data="\n".join(sections)) - - class File: - filename: str - filepath: str - - def __init__(self, filename, filepath): - self.filename = filename - self.filepath = filepath - - def read(self): - with open(self.filepath, "rb") as f: - return f.read() - - r = re.search(r"filename=\"([^\"]+)\"", str(res_headers)) - if not r or not r.group(1): - return get_json_result(data=False, message="Can't not identify downloaded file", code=RetCode.ARGUMENT_ERROR) - filename = r.group(1).strip() - if not _is_safe_download_filename(filename): - return get_json_result(data=False, message="Invalid downloaded filename", code=RetCode.ARGUMENT_ERROR) - filepath = os.path.join(download_path, filename) - f = File(filename, filepath) - txt = FileService.parse_docs([f], current_user.id) - return get_json_result(data=txt) - - files = await request.files - if "file" not in files: - return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR) - - file_objs = files.getlist("file") - txt = FileService.parse_docs(file_objs, current_user.id) - - return get_json_result(data=txt) - - -@manager.route("/upload_info", methods=["POST"]) # noqa: F821 -@login_required -async def upload_info(): - files = await request.files - file_objs = files.getlist("file") if files and files.get("file") else [] - url = request.args.get("url") - - if file_objs and url: - return get_json_result( - data=False, - message="Provide either multipart file(s) or ?url=..., not both.", - code=RetCode.BAD_REQUEST, - ) - - if not file_objs and not url: - return get_json_result( - data=False, - message="Missing input: provide multipart file(s) or url", - code=RetCode.BAD_REQUEST, - ) - - try: - if url and not file_objs: - return get_json_result(data=FileService.upload_info(current_user.id, None, url)) - - if len(file_objs) == 1: - return get_json_result(data=FileService.upload_info(current_user.id, file_objs[0], None)) - - results = [FileService.upload_info(current_user.id, f, None) for f in file_objs] - return get_json_result(data=results) - except Exception as e: - return server_error_response(e) diff --git a/api/apps/evaluation_app.py b/api/apps/evaluation_app.py deleted file mode 100644 index b33db26da17..00000000000 --- a/api/apps/evaluation_app.py +++ /dev/null @@ -1,479 +0,0 @@ -# -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -RAG Evaluation API Endpoints - -Provides REST API for RAG evaluation functionality including: -- Dataset management -- Test case management -- Evaluation execution -- Results retrieval -- Configuration recommendations -""" - -from quart import request -from api.apps import login_required, current_user -from api.db.services.evaluation_service import EvaluationService -from api.utils.api_utils import ( - get_data_error_result, - get_json_result, - get_request_json, - server_error_response, - validate_request -) -from common.constants import RetCode - - -# ==================== Dataset Management ==================== - -@manager.route('/dataset/create', methods=['POST']) # noqa: F821 -@login_required -@validate_request("name", "kb_ids") -async def create_dataset(): - """ - Create a new evaluation dataset. - - Request body: - { - "name": "Dataset name", - "description": "Optional description", - "kb_ids": ["kb_id1", "kb_id2"] - } - """ - try: - req = await get_request_json() - name = req.get("name", "").strip() - description = req.get("description", "") - kb_ids = req.get("kb_ids", []) - - if not name: - return get_data_error_result(message="Dataset name cannot be empty") - - if not kb_ids or not isinstance(kb_ids, list): - return get_data_error_result(message="kb_ids must be a non-empty list") - - success, result = EvaluationService.create_dataset( - name=name, - description=description, - kb_ids=kb_ids, - tenant_id=current_user.id, - user_id=current_user.id - ) - - if not success: - return get_data_error_result(message=result) - - return get_json_result(data={"dataset_id": result}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/dataset/list', methods=['GET']) # noqa: F821 -@login_required -async def list_datasets(): - """ - List evaluation datasets for current tenant. - - Query params: - - page: Page number (default: 1) - - page_size: Items per page (default: 20) - """ - try: - page = int(request.args.get("page", 1)) - page_size = int(request.args.get("page_size", 20)) - - result = EvaluationService.list_datasets( - tenant_id=current_user.id, - user_id=current_user.id, - page=page, - page_size=page_size - ) - - return get_json_result(data=result) - except Exception as e: - return server_error_response(e) - - -@manager.route('/dataset/', methods=['GET']) # noqa: F821 -@login_required -async def get_dataset(dataset_id): - """Get dataset details by ID""" - try: - dataset = EvaluationService.get_dataset(dataset_id) - if not dataset: - return get_data_error_result( - message="Dataset not found", - code=RetCode.DATA_ERROR - ) - - return get_json_result(data=dataset) - except Exception as e: - return server_error_response(e) - - -@manager.route('/dataset/', methods=['PUT']) # noqa: F821 -@login_required -async def update_dataset(dataset_id): - """ - Update dataset. - - Request body: - { - "name": "New name", - "description": "New description", - "kb_ids": ["kb_id1", "kb_id2"] - } - """ - try: - req = await get_request_json() - - # Remove fields that shouldn't be updated - req.pop("id", None) - req.pop("tenant_id", None) - req.pop("created_by", None) - req.pop("create_time", None) - - success = EvaluationService.update_dataset(dataset_id, **req) - - if not success: - return get_data_error_result(message="Failed to update dataset") - - return get_json_result(data={"dataset_id": dataset_id}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/dataset/', methods=['DELETE']) # noqa: F821 -@login_required -async def delete_dataset(dataset_id): - """Delete dataset (soft delete)""" - try: - success = EvaluationService.delete_dataset(dataset_id) - - if not success: - return get_data_error_result(message="Failed to delete dataset") - - return get_json_result(data={"dataset_id": dataset_id}) - except Exception as e: - return server_error_response(e) - - -# ==================== Test Case Management ==================== - -@manager.route('/dataset//case/add', methods=['POST']) # noqa: F821 -@login_required -@validate_request("question") -async def add_test_case(dataset_id): - """ - Add a test case to a dataset. - - Request body: - { - "question": "Test question", - "reference_answer": "Optional ground truth answer", - "relevant_doc_ids": ["doc_id1", "doc_id2"], - "relevant_chunk_ids": ["chunk_id1", "chunk_id2"], - "metadata": {"key": "value"} - } - """ - try: - req = await get_request_json() - question = req.get("question", "").strip() - - if not question: - return get_data_error_result(message="Question cannot be empty") - - success, result = EvaluationService.add_test_case( - dataset_id=dataset_id, - question=question, - reference_answer=req.get("reference_answer"), - relevant_doc_ids=req.get("relevant_doc_ids"), - relevant_chunk_ids=req.get("relevant_chunk_ids"), - metadata=req.get("metadata") - ) - - if not success: - return get_data_error_result(message=result) - - return get_json_result(data={"case_id": result}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/dataset//case/import', methods=['POST']) # noqa: F821 -@login_required -@validate_request("cases") -async def import_test_cases(dataset_id): - """ - Bulk import test cases. - - Request body: - { - "cases": [ - { - "question": "Question 1", - "reference_answer": "Answer 1", - ... - }, - { - "question": "Question 2", - ... - } - ] - } - """ - try: - req = await get_request_json() - cases = req.get("cases", []) - - if not cases or not isinstance(cases, list): - return get_data_error_result(message="cases must be a non-empty list") - - success_count, failure_count = EvaluationService.import_test_cases( - dataset_id=dataset_id, - cases=cases - ) - - return get_json_result(data={ - "success_count": success_count, - "failure_count": failure_count, - "total": len(cases) - }) - except Exception as e: - return server_error_response(e) - - -@manager.route('/dataset//cases', methods=['GET']) # noqa: F821 -@login_required -async def get_test_cases(dataset_id): - """Get all test cases for a dataset""" - try: - cases = EvaluationService.get_test_cases(dataset_id) - return get_json_result(data={"cases": cases, "total": len(cases)}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/case/', methods=['DELETE']) # noqa: F821 -@login_required -async def delete_test_case(case_id): - """Delete a test case""" - try: - success = EvaluationService.delete_test_case(case_id) - - if not success: - return get_data_error_result(message="Failed to delete test case") - - return get_json_result(data={"case_id": case_id}) - except Exception as e: - return server_error_response(e) - - -# ==================== Evaluation Execution ==================== - -@manager.route('/run/start', methods=['POST']) # noqa: F821 -@login_required -@validate_request("dataset_id", "dialog_id") -async def start_evaluation(): - """ - Start an evaluation run. - - Request body: - { - "dataset_id": "dataset_id", - "dialog_id": "dialog_id", - "name": "Optional run name" - } - """ - try: - req = await get_request_json() - dataset_id = req.get("dataset_id") - dialog_id = req.get("dialog_id") - name = req.get("name") - - success, result = EvaluationService.start_evaluation( - dataset_id=dataset_id, - dialog_id=dialog_id, - user_id=current_user.id, - name=name - ) - - if not success: - return get_data_error_result(message=result) - - return get_json_result(data={"run_id": result}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/run/', methods=['GET']) # noqa: F821 -@login_required -async def get_evaluation_run(run_id): - """Get evaluation run details""" - try: - result = EvaluationService.get_run_results(run_id) - - if not result: - return get_data_error_result( - message="Evaluation run not found", - code=RetCode.DATA_ERROR - ) - - return get_json_result(data=result) - except Exception as e: - return server_error_response(e) - - -@manager.route('/run//results', methods=['GET']) # noqa: F821 -@login_required -async def get_run_results(run_id): - """Get detailed results for an evaluation run""" - try: - result = EvaluationService.get_run_results(run_id) - - if not result: - return get_data_error_result( - message="Evaluation run not found", - code=RetCode.DATA_ERROR - ) - - return get_json_result(data=result) - except Exception as e: - return server_error_response(e) - - -@manager.route('/run/list', methods=['GET']) # noqa: F821 -@login_required -async def list_evaluation_runs(): - """ - List evaluation runs. - - Query params: - - dataset_id: Filter by dataset (optional) - - dialog_id: Filter by dialog (optional) - - page: Page number (default: 1) - - page_size: Items per page (default: 20) - """ - try: - # TODO: Implement list_runs in EvaluationService - return get_json_result(data={"runs": [], "total": 0}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/run/', methods=['DELETE']) # noqa: F821 -@login_required -async def delete_evaluation_run(run_id): - """Delete an evaluation run""" - try: - # TODO: Implement delete_run in EvaluationService - return get_json_result(data={"run_id": run_id}) - except Exception as e: - return server_error_response(e) - - -# ==================== Analysis & Recommendations ==================== - -@manager.route('/run//recommendations', methods=['GET']) # noqa: F821 -@login_required -async def get_recommendations(run_id): - """Get configuration recommendations based on evaluation results""" - try: - recommendations = EvaluationService.get_recommendations(run_id) - return get_json_result(data={"recommendations": recommendations}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/compare', methods=['POST']) # noqa: F821 -@login_required -@validate_request("run_ids") -async def compare_runs(): - """ - Compare multiple evaluation runs. - - Request body: - { - "run_ids": ["run_id1", "run_id2", "run_id3"] - } - """ - try: - req = await get_request_json() - run_ids = req.get("run_ids", []) - - if not run_ids or not isinstance(run_ids, list) or len(run_ids) < 2: - return get_data_error_result( - message="run_ids must be a list with at least 2 run IDs" - ) - - # TODO: Implement compare_runs in EvaluationService - return get_json_result(data={"comparison": {}}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/run//export', methods=['GET']) # noqa: F821 -@login_required -async def export_results(run_id): - """Export evaluation results as JSON/CSV""" - try: - # format_type = request.args.get("format", "json") # TODO: Use for CSV export - - result = EvaluationService.get_run_results(run_id) - - if not result: - return get_data_error_result( - message="Evaluation run not found", - code=RetCode.DATA_ERROR - ) - - # TODO: Implement CSV export - return get_json_result(data=result) - except Exception as e: - return server_error_response(e) - - -# ==================== Real-time Evaluation ==================== - -@manager.route('/evaluate_single', methods=['POST']) # noqa: F821 -@login_required -@validate_request("question", "dialog_id") -async def evaluate_single(): - """ - Evaluate a single question-answer pair in real-time. - - Request body: - { - "question": "Test question", - "dialog_id": "dialog_id", - "reference_answer": "Optional ground truth", - "relevant_chunk_ids": ["chunk_id1", "chunk_id2"] - } - """ - try: - # req = await get_request_json() # TODO: Use for single evaluation implementation - - # TODO: Implement single evaluation - # This would execute the RAG pipeline and return metrics immediately - - return get_json_result(data={ - "answer": "", - "metrics": {}, - "retrieved_chunks": [] - }) - except Exception as e: - return server_error_response(e) diff --git a/api/apps/file_app.py b/api/apps/file_app.py deleted file mode 100644 index 172b49ff850..00000000000 --- a/api/apps/file_app.py +++ /dev/null @@ -1,464 +0,0 @@ -# # -# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License -# # -# import logging -# import os -# import pathlib -# import re -# from quart import request, make_response -# from api.apps import login_required, current_user -# -# from api.common.check_team_permission import check_file_team_permission -# from api.db.services.document_service import DocumentService -# from api.db.services.file2document_service import File2DocumentService -# from api.utils.api_utils import server_error_response, get_data_error_result, validate_request -# from common.misc_utils import get_uuid, thread_pool_exec -# from common.constants import RetCode, FileSource -# from api.db import FileType -# from api.db.services import duplicate_name -# from api.db.services.file_service import FileService -# from api.utils.api_utils import get_json_result, get_request_json -# from api.utils.file_utils import filename_type -# from api.utils.web_utils import CONTENT_TYPE_MAP, apply_safe_file_response_headers -# from common import settings -# -# @manager.route('/upload', methods=['POST']) # noqa: F821 -# @login_required -# # @validate_request("parent_id") -# async def upload(): -# form = await request.form -# pf_id = form.get("parent_id") -# -# if not pf_id: -# root_folder = FileService.get_root_folder(current_user.id) -# pf_id = root_folder["id"] -# -# files = await request.files -# if 'file' not in files: -# return get_json_result( -# data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) -# file_objs = files.getlist('file') -# -# for file_obj in file_objs: -# if file_obj.filename == '': -# return get_json_result( -# data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) -# file_res = [] -# try: -# e, pf_folder = FileService.get_by_id(pf_id) -# if not e: -# return get_data_error_result( message="Can't find this folder!") -# -# async def _handle_single_file(file_obj): -# MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) -# if 0 < MAX_FILE_NUM_PER_USER <= await thread_pool_exec(DocumentService.get_doc_count, current_user.id): -# return get_data_error_result( message="Exceed the maximum file number of a free user!") -# -# # split file name path -# if not file_obj.filename: -# file_obj_names = [pf_folder.name, file_obj.filename] -# else: -# full_path = '/' + file_obj.filename -# file_obj_names = full_path.split('/') -# file_len = len(file_obj_names) -# -# # get folder -# file_id_list = await thread_pool_exec(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id]) -# len_id_list = len(file_id_list) -# -# # create folder -# if file_len != len_id_list: -# e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 1]) -# if not e: -# return get_data_error_result(message="Folder not found!") -# last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, -# len_id_list) -# else: -# e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2]) -# if not e: -# return get_data_error_result(message="Folder not found!") -# last_folder = await thread_pool_exec(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, -# len_id_list) -# -# # file type -# filetype = filename_type(file_obj_names[file_len - 1]) -# location = file_obj_names[file_len - 1] -# while await thread_pool_exec(settings.STORAGE_IMPL.obj_exist, last_folder.id, location): -# location += "_" -# blob = await thread_pool_exec(file_obj.read) -# filename = await thread_pool_exec( -# duplicate_name, -# FileService.query, -# name=file_obj_names[file_len - 1], -# parent_id=last_folder.id) -# await thread_pool_exec(settings.STORAGE_IMPL.put, last_folder.id, location, blob) -# file_data = { -# "id": get_uuid(), -# "parent_id": last_folder.id, -# "tenant_id": current_user.id, -# "created_by": current_user.id, -# "type": filetype, -# "name": filename, -# "location": location, -# "size": len(blob), -# } -# inserted = await thread_pool_exec(FileService.insert, file_data) -# return inserted.to_json() -# -# for file_obj in file_objs: -# res = await _handle_single_file(file_obj) -# file_res.append(res) -# -# return get_json_result(data=file_res) -# except Exception as e: -# return server_error_response(e) -# -# -# @manager.route('/create', methods=['POST']) # noqa: F821 -# @login_required -# @validate_request("name") -# async def create(): -# req = await get_request_json() -# pf_id = req.get("parent_id") -# input_file_type = req.get("type") -# if not pf_id: -# root_folder = FileService.get_root_folder(current_user.id) -# pf_id = root_folder["id"] -# -# try: -# if not FileService.is_parent_folder_exist(pf_id): -# return get_json_result( -# data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR) -# if FileService.query(name=req["name"], parent_id=pf_id): -# return get_data_error_result( -# message="Duplicated folder name in the same folder.") -# -# if input_file_type == FileType.FOLDER.value: -# file_type = FileType.FOLDER.value -# else: -# file_type = FileType.VIRTUAL.value -# -# file = FileService.insert({ -# "id": get_uuid(), -# "parent_id": pf_id, -# "tenant_id": current_user.id, -# "created_by": current_user.id, -# "name": req["name"], -# "location": "", -# "size": 0, -# "type": file_type -# }) -# -# return get_json_result(data=file.to_json()) -# except Exception as e: -# return server_error_response(e) -# -# -# @manager.route('/list', methods=['GET']) # noqa: F821 -# @login_required -# def list_files(): -# pf_id = request.args.get("parent_id") -# -# keywords = request.args.get("keywords", "") -# -# page_number = int(request.args.get("page", 1)) -# items_per_page = int(request.args.get("page_size", 15)) -# orderby = request.args.get("orderby", "create_time") -# desc = request.args.get("desc", True) -# if not pf_id: -# root_folder = FileService.get_root_folder(current_user.id) -# pf_id = root_folder["id"] -# FileService.init_knowledgebase_docs(pf_id, current_user.id) -# try: -# e, file = FileService.get_by_id(pf_id) -# if not e: -# return get_data_error_result(message="Folder not found!") -# -# files, total = FileService.get_by_pf_id( -# current_user.id, pf_id, page_number, items_per_page, orderby, desc, keywords) -# -# parent_folder = FileService.get_parent_folder(pf_id) -# if not parent_folder: -# return get_json_result(message="File not found!") -# -# return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()}) -# except Exception as e: -# return server_error_response(e) -# -# -# @manager.route('/root_folder', methods=['GET']) # noqa: F821 -# @login_required -# def get_root_folder(): -# try: -# root_folder = FileService.get_root_folder(current_user.id) -# return get_json_result(data={"root_folder": root_folder}) -# except Exception as e: -# return server_error_response(e) -# -# -# @manager.route('/parent_folder', methods=['GET']) # noqa: F821 -# @login_required -# def get_parent_folder(): -# file_id = request.args.get("file_id") -# try: -# e, file = FileService.get_by_id(file_id) -# if not e: -# return get_data_error_result(message="Folder not found!") -# -# parent_folder = FileService.get_parent_folder(file_id) -# return get_json_result(data={"parent_folder": parent_folder.to_json()}) -# except Exception as e: -# return server_error_response(e) -# -# -# @manager.route('/all_parent_folder', methods=['GET']) # noqa: F821 -# @login_required -# def get_all_parent_folders(): -# file_id = request.args.get("file_id") -# try: -# e, file = FileService.get_by_id(file_id) -# if not e: -# return get_data_error_result(message="Folder not found!") -# -# parent_folders = FileService.get_all_parent_folders(file_id) -# parent_folders_res = [] -# for parent_folder in parent_folders: -# parent_folders_res.append(parent_folder.to_json()) -# return get_json_result(data={"parent_folders": parent_folders_res}) -# except Exception as e: -# return server_error_response(e) -# -# -# @manager.route("/rm", methods=["POST"]) # noqa: F821 -# @login_required -# @validate_request("file_ids") -# async def rm(): -# req = await get_request_json() -# file_ids = req["file_ids"] -# uid = current_user.id -# -# try: -# def _delete_single_file(file): -# try: -# if file.location: -# settings.STORAGE_IMPL.rm(file.parent_id, file.location) -# except Exception as e: -# logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}") -# -# informs = File2DocumentService.get_by_file_id(file.id) -# for inform in informs: -# doc_id = inform.document_id -# e, doc = DocumentService.get_by_id(doc_id) -# if e and doc: -# tenant_id = DocumentService.get_tenant_id(doc_id) -# if tenant_id: -# DocumentService.remove_document(doc, tenant_id) -# File2DocumentService.delete_by_file_id(file.id) -# -# FileService.delete(file) -# -# def _delete_folder_recursive(folder, tenant_id): -# sub_files = FileService.list_all_files_by_parent_id(folder.id) -# for sub_file in sub_files: -# if sub_file.type == FileType.FOLDER.value: -# _delete_folder_recursive(sub_file, tenant_id) -# else: -# _delete_single_file(sub_file) -# -# FileService.delete(folder) -# -# def _rm_sync(): -# for file_id in file_ids: -# e, file = FileService.get_by_id(file_id) -# if not e or not file: -# return get_data_error_result(message="File or Folder not found!") -# if not file.tenant_id: -# return get_data_error_result(message="Tenant not found!") -# if not check_file_team_permission(file, uid): -# return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) -# -# if file.source_type == FileSource.KNOWLEDGEBASE: -# continue -# -# if file.type == FileType.FOLDER.value: -# _delete_folder_recursive(file, uid) -# continue -# -# _delete_single_file(file) -# -# return get_json_result(data=True) -# -# return await thread_pool_exec(_rm_sync) -# -# except Exception as e: -# return server_error_response(e) -# -# -# @manager.route('/rename', methods=['POST']) # noqa: F821 -# @login_required -# @validate_request("file_id", "name") -# async def rename(): -# req = await get_request_json() -# try: -# e, file = FileService.get_by_id(req["file_id"]) -# if not e: -# return get_data_error_result(message="File not found!") -# if not check_file_team_permission(file, current_user.id): -# return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) -# if file.type != FileType.FOLDER.value \ -# and pathlib.Path(req["name"].lower()).suffix != pathlib.Path( -# file.name.lower()).suffix: -# return get_json_result( -# data=False, -# message="The extension of file can't be changed", -# code=RetCode.ARGUMENT_ERROR) -# for file in FileService.query(name=req["name"], pf_id=file.parent_id): -# if file.name == req["name"]: -# return get_data_error_result( -# message="Duplicated file name in the same folder.") -# -# if not FileService.update_by_id( -# req["file_id"], {"name": req["name"]}): -# return get_data_error_result( -# message="Database error (File rename)!") -# -# informs = File2DocumentService.get_by_file_id(req["file_id"]) -# if informs: -# if not DocumentService.update_by_id( -# informs[0].document_id, {"name": req["name"]}): -# return get_data_error_result( -# message="Database error (Document rename)!") -# -# return get_json_result(data=True) -# except Exception as e: -# return server_error_response(e) -# -# -# @manager.route('/get/', methods=['GET']) # noqa: F821 -# @login_required -# async def get(file_id): -# try: -# e, file = FileService.get_by_id(file_id) -# if not e: -# return get_data_error_result(message="Document not found!") -# if not check_file_team_permission(file, current_user.id): -# return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) -# -# blob = await thread_pool_exec(settings.STORAGE_IMPL.get, file.parent_id, file.location) -# if not blob: -# b, n = File2DocumentService.get_storage_address(file_id=file_id) -# blob = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n) -# -# response = await make_response(blob) -# ext = re.search(r"\.([^.]+)$", file.name.lower()) -# ext = ext.group(1) if ext else None -# content_type = None -# if ext: -# fallback_prefix = "image" if file.type == FileType.VISUAL.value else "application" -# content_type = CONTENT_TYPE_MAP.get(ext, f"{fallback_prefix}/{ext}") -# apply_safe_file_response_headers(response, content_type, ext) -# return response -# except Exception as e: -# return server_error_response(e) -# -# -# @manager.route("/mv", methods=["POST"]) # noqa: F821 -# @login_required -# @validate_request("src_file_ids", "dest_file_id") -# async def move(): -# req = await get_request_json() -# try: -# file_ids = req["src_file_ids"] -# dest_parent_id = req["dest_file_id"] -# -# ok, dest_folder = FileService.get_by_id(dest_parent_id) -# if not ok or not dest_folder: -# return get_data_error_result(message="Parent folder not found!") -# -# files = FileService.get_by_ids(file_ids) -# if not files: -# return get_data_error_result(message="Source files not found!") -# -# files_dict = {f.id: f for f in files} -# -# for file_id in file_ids: -# file = files_dict.get(file_id) -# if not file: -# return get_data_error_result(message="File or folder not found!") -# if not file.tenant_id: -# return get_data_error_result(message="Tenant not found!") -# if not check_file_team_permission(file, current_user.id): -# return get_json_result( -# data=False, -# message="No authorization.", -# code=RetCode.AUTHENTICATION_ERROR, -# ) -# -# def _move_entry_recursive(source_file_entry, dest_folder): -# if source_file_entry.type == FileType.FOLDER.value: -# existing_folder = FileService.query(name=source_file_entry.name, parent_id=dest_folder.id) -# if existing_folder: -# new_folder = existing_folder[0] -# else: -# new_folder = FileService.insert( -# { -# "id": get_uuid(), -# "parent_id": dest_folder.id, -# "tenant_id": source_file_entry.tenant_id, -# "created_by": current_user.id, -# "name": source_file_entry.name, -# "location": "", -# "size": 0, -# "type": FileType.FOLDER.value, -# } -# ) -# -# sub_files = FileService.list_all_files_by_parent_id(source_file_entry.id) -# for sub_file in sub_files: -# _move_entry_recursive(sub_file, new_folder) -# -# FileService.delete_by_id(source_file_entry.id) -# return -# -# old_parent_id = source_file_entry.parent_id -# old_location = source_file_entry.location -# filename = source_file_entry.name -# -# new_location = filename -# while settings.STORAGE_IMPL.obj_exist(dest_folder.id, new_location): -# new_location += "_" -# -# try: -# settings.STORAGE_IMPL.move(old_parent_id, old_location, dest_folder.id, new_location) -# except Exception as storage_err: -# raise RuntimeError(f"Move file failed at storage layer: {str(storage_err)}") -# -# FileService.update_by_id( -# source_file_entry.id, -# { -# "parent_id": dest_folder.id, -# "location": new_location, -# }, -# ) -# -# def _move_sync(): -# for file in files: -# _move_entry_recursive(file, dest_folder) -# return get_json_result(data=True) -# -# return await thread_pool_exec(_move_sync) -# -# except Exception as e: -# return server_error_response(e) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py deleted file mode 100644 index 730d63c66ca..00000000000 --- a/api/apps/kb_app.py +++ /dev/null @@ -1,1012 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging -import random -import re - -from common.metadata_utils import turn2jsonschema -from quart import request -import numpy as np - -from api.db.services.connector_service import Connector2KbService -from api.db.services.llm_service import LLMBundle -from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks -from api.db.services.doc_metadata_service import DocMetadataService -from api.db.services.pipeline_operation_log_service import PipelineOperationLogService -from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID -from api.db.services.user_service import UserTenantService -from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_model_config_by_id -from api.utils.api_utils import ( - get_error_data_result, - server_error_response, - get_data_error_result, - validate_request, - get_request_json, -) -from api.db import VALID_FILE_TYPES -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.utils.api_utils import get_json_result -from rag.nlp import search -from rag.utils.redis_conn import REDIS_CONN -from common.constants import RetCode, PipelineTaskType, VALID_TASK_STATUS, LLMType -from common import settings -from common.doc_store.doc_store_base import OrderByExpr -from api.apps import login_required, current_user - -""" -Deprecated, todo delete -@manager.route('/create', methods=['post']) # noqa: F821 -@login_required -@validate_request("name") -async def create(): - req = await get_request_json() - create_dict = ensure_tenant_model_id_for_params(current_user.id, req) - e, res = KnowledgebaseService.create_with_name( - name = create_dict.pop("name", None), - tenant_id = current_user.id, - parser_id = create_dict.pop("parser_id", None), - **create_dict - ) - - if not e: - return res - - try: - if not KnowledgebaseService.save(**res): - return get_data_error_result() - return get_json_result(data={"kb_id":res["id"]}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/update', methods=['post']) # noqa: F821 -@login_required -@validate_request("kb_id", "name", "description", "parser_id") -@not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") -async def update(): - req = await get_request_json() - update_dict = ensure_tenant_model_id_for_params(current_user.id, req) - if not isinstance(update_dict["name"], str): - return get_data_error_result(message="Dataset name must be string.") - if update_dict["name"].strip() == "": - return get_data_error_result(message="Dataset name can't be empty.") - if len(update_dict["name"].encode("utf-8")) > DATASET_NAME_LIMIT: - return get_data_error_result( - message=f"Dataset name length is {len(update_dict['name'])} which is large than {DATASET_NAME_LIMIT}") - update_dict["name"] = update_dict["name"].strip() - if settings.DOC_ENGINE_INFINITY: - parser_id = update_dict.get("parser_id") - if isinstance(parser_id, str) and parser_id.lower() == "tag": - return get_json_result( - code=RetCode.OPERATING_ERROR, - message="The chunking method Tag has not been supported by Infinity yet.", - data=False, - ) - if "pagerank" in update_dict and update_dict["pagerank"] > 0: - return get_json_result( - code=RetCode.DATA_ERROR, - message="'pagerank' can only be set when doc_engine is elasticsearch", - data=False, - ) - - if not KnowledgebaseService.accessible4deletion(update_dict["kb_id"], current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - try: - if not KnowledgebaseService.query( - created_by=current_user.id, id=update_dict["kb_id"]): - return get_json_result( - data=False, message='Only owner of dataset authorized for this operation.', - code=RetCode.OPERATING_ERROR) - - e, kb = KnowledgebaseService.get_by_id(update_dict["kb_id"]) - - # Rename folder in FileService - if e and update_dict["name"].lower() != kb.name.lower(): - FileService.filter_update( - [ - File.tenant_id == kb.tenant_id, - File.source_type == FileSource.KNOWLEDGEBASE, - File.type == "folder", - File.name == kb.name, - ], - {"name": update_dict["name"]}, - ) - - if not e: - return get_data_error_result( - message="Can't find this dataset!") - - if update_dict["name"].lower() != kb.name.lower() \ - and len( - KnowledgebaseService.query(name=update_dict["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1: - return get_data_error_result( - message="Duplicated dataset name.") - - del update_dict["kb_id"] - connectors = [] - if "connectors" in update_dict: - connectors = update_dict["connectors"] - del update_dict["connectors"] - if not KnowledgebaseService.update_by_id(kb.id, update_dict): - return get_data_error_result() - - if kb.pagerank != update_dict.get("pagerank", 0): - if update_dict.get("pagerank", 0) > 0: - await thread_pool_exec( - settings.docStoreConn.update, - {"kb_id": kb.id}, - {PAGERANK_FLD: update_dict["pagerank"]}, - search.index_name(kb.tenant_id), - kb.id, - ) - else: - # Elasticsearch requires PAGERANK_FLD be non-zero! - await thread_pool_exec( - settings.docStoreConn.update, - {"exists": PAGERANK_FLD}, - {"remove": PAGERANK_FLD}, - search.index_name(kb.tenant_id), - kb.id, - ) - - e, kb = KnowledgebaseService.get_by_id(kb.id) - if not e: - return get_data_error_result( - message="Database error (Knowledgebase rename)!") - errors = Connector2KbService.link_connectors(kb.id, [conn for conn in connectors], current_user.id) - if errors: - logging.error("Link KB errors: ", errors) - kb = kb.to_dict() - kb.update(update_dict) - kb["connectors"] = connectors - - return get_json_result(data=kb) - except Exception as e: - return server_error_response(e) -""" - -@manager.route('/update_metadata_setting', methods=['post']) # noqa: F821 -@login_required -@validate_request("kb_id", "metadata") -async def update_metadata_setting(): - req = await get_request_json() - e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) - if not e: - return get_data_error_result( - message="Database error (Knowledgebase rename)!") - kb = kb.to_dict() - kb["parser_config"]["metadata"] = req["metadata"] - kb["parser_config"]["enable_metadata"] = req.get("enable_metadata", True) - KnowledgebaseService.update_by_id(kb["id"], kb) - return get_json_result(data=kb) - - -@manager.route('/detail', methods=['GET']) # noqa: F821 -@login_required -def detail(): - kb_id = request.args["kb_id"] - try: - tenants = UserTenantService.query(user_id=current_user.id) - for tenant in tenants: - if KnowledgebaseService.query( - tenant_id=tenant.tenant_id, id=kb_id): - break - else: - return get_json_result( - data=False, message='Only owner of dataset authorized for this operation.', - code=RetCode.OPERATING_ERROR) - kb = KnowledgebaseService.get_detail(kb_id) - if not kb: - return get_data_error_result( - message="Can't find this dataset!") - kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[]) - kb["connectors"] = Connector2KbService.list_connectors(kb_id) - if kb["parser_config"].get("metadata"): - kb["parser_config"]["metadata"] = turn2jsonschema(kb["parser_config"]["metadata"]) - - for key in ["graphrag_task_finish_at", "raptor_task_finish_at", "mindmap_task_finish_at"]: - if finish_at := kb.get(key): - kb[key] = finish_at.strftime("%Y-%m-%d %H:%M:%S") - return get_json_result(data=kb) - except Exception as e: - return server_error_response(e) - -""" -Deprecated, todo delete -@manager.route('/list', methods=['POST']) # noqa: F821 -@login_required -async def list_kbs(): - args = request.args - keywords = args.get("keywords", "") - page_number = int(args.get("page", 0)) - items_per_page = int(args.get("page_size", 0)) - parser_id = args.get("parser_id") - orderby = args.get("orderby", "create_time") - if args.get("desc", "true").lower() == "false": - desc = False - else: - desc = True - - req = await get_request_json() - owner_ids = req.get("owner_ids", []) - try: - if not owner_ids: - tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) - tenants = [m["tenant_id"] for m in tenants] - kbs, total = KnowledgebaseService.get_by_tenant_ids( - tenants, current_user.id, page_number, - items_per_page, orderby, desc, keywords, parser_id) - else: - tenants = owner_ids - kbs, total = KnowledgebaseService.get_by_tenant_ids( - tenants, current_user.id, 0, - 0, orderby, desc, keywords, parser_id) - kbs = [kb for kb in kbs if kb["tenant_id"] in tenants] - total = len(kbs) - if page_number and items_per_page: - kbs = kbs[(page_number-1)*items_per_page:page_number*items_per_page] - return get_json_result(data={"kbs": kbs, "total": total}) - except Exception as e: - return server_error_response(e) - - -@manager.route('/rm', methods=['post']) # noqa: F821 -@login_required -@validate_request("kb_id") -async def rm(): - req = await get_request_json() - uid = current_user.id - if not KnowledgebaseService.accessible4deletion(req["kb_id"], uid): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - try: - kbs = KnowledgebaseService.query( - created_by=uid, id=req["kb_id"]) - if not kbs: - return get_json_result( - data=False, message='Only owner of dataset authorized for this operation.', - code=RetCode.OPERATING_ERROR) - - def _rm_sync(): - for doc in DocumentService.query(kb_id=req["kb_id"]): - if not DocumentService.remove_document(doc, kbs[0].tenant_id): - return get_data_error_result( - message="Database error (Document removal)!") - f2d = File2DocumentService.get_by_document_id(doc.id) - if f2d: - FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) - File2DocumentService.delete_by_document_id(doc.id) - FileService.filter_delete( - [ - File.tenant_id == kbs[0].tenant_id, - File.source_type == FileSource.KNOWLEDGEBASE, - File.type == "folder", - File.name == kbs[0].name, - ] - ) - # Delete the table BEFORE deleting the database record - for kb in kbs: - try: - settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) - settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id) - logging.info(f"Dropped index for dataset {kb.id}") - except Exception as e: - logging.error(f"Failed to drop index for dataset {kb.id}: {e}") - - if not KnowledgebaseService.delete_by_id(req["kb_id"]): - return get_data_error_result( - message="Database error (Knowledgebase removal)!") - for kb in kbs: - if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): - settings.STORAGE_IMPL.remove_bucket(kb.id) - return get_json_result(data=True) - - return await thread_pool_exec(_rm_sync) - except Exception as e: - return server_error_response(e) -""" - -@manager.route('//tags', methods=['GET']) # noqa: F821 -@login_required -def list_tags(kb_id): - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - - tenants = UserTenantService.get_tenants_by_user_id(current_user.id) - tags = [] - for tenant in tenants: - tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id]) - return get_json_result(data=tags) - - -@manager.route('/tags', methods=['GET']) # noqa: F821 -@login_required -def list_tags_from_kbs(): - kb_ids = request.args.get("kb_ids", "").split(",") - for kb_id in kb_ids: - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - - tenants = UserTenantService.get_tenants_by_user_id(current_user.id) - tags = [] - for tenant in tenants: - tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids) - return get_json_result(data=tags) - - -@manager.route('//rm_tags', methods=['POST']) # noqa: F821 -@login_required -async def rm_tags(kb_id): - req = await get_request_json() - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - e, kb = KnowledgebaseService.get_by_id(kb_id) - - for t in req["tags"]: - settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]}, - {"remove": {"tag_kwd": t}}, - search.index_name(kb.tenant_id), - kb_id) - return get_json_result(data=True) - - -@manager.route('//rename_tag', methods=['POST']) # noqa: F821 -@login_required -async def rename_tags(kb_id): - req = await get_request_json() - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - e, kb = KnowledgebaseService.get_by_id(kb_id) - - settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]}, - {"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}}, - search.index_name(kb.tenant_id), - kb_id) - return get_json_result(data=True) - -""" -Deprecated, todo delete -@manager.route('//knowledge_graph', methods=['GET']) # noqa: F821 -@login_required -async def knowledge_graph(kb_id): - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - _, kb = KnowledgebaseService.get_by_id(kb_id) - req = { - "kb_id": [kb_id], - "knowledge_graph_kwd": ["graph"] - } - - obj = {"graph": {}, "mind_map": {}} - if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id): - return get_json_result(data=obj) - sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) - if not len(sres.ids): - return get_json_result(data=obj) - - for id in sres.ids[:1]: - ty = sres.field[id]["knowledge_graph_kwd"] - try: - content_json = json.loads(sres.field[id]["content_with_weight"]) - except Exception: - continue - - obj[ty] = content_json - - if "nodes" in obj["graph"]: - obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256] - if "edges" in obj["graph"]: - node_id_set = { o["id"] for o in obj["graph"]["nodes"] } - filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] - obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128] - return get_json_result(data=obj) - - -@manager.route('//knowledge_graph', methods=['DELETE']) # noqa: F821 -@login_required -def delete_knowledge_graph(kb_id): - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - _, kb = KnowledgebaseService.get_by_id(kb_id) - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) - - return get_json_result(data=True) -""" - -@manager.route("/get_meta", methods=["GET"]) # noqa: F821 -@login_required -def get_meta(): - kb_ids = request.args.get("kb_ids", "").split(",") - for kb_id in kb_ids: - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - return get_json_result(data=DocMetadataService.get_flatted_meta_by_kbs(kb_ids)) - - -@manager.route("/basic_info", methods=["GET"]) # noqa: F821 -@login_required -def get_basic_info(): - kb_id = request.args.get("kb_id", "") - if not KnowledgebaseService.accessible(kb_id, current_user.id): - return get_json_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - - basic_info = DocumentService.knowledgebase_basic_info(kb_id) - - return get_json_result(data=basic_info) - - -@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821 -@login_required -async def list_pipeline_logs(): - kb_id = request.args.get("kb_id") - if not kb_id: - return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - - keywords = request.args.get("keywords", "") - - page_number = int(request.args.get("page", 0)) - items_per_page = int(request.args.get("page_size", 0)) - orderby = request.args.get("orderby", "create_time") - if request.args.get("desc", "true").lower() == "false": - desc = False - else: - desc = True - create_date_from = request.args.get("create_date_from", "") - create_date_to = request.args.get("create_date_to", "") - if create_date_to > create_date_from: - return get_data_error_result(message="Create data filter is abnormal.") - - req = await get_request_json() - - operation_status = req.get("operation_status", []) - if operation_status: - invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS} - if invalid_status: - return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}") - - types = req.get("types", []) - if types: - invalid_types = {t for t in types if t not in VALID_FILE_TYPES} - if invalid_types: - return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}") - - suffix = req.get("suffix", []) - - try: - logs, count = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to) - return get_json_result(data={"total": count, "logs": logs}) - except Exception as e: - return server_error_response(e) - - -@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821 -@login_required -async def list_pipeline_dataset_logs(): - kb_id = request.args.get("kb_id") - if not kb_id: - return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - - page_number = int(request.args.get("page", 0)) - items_per_page = int(request.args.get("page_size", 0)) - orderby = request.args.get("orderby", "create_time") - if request.args.get("desc", "true").lower() == "false": - desc = False - else: - desc = True - create_date_from = request.args.get("create_date_from", "") - create_date_to = request.args.get("create_date_to", "") - if create_date_to > create_date_from: - return get_data_error_result(message="Create data filter is abnormal.") - - req = await get_request_json() - - operation_status = req.get("operation_status", []) - if operation_status: - invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS} - if invalid_status: - return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}") - - try: - logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to) - return get_json_result(data={"total": tol, "logs": logs}) - except Exception as e: - return server_error_response(e) - - -@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821 -@login_required -async def delete_pipeline_logs(): - kb_id = request.args.get("kb_id") - if not kb_id: - return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - - req = await get_request_json() - log_ids = req.get("log_ids", []) - - PipelineOperationLogService.delete_by_ids(log_ids) - - return get_json_result(data=True) - - -@manager.route("/pipeline_log_detail", methods=["GET"]) # noqa: F821 -@login_required -def pipeline_log_detail(): - log_id = request.args.get("log_id") - if not log_id: - return get_json_result(data=False, message='Lack of "Pipeline log ID"', code=RetCode.ARGUMENT_ERROR) - - ok, log = PipelineOperationLogService.get_by_id(log_id) - if not ok: - return get_data_error_result(message="Invalid pipeline log ID") - - return get_json_result(data=log.to_dict()) - - -""" -Deprecated, todo delete -@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821 -@login_required -async def run_graphrag(): - req = await get_request_json() - - kb_id = req.get("kb_id", "") - if not kb_id: - return get_error_data_result(message='Lack of "KB ID"') - - ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok: - return get_error_data_result(message="Invalid Knowledgebase ID") - - task_id = kb.graphrag_task_id - if task_id: - ok, task = TaskService.get_by_id(task_id) - if not ok: - logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}") - - if task and task.progress not in [-1, 1]: - return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.") - - documents, _ = DocumentService.get_by_kb_id( - kb_id=kb_id, - page_number=0, - items_per_page=0, - orderby="create_time", - desc=False, - keywords="", - run_status=[], - types=[], - suffix=[], - ) - if not documents: - return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}") - - sample_document = documents[0] - document_ids = [document["id"] for document in documents] - - task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - - if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): - logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}") - - return get_json_result(data={"graphrag_task_id": task_id}) - - -@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821 -@login_required -def trace_graphrag(): - kb_id = request.args.get("kb_id", "") - if not kb_id: - return get_error_data_result(message='Lack of "KB ID"') - - ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok: - return get_error_data_result(message="Invalid Knowledgebase ID") - - task_id = kb.graphrag_task_id - if not task_id: - return get_json_result(data={}) - - ok, task = TaskService.get_by_id(task_id) - if not ok: - return get_json_result(data={}) - - return get_json_result(data=task.to_dict()) - - -@manager.route("/run_raptor", methods=["POST"]) # noqa: F821 -@login_required -async def run_raptor(): - req = await get_request_json() - - kb_id = req.get("kb_id", "") - if not kb_id: - return get_error_data_result(message='Lack of "KB ID"') - - ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok: - return get_error_data_result(message="Invalid Knowledgebase ID") - - task_id = kb.raptor_task_id - if task_id: - ok, task = TaskService.get_by_id(task_id) - if not ok: - logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}") - - if task and task.progress not in [-1, 1]: - return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.") - - documents, _ = DocumentService.get_by_kb_id( - kb_id=kb_id, - page_number=0, - items_per_page=0, - orderby="create_time", - desc=False, - keywords="", - run_status=[], - types=[], - suffix=[], - ) - if not documents: - return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}") - - sample_document = documents[0] - document_ids = [document["id"] for document in documents] - - task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - - if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}): - logging.warning(f"Cannot save raptor_task_id for kb {kb_id}") - - return get_json_result(data={"raptor_task_id": task_id}) - - -@manager.route("/trace_raptor", methods=["GET"]) # noqa: F821 -@login_required -def trace_raptor(): - kb_id = request.args.get("kb_id", "") - if not kb_id: - return get_error_data_result(message='Lack of "KB ID"') - - ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok: - return get_error_data_result(message="Invalid Knowledgebase ID") - - task_id = kb.raptor_task_id - if not task_id: - return get_json_result(data={}) - - ok, task = TaskService.get_by_id(task_id) - if not ok: - return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred") - - return get_json_result(data=task.to_dict()) -""" - -@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821 -@login_required -async def run_mindmap(): - req = await get_request_json() - - kb_id = req.get("kb_id", "") - if not kb_id: - return get_error_data_result(message='Lack of "KB ID"') - - ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok: - return get_error_data_result(message="Invalid Knowledgebase ID") - - task_id = kb.mindmap_task_id - if task_id: - ok, task = TaskService.get_by_id(task_id) - if not ok: - logging.warning(f"A valid Mindmap task id is expected for kb {kb_id}") - - if task and task.progress not in [-1, 1]: - return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Mindmap Task is already running.") - - documents, _ = DocumentService.get_by_kb_id( - kb_id=kb_id, - page_number=0, - items_per_page=0, - orderby="create_time", - desc=False, - keywords="", - run_status=[], - types=[], - suffix=[], - ) - if not documents: - return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}") - - sample_document = documents[0] - document_ids = [document["id"] for document in documents] - - task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - - if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}): - logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}") - - return get_json_result(data={"mindmap_task_id": task_id}) - - -@manager.route("/trace_mindmap", methods=["GET"]) # noqa: F821 -@login_required -def trace_mindmap(): - kb_id = request.args.get("kb_id", "") - if not kb_id: - return get_error_data_result(message='Lack of "KB ID"') - - ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok: - return get_error_data_result(message="Invalid Knowledgebase ID") - - task_id = kb.mindmap_task_id - if not task_id: - return get_json_result(data={}) - - ok, task = TaskService.get_by_id(task_id) - if not ok: - return get_error_data_result(message="Mindmap Task Not Found or Error Occurred") - - return get_json_result(data=task.to_dict()) - - -@manager.route("/unbind_task", methods=["DELETE"]) # noqa: F821 -@login_required -def delete_kb_task(): - kb_id = request.args.get("kb_id", "") - if not kb_id: - return get_error_data_result(message='Lack of "KB ID"') - ok, kb = KnowledgebaseService.get_by_id(kb_id) - if not ok: - return get_json_result(data=True) - - pipeline_task_type = request.args.get("pipeline_task_type", "") - if not pipeline_task_type or pipeline_task_type not in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]: - return get_error_data_result(message="Invalid task type") - - def cancel_task(task_id): - REDIS_CONN.set(f"{task_id}-cancel", "x") - - kb_task_id_field: str = "" - kb_task_finish_at: str = "" - match pipeline_task_type: - case PipelineTaskType.GRAPH_RAG: - kb_task_id_field = "graphrag_task_id" - task_id = kb.graphrag_task_id - kb_task_finish_at = "graphrag_task_finish_at" - cancel_task(task_id) - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) - case PipelineTaskType.RAPTOR: - kb_task_id_field = "raptor_task_id" - task_id = kb.raptor_task_id - kb_task_finish_at = "raptor_task_finish_at" - cancel_task(task_id) - settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id) - case PipelineTaskType.MINDMAP: - kb_task_id_field = "mindmap_task_id" - task_id = kb.mindmap_task_id - kb_task_finish_at = "mindmap_task_finish_at" - cancel_task(task_id) - case _: - return get_error_data_result(message="Internal Error: Invalid task type") - - - ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id_field: "", kb_task_finish_at: None}) - if not ok: - return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}") - - return get_json_result(data=True) - -@manager.route("/check_embedding", methods=["post"]) # noqa: F821 -@login_required -async def check_embedding(): - - def _guess_vec_field(src: dict) -> str | None: - for k in src or {}: - if k.endswith("_vec"): - return k - return None - - def _as_float_vec(v): - if v is None: - return [] - if isinstance(v, str): - return [float(x) for x in v.split("\t") if x != ""] - if isinstance(v, (list, tuple, np.ndarray)): - return [float(x) for x in v] - return [] - - def _to_1d(x): - a = np.asarray(x, dtype=np.float32) - return a.reshape(-1) - - def _cos_sim(a, b, eps=1e-12): - a = _to_1d(a) - b = _to_1d(b) - na = np.linalg.norm(a) - nb = np.linalg.norm(b) - if na < eps or nb < eps: - return 0.0 - return float(np.dot(a, b) / (na * nb)) - - def sample_random_chunks_with_vectors( - docStoreConn, - tenant_id: str, - kb_id: str, - n: int = 5, - base_fields=("docnm_kwd","doc_id","content_with_weight","page_num_int","position_int","top_int"), - ): - index_nm = search.index_name(tenant_id) - - res0 = docStoreConn.search( - select_fields=[], highlight_fields=[], - condition={"kb_id": kb_id, "available_int": 1}, - match_expressions=[], order_by=OrderByExpr(), - offset=0, limit=1, - index_names=index_nm, knowledgebase_ids=[kb_id] - ) - total = docStoreConn.get_total(res0) - if total <= 0: - return [] - - n = min(n, total) - offsets = sorted(random.sample(range(min(total,1000)), n)) - out = [] - - for off in offsets: - res1 = docStoreConn.search( - select_fields=list(base_fields), - highlight_fields=[], - condition={"kb_id": kb_id, "available_int": 1}, - match_expressions=[], order_by=OrderByExpr(), - offset=off, limit=1, - index_names=index_nm, knowledgebase_ids=[kb_id] - ) - ids = docStoreConn.get_doc_ids(res1) - if not ids: - continue - - cid = ids[0] - full_doc = docStoreConn.get(cid, index_nm, [kb_id]) or {} - vec_field = _guess_vec_field(full_doc) - vec = _as_float_vec(full_doc.get(vec_field)) - - out.append({ - "chunk_id": cid, - "kb_id": kb_id, - "doc_id": full_doc.get("doc_id"), - "doc_name": full_doc.get("docnm_kwd"), - "vector_field": vec_field, - "vector_dim": len(vec), - "vector": vec, - "page_num_int": full_doc.get("page_num_int"), - "position_int": full_doc.get("position_int"), - "top_int": full_doc.get("top_int"), - "content_with_weight": full_doc.get("content_with_weight") or "", - "question_kwd": full_doc.get("question_kwd") or [] - }) - return out - - def _clean(s: str) -> str: - s = re.sub(r"]{0,12})?>", " ", s or "") - return s if s else "None" - req = await get_request_json() - kb_id = req.get("kb_id", "") - tenant_embd_id = req.get("tenant_embd_id") - embd_id = req.get("embd_id", "") - n = int(req.get("check_num", 5)) - _, kb = KnowledgebaseService.get_by_id(kb_id) - tenant_id = kb.tenant_id - if tenant_embd_id: - embd_model_config = get_model_config_by_id(tenant_embd_id) - elif embd_id: - embd_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING, embd_id) - else: - return get_error_data_result("`tenant_embd_id` or `embd_id` is required.") - emb_mdl = LLMBundle(tenant_id, embd_model_config) - samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n) - - results, eff_sims = [], [] - for ck in samples: - title = ck.get("doc_name") or "Title" - txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or "" - txt_in = _clean(txt_in) - if not txt_in: - results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"}) - continue - - if not ck.get("vector"): - results.append({"chunk_id": ck["chunk_id"], "reason": "no_stored_vector"}) - continue - - try: - v, _ = emb_mdl.encode([title, txt_in]) - assert len(v[1]) == len(ck["vector"]), f"The dimension ({len(v[1])}) of given embedding model is different from the original ({len(ck['vector'])})" - sim_content = _cos_sim(v[1], ck["vector"]) - title_w = 0.1 - qv_mix = title_w * v[0] + (1 - title_w) * v[1] - sim_mix = _cos_sim(qv_mix, ck["vector"]) - sim = sim_content - mode = "content_only" - if sim_mix > sim: - sim = sim_mix - mode = "title+content" - except Exception as e: - return get_error_data_result(message=f"Embedding failure. {e}") - - eff_sims.append(sim) - results.append({ - "chunk_id": ck["chunk_id"], - "doc_id": ck["doc_id"], - "doc_name": ck["doc_name"], - "vector_field": ck["vector_field"], - "vector_dim": ck["vector_dim"], - "cos_sim": round(sim, 6), - }) - - summary = { - "kb_id": kb_id, - "model": embd_id, - "sampled": len(samples), - "valid": len(eff_sims), - "avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6), - "min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6), - "max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6), - "match_mode": mode, - } - if summary["avg_cos_sim"] > 0.9: - return get_json_result(data={"summary": summary, "results": results}) - return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results}) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 91c20fddfa7..583e05af7c9 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -29,6 +29,23 @@ from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel +def _resolve_my_llm_is_tools(o_dict: dict) -> bool: + decode_api_key_config = getattr(TenantLLMService, "_decode_api_key_config", None) + if callable(decode_api_key_config): + _, is_tools, _ = decode_api_key_config(o_dict.get("api_key", "")) + if is_tools is not None: + return bool(is_tools) + + try: + base_name, fid = TenantLLMService.split_model_name_and_factory(o_dict["llm_name"]) + llm_cfg = LLMService.query(llm_name=base_name, fid=fid) if fid else LLMService.query(llm_name=base_name) + if not llm_cfg and fid: + llm_cfg = LLMService.query(llm_name=base_name) + return bool(llm_cfg[0].is_tools) if llm_cfg else False + except Exception: + return False + + @manager.route("/factories", methods=["GET"]) # noqa: F821 @login_required def factories(): @@ -185,7 +202,9 @@ def apikey_json(keys): elif factory == "Bedrock": # For Bedrock, due to its special authentication method # Assemble bedrock_ak, bedrock_sk, bedrock_region - api_key = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"]) + # Write into req["api_key"] to prevent the "existing key" override logic from replacing it + req["api_key"] = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"]) + api_key = req["api_key"] elif factory == "LocalAI": llm_name += "___LocalAI" @@ -226,6 +245,22 @@ def apikey_json(keys): elif factory == "PaddleOCR": api_key = apikey_json(["api_key", "provider_order"]) + elif factory == "OpenDataLoader": + api_key = apikey_json(["api_key", "provider_order"]) + + existing_llm = None + existing_api_key = None + if req.get("api_key") is None: + existing_llms = TenantLLMService.query(tenant_id=current_user.id, llm_factory=factory, llm_name=llm_name) + if existing_llms: + existing_llm = existing_llms[0] + existing_api_key, _, existing_api_key_payload = TenantLLMService._decode_api_key_config(existing_llm.api_key) + if existing_api_key_payload is not None: + existing_api_key = existing_api_key_payload + + if req.get("api_key") is None: + api_key = existing_api_key if existing_api_key is not None else "x" + llm = { "tenant_id": current_user.id, "llm_factory": factory, @@ -350,6 +385,9 @@ def drain_tts(): if msg: return get_data_error_result(message=msg) + if "is_tools" in req: + llm["api_key"] = TenantLLMService._encode_api_key_config(llm["api_key"], bool(req["is_tools"])) + if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm["llm_name"]], llm): TenantLLMService.save(**llm) @@ -390,6 +428,7 @@ async def delete_factory(): def my_llms(): try: TenantLLMService.ensure_mineru_from_env(current_user.id) + TenantLLMService.ensure_opendataloader_from_env(current_user.id) include_details = request.args.get("include_details", "false").lower() == "true" if include_details: @@ -417,6 +456,7 @@ def my_llms(): "api_base": o_dict["api_base"] or "", "max_tokens": o_dict["max_tokens"] or 8192, "status": o_dict["status"] or "1", + "is_tools": _resolve_my_llm_is_tools(o_dict), } ) else: diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py new file mode 100644 index 00000000000..c0c6c604af7 --- /dev/null +++ b/api/apps/restful_apis/agent_api.py @@ -0,0 +1,1892 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import base64 +import copy +import hashlib +import hmac +import inspect +import ipaddress +import json +import logging +import time +from functools import partial, wraps + +import jwt +from quart import Response, jsonify, request + +from agent.canvas import Canvas +from agent.component import LLM +from agent.dsl_migration import normalize_chunker_dsl +from api.apps import current_user, login_required +from api.apps.services.canvas_replica_service import CanvasReplicaService +from api.db import CanvasCategory +from api.db.db_models import Task +from api.db.services.api_service import API4ConversationService +from api.db.services.canvas_service import ( + CanvasTemplateService, + UserCanvasService, + completion as agent_completion, + completion_openai, +) +from api.db.services.document_service import DocumentService +from api.db.services.file_service import FileService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.pipeline_operation_log_service import PipelineOperationLogService +from api.db.services.task_service import CANVAS_DEBUG_DOC_ID, TaskService, queue_dataflow +from api.db.services.user_service import TenantService, UserService +from api.db.services.user_canvas_version import UserCanvasVersionService +from api.utils.api_utils import ( + add_tenant_id_to_kwargs, + get_data_error_result, + get_json_result, + get_result, + get_request_json, + server_error_response, + validate_request, +) +from common import settings +from common.constants import RetCode +from common.misc_utils import get_uuid, thread_pool_exec +from peewee import MySQLDatabase, PostgresqlDatabase +from rag.flow.pipeline import Pipeline +from rag.nlp import search +from rag.utils.redis_conn import REDIS_CONN + + +def _require_canvas_access_sync(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not UserCanvasService.accessible(kwargs.get('agent_id'), kwargs.get('tenant_id')): + return get_json_result(data=False, message="Make sure you have permission to access the agent.", code=RetCode.OPERATING_ERROR) + return func(*args, **kwargs) + return wrapper + + +def _require_canvas_access_async(func): + @wraps(func) + async def wrapper(*args, **kwargs): + agent_id = kwargs.get('agent_id') + tenant_id = kwargs.get('tenant_id') + if not await thread_pool_exec(UserCanvasService.accessible, agent_id, tenant_id): + return get_json_result(data=False, message="Make sure you have permission to access the agent.", code=RetCode.OPERATING_ERROR) + return await func(*args, **kwargs) + return wrapper + + +def _require_canvas_owner_sync(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not UserCanvasService.query(user_id=kwargs.get('tenant_id'), id=kwargs.get('agent_id')): + return get_json_result(data=False, message="Only the owner of the agent is authorized for this operation.", code=RetCode.OPERATING_ERROR) + return func(*args, **kwargs) + return wrapper + + +def _get_user_nickname(user_id: str) -> str: + exists, user = UserService.get_by_id(user_id) + if not exists: + return user_id + return str(getattr(user, "nickname", "") or user_id) + + +def _build_sse_response(body): + resp = Response(body, mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + +def _normalize_agent_session(conv): + conv["messages"] = conv.pop("message") + for info in conv["messages"]: + if "prompt" in info: + info.pop("prompt") + conv["agent_id"] = conv.pop("dialog_id") + if isinstance(conv["reference"], dict): + if "chunks" in conv["reference"]: + conv["reference"] = [conv["reference"]] + else: + conv["reference"] = [value for _, value in sorted(conv["reference"].items(), key=lambda item: int(item[0]))] + + if conv["reference"]: + messages = [message for i, message in enumerate(conv["messages"]) if i != 0 and message["role"] != "user"] + for message, reference in zip(messages, conv["reference"]): + chunks = reference["chunks"] + message["reference"] = [ + { + "id": chunk.get("chunk_id", chunk.get("id")), + "content": chunk.get("content_with_weight", chunk.get("content")), + "document_id": chunk.get("doc_id", chunk.get("document_id")), + "document_name": chunk.get("docnm_kwd", chunk.get("document_name")), + "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")), + "image_id": chunk.get("image_id", chunk.get("img_id")), + "positions": chunk.get("positions", chunk.get("position_int")), + } + for chunk in chunks + ] + del conv["reference"] + return conv + + +def _agent_session_list_result(data, total): + return jsonify({"code": RetCode.SUCCESS, "message": "success", "data": data, "total": total}) + + +@manager.route("/agents//sessions", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_sync +def list_agent_sessions(agent_id, tenant_id): + session_id = request.args.get("id") + user_id = request.args.get("user_id") + page_number = int(request.args.get("page", 1)) + items_per_page = int(request.args.get("page_size", 30)) + keywords = request.args.get("keywords") + from_date = request.args.get("from_date") + to_date = request.args.get("to_date") + orderby = request.args.get("orderby", "update_time") + exp_user_id = request.args.get("exp_user_id") + desc = request.args.get("desc") not in {"False", "false"} + + if exp_user_id: + sessions = API4ConversationService.get_names(agent_id, exp_user_id) + return _agent_session_list_result(sessions, len(sessions)) + + include_dsl = request.args.get("dsl") not in {"False", "false"} + total, sessions = API4ConversationService.get_list( + agent_id, + tenant_id, + page_number, + items_per_page, + orderby, + desc, + session_id, + user_id, + include_dsl, + keywords, + from_date, + to_date, + exp_user_id=exp_user_id, + ) + sessions = [_normalize_agent_session(session) for session in sessions] + return _agent_session_list_result(sessions, total) + + +@manager.route("/agents//sessions", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_async +async def create_agent_session(agent_id, tenant_id): + req = await get_request_json() + user_id = req.get("user_id") or request.args.get("user_id", tenant_id) + release_mode = bool(req.get("release", request.args.get("release", False))) + + try: + cvs, dsl = UserCanvasService.get_agent_dsl_with_release(agent_id, release_mode, tenant_id) + except LookupError: + return get_data_error_result(message="Agent not found.") + except PermissionError as e: + return get_data_error_result(message=str(e)) + + session_id = get_uuid() + canvas = Canvas(dsl, tenant_id, agent_id, canvas_id=cvs.id) + canvas.reset() + + cvs.dsl = json.loads(str(canvas)) + version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=release_mode) + conv = { + "id": session_id, + "name": req.get("name", ""), + "dialog_id": cvs.id, + "user_id": user_id, + "exp_user_id": user_id, + "message": [{"role": "assistant", "content": canvas.get_prologue()}], + "source": "agent", + "dsl": cvs.dsl, + "reference": [], + "version_title": version_title, + } + API4ConversationService.save(**conv) + return get_result(data=_normalize_agent_session(conv)) + + +@manager.route("/agents//sessions/", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_sync +def get_agent_session(agent_id, session_id, tenant_id): + _, conv = API4ConversationService.get_by_id(session_id) + return get_json_result(data=conv.to_dict()) + + +@manager.route("/agents//sessions/", methods=["DELETE"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_sync +def delete_agent_session_item(agent_id, session_id, tenant_id): + return get_json_result(data=API4ConversationService.delete_by_id(session_id)) + + +@manager.route("/agents/download", methods=["GET"]) # noqa: F821 +async def download_agent_file(): + id = request.args.get("id") + created_by = request.args.get("created_by") + blob = FileService.get_blob(created_by, id) + return Response(blob) + + +async def _iter_session_completion_events(tenant_id, agent_id, req, return_trace): + # Stream and non-stream session completions share the same event parsing and trace injection. + trace_items = [] + async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): + if isinstance(answer, str): + try: + ans = json.loads(answer[5:]) + except Exception: + continue + else: + ans = answer + + event = ans.get("event") + if event == "node_finished": + if return_trace: + data = ans.get("data", {}) + trace_items.append( + { + "component_id": data.get("component_id"), + "trace": [copy.deepcopy(data)], + } + ) + ans.setdefault("data", {})["trace"] = trace_items + yield ans + continue + + if event in ["message", "message_end"]: + yield ans + + +@manager.route("/agents/templates", methods=["GET"]) # noqa: F821 +@login_required +def list_agent_template(): + return get_json_result(data=[item.to_dict() for item in CanvasTemplateService.get_all()]) + + +@manager.route("/agents/prompts", methods=["GET"]) # noqa: F821 +@login_required +def prompts(): + from rag.prompts.generator import ( + ANALYZE_TASK_SYSTEM, + ANALYZE_TASK_USER, + CITATION_PROMPT_TEMPLATE, + NEXT_STEP, + REFLECT, + ) + + return get_json_result( + data={ + "task_analysis": f"{ANALYZE_TASK_SYSTEM}\n\n{ANALYZE_TASK_USER}", + "plan_generation": NEXT_STEP, + "reflection": REFLECT, + "citation_guidelines": CITATION_PROMPT_TEMPLATE, + } + ) + + +@manager.route("/agents", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def list_agents(tenant_id): + keywords = request.args.get("keywords", "") + canvas_category = request.args.get("canvas_category") + owner_ids = [item for item in request.args.get("owner_ids", "").strip().split(",") if item] + + page_number = int(request.args.get("page", 0)) + items_per_page = int(request.args.get("page_size", 0)) + order_by = request.args.get("orderby", "create_time") + desc = str(request.args.get("desc", "true")).lower() != "false" + tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) + authorized_owner_ids = {member["tenant_id"] for member in tenants} + authorized_owner_ids.add(tenant_id) + + if owner_ids: + requested_owner_ids = set(owner_ids) + unauthorized_owner_ids = requested_owner_ids - authorized_owner_ids + if unauthorized_owner_ids: + return get_json_result( + data=False, + message="Only authorized owner_ids can be queried.", + code=RetCode.OPERATING_ERROR, + ) + effective_owner_ids = list(requested_owner_ids) + else: + effective_owner_ids = list(authorized_owner_ids) + + canvas, total = UserCanvasService.get_by_tenant_ids( + effective_owner_ids, + tenant_id, + page_number, + items_per_page, + order_by, + desc, + keywords, + canvas_category, + ) + + return get_json_result(data={"canvas": canvas, "total": total}) + + +@manager.route("/agents", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def create_agent(tenant_id): + req = {k: v for k, v in (await get_request_json()).items() if v is not None} + req["user_id"] = tenant_id + req["canvas_category"] = req.get("canvas_category") or CanvasCategory.Agent + req["release"] = bool(req.get("release", "")) + + if req.get("dsl") is None: + return get_json_result( + data=False, + message="No DSL data in request.", + code=RetCode.ARGUMENT_ERROR, + ) + + try: + req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"]) + except ValueError as exc: + return get_json_result( + data=False, + message=str(exc), + code=RetCode.ARGUMENT_ERROR, + ) + + if req.get("title") is None: + return get_json_result( + data=False, + message="No title in request.", + code=RetCode.ARGUMENT_ERROR, + ) + + req["title"] = req["title"].strip() + if UserCanvasService.query( + user_id=tenant_id, + title=req["title"], + canvas_category=req["canvas_category"], + ): + return get_data_error_result(message=f"{req['title']} already exists.") + + req["id"] = get_uuid() + if not UserCanvasService.save(**req): + return get_data_error_result(message="Fail to create agent.") + + owner_nickname = _get_user_nickname(tenant_id) + UserCanvasVersionService.save_or_replace_latest( + user_canvas_id=req["id"], + title=UserCanvasVersionService.build_version_title(owner_nickname, req.get("title")), + dsl=req["dsl"], + release=req.get("release"), + ) + replica_ok = CanvasReplicaService.replace_for_set( + canvas_id=req["id"], + tenant_id=str(tenant_id), + runtime_user_id=str(tenant_id), + dsl=req["dsl"], + canvas_category=req["canvas_category"], + title=req.get("title", ""), + ) + if not replica_ok: + return get_data_error_result(message="canvas saved, but replica sync failed.") + + exists, created_agent = UserCanvasService.get_by_canvas_id(req["id"]) + if not exists: + return get_data_error_result(message="Fail to create agent.") + return get_json_result(data=created_agent) + + +@manager.route("/agents//upload", methods=["POST"]) # noqa: F821 +async def upload_agent_file(agent_id): + exists, canvas = UserCanvasService.get_by_canvas_id(agent_id) + if not exists: + return get_data_error_result(message="canvas not found.") + + user_id = canvas["user_id"] + files = await request.files + file_objs = files.getlist("file") if files and files.get("file") else [] + try: + if len(file_objs) == 1: + return get_json_result( + data=FileService.upload_info(user_id, file_objs[0], request.args.get("url")) + ) + results = [FileService.upload_info(user_id, file_obj) for file_obj in file_objs] + return get_json_result(data=results) + except Exception as exc: + return server_error_response(exc) + + +@manager.route("/agents//components//input-form", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_sync +def get_agent_component_input_form(agent_id, component_id, tenant_id): + try: + exists, user_canvas = UserCanvasService.get_by_id(agent_id) + if not exists: + return get_data_error_result(message="canvas not found.") + canvas = Canvas(json.dumps(user_canvas.dsl), tenant_id, canvas_id=user_canvas.id) + return get_json_result(data=canvas.get_component_input_form(component_id)) + except Exception as exc: + return server_error_response(exc) + + +@manager.route("/agents//components//debug", methods=["POST"]) # noqa: F821 +@validate_request("params") +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_async +async def debug_agent_component(agent_id, component_id, tenant_id): + req = await get_request_json() + try: + _, user_canvas = UserCanvasService.get_by_id(agent_id) + canvas = Canvas(json.dumps(user_canvas.dsl), tenant_id, canvas_id=user_canvas.id) + canvas.reset() + canvas.message_id = get_uuid() + component = canvas.get_component(component_id)["obj"] + component.reset() + + if isinstance(component, LLM): + component.set_debug_inputs(req["params"]) + component.invoke(**{k: o["value"] for k, o in req["params"].items()}) + outputs = component.output() + for k in outputs.keys(): + if isinstance(outputs[k], partial): + txt = "" + iter_obj = outputs[k]() + if inspect.isasyncgen(iter_obj): + async for c in iter_obj: + txt += c + else: + for c in iter_obj: + txt += c + outputs[k] = txt + return get_json_result(data=outputs) + except Exception as exc: + return server_error_response(exc) + + +@manager.route("/agents/", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def get_agent(agent_id, tenant_id): + if not UserCanvasService.accessible(agent_id, tenant_id): + return get_data_error_result(message="canvas not found.") + + exists, canvas = UserCanvasService.get_by_canvas_id(agent_id) + if not exists: + return get_data_error_result(message="canvas not found.") + + try: + CanvasReplicaService.bootstrap( + canvas_id=agent_id, + tenant_id=str(tenant_id), + runtime_user_id=str(tenant_id), + dsl=canvas.get("dsl"), + canvas_category=canvas.get("canvas_category", CanvasCategory.Agent), + title=canvas.get("title", ""), + ) + except ValueError as exc: + return get_data_error_result(message=str(exc)) + + last_publish_time = None + versions = UserCanvasVersionService.list_by_canvas_id(agent_id) + if versions: + released_versions = [version for version in versions if version.release] + if released_versions: + released_versions.sort(key=lambda version: version.update_time, reverse=True) + last_publish_time = released_versions[0].update_time + + canvas["dsl"] = normalize_chunker_dsl(canvas.get("dsl", {})) + canvas["last_publish_time"] = last_publish_time + + if canvas.get("canvas_category") == CanvasCategory.DataFlow: + datasets = list(KnowledgebaseService.query(pipeline_id=agent_id)) + canvas["datasets"] = [{"id": item.id, "name": item.name, "avatar": item.avatar} for item in datasets] + + return get_json_result(data=canvas) + + +@manager.route("/agents//versions", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_sync +def list_agent_versions(agent_id, tenant_id): + try: + versions = sorted( + [item.to_dict() for item in UserCanvasVersionService.list_by_canvas_id(agent_id)], + key=lambda item: item["update_time"] * -1, + ) + return get_json_result(data=versions) + except Exception as exc: + return get_data_error_result(message=f"Error getting history files: {exc}") + + +@manager.route("/agents//versions/", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_sync +def get_agent_version(agent_id, version_id, tenant_id): + try: + exists, version = UserCanvasVersionService.get_by_id(version_id) + if not exists or not version or str(version.user_canvas_id) != str(agent_id): + return get_data_error_result(message="Version not found.") + return get_json_result(data=version.to_dict()) + except Exception as exc: + return get_data_error_result(message=f"Error getting history file: {exc}") + + +@manager.route("/agents//logs/", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_sync +def get_agent_logs(agent_id, message_id, tenant_id): + try: + binary = REDIS_CONN.get(f"{agent_id}-{message_id}-logs") + if not binary: + return get_json_result(data={}) + + return get_json_result(data=json.loads(binary.encode("utf-8"))) + except Exception as exc: + logging.exception(exc) + return server_error_response(exc) + + +@manager.route("/agents/", methods=["DELETE"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_owner_sync +def delete_agent(agent_id, tenant_id): + UserCanvasService.delete_by_id(agent_id) + return get_json_result(data=True) + + +@manager.route("/agents/", methods=["PUT"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_async +async def update_agent(agent_id, tenant_id): + req = {k: v for k, v in (await get_request_json()).items() if v is not None} + req["release"] = bool(req.get("release", "")) + + if req.get("dsl") is not None: + try: + req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"]) + except ValueError as exc: + return get_json_result( + data=False, + message=str(exc), + code=RetCode.ARGUMENT_ERROR, + ) + + if req.get("title") is not None: + req["title"] = req["title"].strip() + + _, current_agent = UserCanvasService.get_by_id(agent_id) + agent_title_for_version = req.get("title") or (current_agent.title if current_agent else "") + canvas_category = ( + req.get("canvas_category") + or (current_agent.canvas_category if current_agent else CanvasCategory.Agent) + ) + owner_nickname = _get_user_nickname(tenant_id) + UserCanvasService.update_by_id(agent_id, req) + + if req.get("dsl") is not None: + UserCanvasVersionService.save_or_replace_latest( + user_canvas_id=agent_id, + title=UserCanvasVersionService.build_version_title(owner_nickname, agent_title_for_version), + dsl=req["dsl"], + release=req.get("release"), + ) + replica_ok = CanvasReplicaService.replace_for_set( + canvas_id=agent_id, + tenant_id=str(tenant_id), + runtime_user_id=str(tenant_id), + dsl=req["dsl"], + canvas_category=canvas_category, + title=agent_title_for_version, + ) + if not replica_ok: + return get_data_error_result(message="agent saved, but replica sync failed.") + + return get_json_result(data=True) + + +@manager.route("/agents//reset", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_async +async def reset_agent(agent_id, tenant_id): + try: + exists, user_canvas = UserCanvasService.get_by_id(agent_id) + if not exists: + return get_data_error_result(message="canvas not found.") + + canvas = Canvas(json.dumps(user_canvas.dsl), tenant_id, canvas_id=user_canvas.id) + canvas.reset() + dsl = json.loads(str(canvas)) + UserCanvasService.update_by_id(agent_id, {"dsl": dsl}) + replica_ok = CanvasReplicaService.replace_for_set( + canvas_id=agent_id, + tenant_id=str(tenant_id), + runtime_user_id=str(tenant_id), + dsl=dsl, + canvas_category=user_canvas.canvas_category, + title=user_canvas.title, + ) + if not replica_ok: + return get_data_error_result(message="agent reset, but replica sync failed.") + return get_json_result(data=dsl) + except Exception as exc: + return server_error_response(exc) + + +@manager.route("/agents/rerun", methods=["POST"]) # noqa: F821 +@validate_request("id", "dsl", "component_id") +@login_required +@add_tenant_id_to_kwargs +async def rerun_agent(tenant_id): + req = await get_request_json() + doc = PipelineOperationLogService.get_documents_info(req["id"]) + if not doc: + return get_data_error_result(message="Document not found.") + doc = doc[0] + if 0 < doc["progress"] < 1: + return get_data_error_result(message=f"`{doc['name']}` is processing...") + + if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc["kb_id"]): + settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(tenant_id), doc["kb_id"]) + doc["progress_msg"] = "" + doc["chunk_num"] = 0 + doc["token_num"] = 0 + DocumentService.clear_chunk_num_when_rerun(doc["id"]) + DocumentService.update_by_id(doc["id"], doc) + TaskService.filter_delete([Task.doc_id == doc["id"]]) + + dsl = req["dsl"] + dsl["path"] = [req["component_id"]] + PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl}) + queue_dataflow( + tenant_id=tenant_id, + flow_id=req["id"], + task_id=get_uuid(), + doc_id=doc["id"], + priority=0, + rerun=True, + ) + return get_json_result(data=True) + + +@manager.route("/agents/test_db_connection", methods=["POST"]) # noqa: F821 +@validate_request("db_type", "database", "username", "host", "port", "password") +@login_required +async def test_db_connection(): + req = await get_request_json() + try: + if req["db_type"] in ["mysql", "mariadb"]: + db = MySQLDatabase( + req["database"], + user=req["username"], + host=req["host"], + port=req["port"], + password=req["password"], + ) + elif req["db_type"] == "oceanbase": + db = MySQLDatabase( + req["database"], + user=req["username"], + host=req["host"], + port=req["port"], + password=req["password"], + charset="utf8mb4", + ) + elif req["db_type"] == "postgres": + db = PostgresqlDatabase( + req["database"], + user=req["username"], + host=req["host"], + port=req["port"], + password=req["password"], + ) + elif req["db_type"] == "mssql": + import pyodbc + + connection_string = ( + f"DRIVER={{ODBC Driver 17 for SQL Server}};" + f"SERVER={req['host']},{req['port']};" + f"DATABASE={req['database']};" + f"UID={req['username']};" + f"PWD={req['password']};" + ) + db = pyodbc.connect(connection_string) + cursor = db.cursor() + cursor.execute("SELECT 1") + cursor.close() + elif req["db_type"] == "IBM DB2": + import ibm_db + + conn_str = ( + f"DATABASE={req['database']};" + f"HOSTNAME={req['host']};" + f"PORT={req['port']};" + f"PROTOCOL=TCPIP;" + f"UID={req['username']};" + f"PWD={req['password']};" + ) + logging.info( + "DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=****;", + req["database"], + req["host"], + req["port"], + req["username"], + ) + conn = ibm_db.connect(conn_str, "", "") + stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1") + ibm_db.fetch_assoc(stmt) + ibm_db.close(conn) + return get_json_result(data="Database Connection Successful!") + elif req["db_type"] == "trino": + import os + import trino + + db_name = req["database"] + if "." in db_name: + catalog, schema = db_name.split(".", 1) + elif "/" in db_name: + catalog, schema = db_name.split("/", 1) + else: + catalog, schema = db_name, "default" + + http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http" + auth = None + if http_scheme == "https" and req.get("password"): + auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"]) + + conn = trino.dbapi.connect( + host=req["host"], + port=int(req["port"] or 8080), + user=req["username"] or "ragflow", + catalog=catalog, + schema=schema or "default", + http_scheme=http_scheme, + auth=auth, + ) + cur = conn.cursor() + cur.execute("SELECT 1") + cur.fetchall() + cur.close() + conn.close() + return get_json_result(data="Database Connection Successful!") + else: + return server_error_response("Unsupported database type.") + + if req["db_type"] != "mssql": + db.connect() + db.close() + return get_json_result(data="Database Connection Successful!") + except Exception as exc: + return server_error_response(exc) + + +@manager.route("/agents/chat/completion", methods=["POST"]) # noqa: F821 +@manager.route("/agents/chat/completions", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def agent_chat_completion(tenant_id, agent_id=None): + # This endpoint serves two execution modes: + # 1. Draft/runtime execution without session state. The request runs against the caller's + # runtime replica, which is populated from the editable canvas state. + # 2. Session continuation with an existing session_id. The request resumes from the stored + # API4Conversation state and must stay bound to the same agent and an accessible canvas. + # + # Security constraints: + # - agent_id is always supplied at the route layer and is not forwarded downstream as a free-form kwarg. + # - New runs without session_id must pass UserCanvasService.accessible(...) before the runtime replica is loaded. + # - Existing sessions are validated here at the route layer before handing control to the lower-level + # completion functions, so canvas_service only executes a pre-authorized session payload. + # + # Response modes: + # - Regular mode emits internal agent events. + # - openai-compatible mode reshapes the same execution into an OpenAI-like wire format. + req = await get_request_json() + agent_id = agent_id or req.get("agent_id") + openai_compatible = bool(req.get("openai-compatible", False)) + if not agent_id: + return get_json_result( + data=False, + message="`agent_id` is required.", + code=RetCode.ARGUMENT_ERROR, + ) + # Route-level selectors should not be forwarded into the lower-level completion functions. + req = dict(req) + req.pop("agent_id", None) + req.pop("openai-compatible", None) + session_id = req.get("session_id") + if session_id: + exists, conv = API4ConversationService.get_by_id(session_id) + if not exists: + return get_data_error_result(message="Session not found!") + if conv.dialog_id != agent_id: + return get_json_result( + data=False, + message="Session does not belong to the requested agent.", + code=RetCode.OPERATING_ERROR, + ) + if not UserCanvasService.accessible(agent_id, tenant_id): + return get_json_result( + data=False, + message="Only authorized users can access this agent session.", + code=RetCode.OPERATING_ERROR, + ) + + if openai_compatible: + # OpenAI-compatible mode uses a different wire format, keep it separate from regular agent events. + messages = req.get("messages", []) + if not messages: + return get_data_error_result(message="You must provide at least one message.") + question = next((m.get("content", "") for m in reversed(messages) if m.get("role") == "user"), "") + stream = req.pop("stream", False) + session_id = req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", "") + if stream: + return _build_sse_response( + completion_openai( + tenant_id, + agent_id, + question, + session_id=session_id, + stream=True, + **req, + ) + ) + + async for response in completion_openai( + tenant_id, + agent_id, + question, + session_id=session_id, + stream=False, + **req, + ): + return jsonify(response) + return None + + if not session_id: + # Without session state, run against the runtime replica that tracks draft edits. + query = req.get("query", "") or req.get("question", "") + files = req.get("files", []) + inputs = req.get("inputs", {}) + runtime_user_id = req.get("user_id") or tenant_id + user_id = str(runtime_user_id) + custom_header = req.get("custom_header", "") + + if not UserCanvasService.accessible(agent_id, tenant_id): + return get_json_result( + data=False, + message="Make sure you have permission to access the agent.", + code=RetCode.OPERATING_ERROR, + ) + + _, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id) + if not cvs: + return get_data_error_result(message="canvas not found.") + + replica_payload = CanvasReplicaService.load_for_run( + canvas_id=agent_id, + tenant_id=str(tenant_id), + runtime_user_id=user_id, + ) + if not replica_payload: + try: + replica_payload = CanvasReplicaService.bootstrap( + canvas_id=agent_id, + tenant_id=str(tenant_id), + runtime_user_id=user_id, + dsl=cvs.dsl, + canvas_category=getattr(cvs, "canvas_category", CanvasCategory.Agent), + title=getattr(cvs, "title", ""), + ) + except ValueError as exc: + return get_data_error_result(message=str(exc)) + if not replica_payload: + return get_data_error_result(message="canvas replica not found, please fetch the agent first.") + + replica_dsl = replica_payload.get("dsl", {}) + canvas_title = replica_payload.get("title", "") + canvas_category = replica_payload.get("canvas_category", CanvasCategory.Agent) + dsl_str = json.dumps(replica_dsl, ensure_ascii=False) + + if cvs.canvas_category == CanvasCategory.DataFlow: + task_id = get_uuid() + Pipeline( + dsl_str, + tenant_id=str(tenant_id), + doc_id=CANVAS_DEBUG_DOC_ID, + task_id=task_id, + flow_id=agent_id, + ) + ok, error_message = await thread_pool_exec( + queue_dataflow, + user_id, + agent_id, + task_id, + CANVAS_DEBUG_DOC_ID, + files[0], + 0, + ) + if not ok: + return get_data_error_result(message=error_message) + return get_json_result(data={"message_id": task_id}) + + try: + canvas = Canvas(dsl_str, str(tenant_id), canvas_id=agent_id, custom_header=custom_header) + except Exception as exc: + return server_error_response(exc) + + async def commit_runtime_replica(): + commit_ok = CanvasReplicaService.commit_after_run( + canvas_id=agent_id, + tenant_id=str(tenant_id), + runtime_user_id=user_id, + dsl=json.loads(str(canvas)), + canvas_category=canvas_category, + title=canvas_title, + ) + if not commit_ok: + logging.error( + "Canvas runtime replica commit failed: canvas_id=%s tenant_id=%s runtime_user_id=%s", + agent_id, + tenant_id, + user_id, + ) + + if req.get("stream", True): + async def sse(): + nonlocal canvas + try: + async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): + yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" + + await commit_runtime_replica() + except Exception as exc: + logging.exception(exc) + canvas.cancel_task() + yield ( + "data:" + + json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False) + + "\n\n" + ) + + return _build_sse_response(sse()) + + full_content = "" + reference = {} + final_ans = {} + trace_items = [] + structured_output = {} + try: + async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): + if ans.get("event") == "message": + full_content += ans.get("data", {}).get("content", "") + if ans.get("data", {}).get("reference", None): + reference.update(ans["data"]["reference"]) + if ans.get("event") == "node_finished": + data = ans.get("data", {}) + node_out = data.get("outputs", {}) + component_id = data.get("component_id") + if component_id is not None and "structured" in node_out: + structured_output[component_id] = copy.deepcopy(node_out["structured"]) + if req.get("return_trace", False): + trace_items.append( + { + "component_id": data.get("component_id"), + "trace": [copy.deepcopy(data)], + } + ) + final_ans = ans + except Exception as exc: + logging.exception(exc) + canvas.cancel_task() + return get_result(data=f"**ERROR**: {str(exc)}") + + if not final_ans: + await commit_runtime_replica() + return get_result(data={}) + + if "data" not in final_ans or not isinstance(final_ans["data"], dict): + final_ans["data"] = {} + final_ans["data"]["content"] = full_content + final_ans["data"]["reference"] = reference + if structured_output: + final_ans["data"]["structured"] = structured_output + if trace_items: + final_ans["data"]["trace"] = trace_items + + await commit_runtime_replica() + return get_result(data=final_ans) + + return_trace = bool(req.get("return_trace", False)) + if req.get("stream", True): + + async def generate(): + async for ans in _iter_session_completion_events(tenant_id, agent_id, req, return_trace): + yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" + yield "data:[DONE]\n\n" + + return _build_sse_response(generate()) + + full_content = "" + reference = {} + final_ans = {} + trace_items = [] + structured_output = {} + async for ans in _iter_session_completion_events(tenant_id, agent_id, req, return_trace): + try: + if ans["event"] == "message": + full_content += ans["data"]["content"] + if ans.get("data", {}).get("reference", None): + reference.update(ans["data"]["reference"]) + if ans.get("event") == "node_finished": + data = ans.get("data", {}) + node_out = data.get("outputs", {}) + component_id = data.get("component_id") + if component_id is not None and "structured" in node_out: + structured_output[component_id] = copy.deepcopy(node_out["structured"]) + if return_trace: + trace_items.append( + { + "component_id": data.get("component_id"), + "trace": [copy.deepcopy(data)], + } + ) + final_ans = ans + except Exception as exc: + return get_result(data=f"**ERROR**: {str(exc)}") + + if not final_ans: + return get_result(data={}) + + if "data" not in final_ans or not isinstance(final_ans["data"], dict): + final_ans["data"] = {} + final_ans["data"]["content"] = full_content + final_ans["data"]["reference"] = reference + if structured_output: + final_ans["data"]["structured"] = structured_output + if return_trace and final_ans: + final_ans["data"]["trace"] = trace_items + return get_result(data=final_ans) + + +@manager.route("/agents//webhook", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821 +@manager.route("/agents//webhook/test",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821 +async def webhook(agent_id: str): + is_test = request.path.startswith(f"/api/v1/agents/{agent_id}/webhook/test") + start_ts = time.time() + + # 1. Fetch canvas by agent_id + exists, cvs = UserCanvasService.get_by_id(agent_id) + if not exists: + return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST + + # 2. Check canvas category + if cvs.canvas_category == CanvasCategory.DataFlow: + return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST + + # 3. Load DSL from canvas + dsl = getattr(cvs, "dsl", None) + if not isinstance(dsl, dict): + return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST + + # 4. Check webhook configuration in DSL + webhook_cfg = {} + components = dsl.get("components", {}) + for k, _ in components.items(): + cpn_obj = components[k]["obj"] + if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook": + webhook_cfg = cpn_obj["params"] + + if not webhook_cfg: + return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST + + # 5. Validate request method against webhook_cfg.methods + allowed_methods = webhook_cfg.get("methods", []) + request_method = request.method.upper() + if allowed_methods and request_method not in allowed_methods: + return get_data_error_result( + code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook." + ),RetCode.BAD_REQUEST + + # 6. Validate webhook security + async def validate_webhook_security(security_cfg: dict): + """Validate webhook security rules based on security configuration.""" + + if not security_cfg: + return # No security config → allowed by default + + # 1. Validate max body size + await _validate_max_body_size(security_cfg) + + # 2. Validate IP whitelist + _validate_ip_whitelist(security_cfg) + + # # 3. Validate rate limiting + _validate_rate_limit(security_cfg) + + # 4. Validate authentication + auth_type = security_cfg.get("auth_type", "none") + + if auth_type == "none": + return + + if auth_type == "token": + _validate_token_auth(security_cfg) + + elif auth_type == "basic": + _validate_basic_auth(security_cfg) + + elif auth_type == "jwt": + _validate_jwt_auth(security_cfg) + + else: + raise Exception(f"Unsupported auth_type: {auth_type}") + + async def _validate_max_body_size(security_cfg): + """Check request size does not exceed max_body_size.""" + max_size = security_cfg.get("max_body_size") + if not max_size: + return + + # Convert "10MB" → bytes + units = {"kb": 1024, "mb": 1024**2} + size_str = max_size.lower() + + for suffix, factor in units.items(): + if size_str.endswith(suffix): + limit = int(size_str.replace(suffix, "")) * factor + break + else: + raise Exception("Invalid max_body_size format") + MAX_LIMIT = 10 * 1024 * 1024 # 10MB + if limit > MAX_LIMIT: + raise Exception("max_body_size exceeds maximum allowed size (10MB)") + + content_length = request.content_length or 0 + if content_length > limit: + raise Exception(f"Request body too large: {content_length} > {limit}") + + def _validate_ip_whitelist(security_cfg): + """Allow only IPs listed in ip_whitelist.""" + whitelist = security_cfg.get("ip_whitelist", []) + if not whitelist: + return + + client_ip = request.remote_addr + + + for rule in whitelist: + if "/" in rule: + # CIDR notation + if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False): + return + else: + # Single IP + if client_ip == rule: + return + + raise Exception(f"IP {client_ip} is not allowed by whitelist") + + def _validate_rate_limit(security_cfg): + """Simple in-memory rate limiting.""" + rl = security_cfg.get("rate_limit") + if not rl: + return + + limit = int(rl.get("limit", 60)) + if limit <= 0: + raise Exception("rate_limit.limit must be > 0") + per = rl.get("per", "minute") + + window = { + "second": 1, + "minute": 60, + "hour": 3600, + "day": 86400, + }.get(per) + + if not window: + raise Exception(f"Invalid rate_limit.per: {per}") + + capacity = limit + rate = limit / window + cost = 1 + + key = f"rl:tb:{agent_id}" + now = time.time() + + try: + res = REDIS_CONN.lua_token_bucket( + keys=[key], + args=[capacity, rate, now, cost], + client=REDIS_CONN.REDIS, + ) + + allowed = int(res[0]) + if allowed != 1: + raise Exception("Too many requests (rate limit exceeded)") + + except Exception as e: + raise Exception(f"Rate limit error: {e}") + + def _validate_token_auth(security_cfg): + """Validate header-based token authentication.""" + token_cfg = security_cfg.get("token",{}) + header = token_cfg.get("token_header") + token_value = token_cfg.get("token_value") + + provided = request.headers.get(header) + if provided != token_value: + raise Exception("Invalid token authentication") + + def _validate_basic_auth(security_cfg): + """Validate HTTP Basic Auth credentials.""" + auth_cfg = security_cfg.get("basic_auth", {}) + username = auth_cfg.get("username") + password = auth_cfg.get("password") + + auth = request.authorization + if not auth or auth.username != username or auth.password != password: + raise Exception("Invalid Basic Auth credentials") + + def _validate_jwt_auth(security_cfg): + """Validate JWT token in Authorization header.""" + jwt_cfg = security_cfg.get("jwt", {}) + secret = jwt_cfg.get("secret") + if not secret: + raise Exception("JWT secret not configured") + + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + raise Exception("Missing Bearer token") + + token = auth_header[len("Bearer "):].strip() + if not token: + raise Exception("Empty Bearer token") + + alg = (jwt_cfg.get("algorithm") or "HS256").upper() + + decode_kwargs = { + "key": secret, + "algorithms": [alg], + } + options = {} + if jwt_cfg.get("audience"): + decode_kwargs["audience"] = jwt_cfg["audience"] + options["verify_aud"] = True + else: + options["verify_aud"] = False + + if jwt_cfg.get("issuer"): + decode_kwargs["issuer"] = jwt_cfg["issuer"] + options["verify_iss"] = True + else: + options["verify_iss"] = False + try: + decoded = jwt.decode( + token, + options=options, + **decode_kwargs, + ) + except Exception as e: + raise Exception(f"Invalid JWT: {str(e)}") + + raw_required_claims = jwt_cfg.get("required_claims", []) + if isinstance(raw_required_claims, str): + required_claims = [raw_required_claims] + elif isinstance(raw_required_claims, (list, tuple, set)): + required_claims = list(raw_required_claims) + else: + required_claims = [] + + required_claims = [ + c for c in required_claims + if isinstance(c, str) and c.strip() + ] + + RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"} + for claim in required_claims: + if claim in RESERVED_CLAIMS: + raise Exception(f"Reserved JWT claim cannot be required: {claim}") + + for claim in required_claims: + if claim not in decoded: + raise Exception(f"Missing JWT claim: {claim}") + + return decoded + + try: + security_config=webhook_cfg.get("security", {}) + await validate_webhook_security(security_config) + except Exception as e: + return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST + if not isinstance(cvs.dsl, str): + dsl = json.dumps(cvs.dsl, ensure_ascii=False) + try: + canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id) + except Exception as e: + resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)) + resp.status_code = RetCode.BAD_REQUEST + return resp + + # 7. Parse request body + async def parse_webhook_request(content_type): + """Parse request based on content-type and return structured data.""" + + # 1. Query + query_data = {k: v for k, v in request.args.items()} + + # 2. Headers + header_data = {k: v for k, v in request.headers.items()} + + # 3. Body + ctype = request.headers.get("Content-Type", "").split(";")[0].strip() + if ctype and ctype != content_type: + raise ValueError( + f"Invalid Content-Type: expect '{content_type}', got '{ctype}'" + ) + + body_data: dict = {} + + try: + if ctype == "application/json": + body_data = await request.get_json() or {} + + elif ctype == "multipart/form-data": + nonlocal canvas + form = await request.form + files = await request.files + + body_data = {} + + for key, value in form.items(): + body_data[key] = value + + if len(files) > 10: + raise Exception("Too many uploaded files") + for key, file in files.items(): + desc = FileService.upload_info( + cvs.user_id, # user + file, # FileStorage + None # url (None for webhook) + ) + file_parsed= await canvas.get_files_async([desc]) + body_data[key] = file_parsed + + elif ctype == "application/x-www-form-urlencoded": + form = await request.form + body_data = dict(form) + + else: + # text/plain / octet-stream / empty / unknown + raw = await request.get_data() + if raw: + try: + body_data = json.loads(raw.decode("utf-8")) + except Exception: + body_data = {} + else: + body_data = {} + + except Exception: + body_data = {} + + return { + "query": query_data, + "headers": header_data, + "body": body_data, + "content_type": ctype, + } + + def extract_by_schema(data, schema, name="section"): + """ + Extract only fields defined in schema. + Required fields must exist. + Optional fields default to type-based default values. + Type validation included. + """ + props = schema.get("properties", {}) + required = schema.get("required", []) + + extracted = {} + + for field, field_schema in props.items(): + field_type = field_schema.get("type") + + # 1. Required field missing + if field in required and field not in data: + raise Exception(f"{name} missing required field: {field}") + + # 2. Optional → default value + if field not in data: + extracted[field] = default_for_type(field_type) + continue + + raw_value = data[field] + + # 3. Auto convert value + try: + value = auto_cast_value(raw_value, field_type) + except Exception as e: + raise Exception(f"{name}.{field} auto-cast failed: {str(e)}") + + # 4. Type validation + if not validate_type(value, field_type): + raise Exception( + f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}" + ) + + extracted[field] = value + + return extracted + + + def default_for_type(t): + """Return default value for the given schema type.""" + if t == "file": + return [] + if t == "object": + return {} + if t == "boolean": + return False + if t == "number": + return 0 + if t == "string": + return "" + if t and t.startswith("array"): + return [] + if t == "null": + return None + return None + + def auto_cast_value(value, expected_type): + """Convert string values into schema type when possible.""" + + # Non-string values already good + if not isinstance(value, str): + return value + + v = value.strip() + + # Boolean + if expected_type == "boolean": + if v.lower() in ["true", "1"]: + return True + if v.lower() in ["false", "0"]: + return False + raise Exception(f"Cannot convert '{value}' to boolean") + + # Number + if expected_type == "number": + # integer + if v.isdigit() or (v.startswith("-") and v[1:].isdigit()): + return int(v) + + # float + try: + return float(v) + except Exception: + raise Exception(f"Cannot convert '{value}' to number") + + # Object + if expected_type == "object": + try: + parsed = json.loads(v) + if isinstance(parsed, dict): + return parsed + else: + raise Exception("JSON is not an object") + except Exception: + raise Exception(f"Cannot convert '{value}' to object") + + # Array + if expected_type.startswith("array"): + try: + parsed = json.loads(v) + if isinstance(parsed, list): + return parsed + else: + raise Exception("JSON is not an array") + except Exception: + raise Exception(f"Cannot convert '{value}' to array") + + # String (accept original) + if expected_type == "string": + return value + + # File + if expected_type == "file": + return value + # Default: do nothing + return value + + + def validate_type(value, t): + """Validate value type against schema type t.""" + if t == "file": + return isinstance(value, list) + + if t == "string": + return isinstance(value, str) + + if t == "number": + return isinstance(value, (int, float)) + + if t == "boolean": + return isinstance(value, bool) + + if t == "object": + return isinstance(value, dict) + + # array / array / array + if t.startswith("array"): + if not isinstance(value, list): + return False + + if "<" in t and ">" in t: + inner = t[t.find("<") + 1 : t.find(">")] + + # Check each element type + for item in value: + if not validate_type(item, inner): + return False + + return True + + return True + parsed = await parse_webhook_request(webhook_cfg.get("content_types")) + SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}}) + + # Extract strictly by schema + try: + query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query") + header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers") + body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body") + except Exception as e: + return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST + + clean_request = { + "query": query_clean, + "headers": header_clean, + "body": body_clean, + "input": parsed + } + + execution_mode = webhook_cfg.get("execution_mode", "Immediately") + response_cfg = webhook_cfg.get("response", {}) + + def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600): + key = f"webhook-trace-{agent_id}-logs" + + raw = REDIS_CONN.get(key) + obj = json.loads(raw) if raw else {"webhooks": {}} + + ws = obj["webhooks"].setdefault( + str(start_ts), + {"start_ts": start_ts, "events": []} + ) + + ws["events"].append({ + "ts": time.time(), + **event + }) + + REDIS_CONN.set_obj(key, obj, ttl) + + if execution_mode == "Immediately": + status = response_cfg.get("status", 200) + try: + status = int(status) + except (TypeError, ValueError): + return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST + + if not (200 <= status <= 399): + return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST + + body_tpl = response_cfg.get("body_template", "") + + def parse_body(body: str): + if not body: + return None, "application/json" + + try: + parsed = json.loads(body) + return parsed, "application/json" + except (json.JSONDecodeError, TypeError): + return body, "text/plain" + + + body, content_type = parse_body(body_tpl) + resp = Response( + json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body, + status=status, + content_type=content_type, + ) + + async def background_run(): + try: + async for ans in canvas.run( + query="", + user_id=cvs.user_id, + webhook_payload=clean_request + ): + if is_test: + append_webhook_trace(agent_id, start_ts, ans) + + if is_test: + append_webhook_trace( + agent_id, + start_ts, + { + "event": "finished", + "elapsed_time": time.time() - start_ts, + "success": True, + } + ) + + cvs.dsl = json.loads(str(canvas)) + UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict()) + + except Exception as e: + logging.exception("Webhook background run failed") + if is_test: + try: + append_webhook_trace( + agent_id, + start_ts, + { + "event": "error", + "message": str(e), + "error_type": type(e).__name__, + } + ) + append_webhook_trace( + agent_id, + start_ts, + { + "event": "finished", + "elapsed_time": time.time() - start_ts, + "success": False, + } + ) + except Exception: + logging.exception("Failed to append webhook trace") + + asyncio.create_task(background_run()) + return resp + else: + async def sse(): + nonlocal canvas + contents: list[str] = [] + status = 200 + try: + async for ans in canvas.run( + query="", + user_id=cvs.user_id, + webhook_payload=clean_request, + ): + if ans["event"] == "message": + content = ans["data"]["content"] + if ans["data"].get("start_to_think", False): + content = "" + elif ans["data"].get("end_to_think", False): + content = "" + if content: + contents.append(content) + if ans["event"] == "message_end": + status = int(ans["data"].get("status", status)) + if is_test: + append_webhook_trace( + agent_id, + start_ts, + ans + ) + if is_test: + append_webhook_trace( + agent_id, + start_ts, + { + "event": "finished", + "elapsed_time": time.time() - start_ts, + "success": True, + } + ) + final_content = "".join(contents) + return { + "message": final_content, + "success": True, + "code": status, + } + + except Exception as e: + if is_test: + append_webhook_trace( + agent_id, + start_ts, + { + "event": "error", + "message": str(e), + "error_type": type(e).__name__, + } + ) + append_webhook_trace( + agent_id, + start_ts, + { + "event": "finished", + "elapsed_time": time.time() - start_ts, + "success": False, + } + ) + return {"code": 400, "message": str(e),"success":False} + + result = await sse() + return Response( + json.dumps(result), + status=result["code"], + mimetype="application/json", + ) + + +@manager.route("/agents//webhook/logs", methods=["GET"]) # noqa: F821 +@login_required +async def webhook_trace(agent_id: str): + exists, cvs = UserCanvasService.get_by_id(agent_id) + if not exists or str(cvs.user_id) != str(current_user.id): + return get_data_error_result( + message="Canvas not found.", + ) + + def encode_webhook_id(start_ts: str) -> str: + WEBHOOK_ID_SECRET = "webhook_id_secret" + sig = hmac.new( + WEBHOOK_ID_SECRET.encode("utf-8"), + start_ts.encode("utf-8"), + hashlib.sha256, + ).digest() + return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=") + + def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None: + for ts in webhooks.keys(): + if encode_webhook_id(ts) == enc_id: + return ts + return None + since_ts = request.args.get("since_ts", type=float) + webhook_id = request.args.get("webhook_id") + + key = f"webhook-trace-{agent_id}-logs" + raw = REDIS_CONN.get(key) + + if since_ts is None: + now = time.time() + return get_json_result( + data={ + "webhook_id": None, + "events": [], + "next_since_ts": now, + "finished": False, + } + ) + + if not raw: + return get_json_result( + data={ + "webhook_id": None, + "events": [], + "next_since_ts": since_ts, + "finished": False, + } + ) + + obj = json.loads(raw) + webhooks = obj.get("webhooks", {}) + + if webhook_id is None: + candidates = [ + float(k) for k in webhooks.keys() if float(k) > since_ts + ] + + if not candidates: + return get_json_result( + data={ + "webhook_id": None, + "events": [], + "next_since_ts": since_ts, + "finished": False, + } + ) + + start_ts = min(candidates) + real_id = str(start_ts) + webhook_id = encode_webhook_id(real_id) + + return get_json_result( + data={ + "webhook_id": webhook_id, + "events": [], + "next_since_ts": start_ts, + "finished": False, + } + ) + + real_id = decode_webhook_id(webhook_id, webhooks) + + if not real_id: + return get_json_result( + data={ + "webhook_id": webhook_id, + "events": [], + "next_since_ts": since_ts, + "finished": True, + } + ) + + ws = webhooks.get(str(real_id)) + events = ws.get("events", []) + new_events = [e for e in events if e.get("ts", 0) > since_ts] + + next_ts = since_ts + for e in new_events: + next_ts = max(next_ts, e["ts"]) + + finished = any(e.get("event") == "finished" for e in new_events) + + return get_json_result( + data={ + "webhook_id": webhook_id, + "events": new_events, + "next_since_ts": next_ts, + "finished": finished, + } + ) diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py index 263294b53fa..fab74f5c62a 100644 --- a/api/apps/restful_apis/chat_api.py +++ b/api/apps/restful_apis/chat_api.py @@ -20,6 +20,7 @@ import re import tempfile from copy import deepcopy +from types import SimpleNamespace from quart import Response, request @@ -30,7 +31,7 @@ ) from api.db.services.chunk_feedback_service import ChunkFeedbackService from api.db.services.conversation_service import ConversationService, structure_answer -from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap +from api.db.services.dialog_service import DialogService, async_chat, gen_mindmap from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService @@ -67,6 +68,15 @@ "tts": False, "refine_multiturn": True, } +_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG = { + "system": "", + "prologue": "", + "parameters": [], + "empty_response": "", + "quote": False, + "tts": False, + "refine_multiturn": True, +} _DEFAULT_RERANK_MODELS = {"BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"} _READONLY_FIELDS = {"id", "tenant_id", "created_by", "create_time", "create_date", "update_time", "update_date"} _PERSISTED_FIELDS = set(DialogService.model._meta.fields) @@ -124,6 +134,39 @@ def _ensure_owned_chat(chat_id): ) +def _build_default_completion_dialog(): + return SimpleNamespace( + tenant_id=current_user.id, + llm_id="", + tenant_llm_id=None, + llm_setting={}, + prompt_config=deepcopy(_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG), + kb_ids=[], + top_n=6, + top_k=1024, + rerank_id="", + similarity_threshold=0.1, + vector_similarity_weight=0.3, + meta_data_filter=None, + ) + + +def _create_session_for_completion(chat_id, dialog, user_id): + conv = { + "id": get_uuid(), + "dialog_id": chat_id, + "name": "New session", + "message": [{"role": "assistant", "content": dialog.prompt_config.get("prologue", "")}], + "user_id": user_id, + "reference": [], + } + ConversationService.save(**conv) + ok, conv_obj = ConversationService.get_by_id(conv["id"]) + if not ok: + raise LookupError("Fail to create a session!") + return conv_obj + + def _validate_llm_id(llm_id, tenant_id, llm_setting=None): if not llm_id: return None @@ -565,6 +608,15 @@ async def bulk_delete_chats(): if not ids: return get_json_result(data={}) else: + # keep backward compatibility, DELETE with chat_id in request body + chat_id = req.get("chat_id") + if chat_id: + try: + if not DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value}): + return get_data_error_result(message=f"Failed to delete chat {chat_id}") + return get_json_result(data=True) + except Exception as ex: + return server_error_response(ex) return get_json_result(data={}) errors = [] @@ -671,7 +723,7 @@ async def get_session(chat_id, session_id): return server_error_response(ex) -@manager.route("/chats//sessions/", methods=["PUT"]) # noqa: F821 +@manager.route("/chats//sessions/", methods=["PATCH"]) # noqa: F821 @login_required async def update_session(chat_id, session_id): if not _ensure_owned_chat(chat_id): @@ -829,7 +881,7 @@ async def update_message_feedback(chat_id, session_id, msg_id): return server_error_response(ex) -@manager.route("/chats/tts", methods=["POST"]) # noqa: F821 +@manager.route("/chat/audio/speech", methods=["POST"]) # noqa: F821 @login_required async def tts(): req = await get_request_json() @@ -857,9 +909,9 @@ def stream_audio(): return resp -@manager.route("/chats/transcriptions", methods=["POST"]) # noqa: F821 +@manager.route("/chat/audio/transcription", methods=["POST"]) # noqa: F821 @login_required -async def transcriptions(): +async def transcription(): req = await request.form stream_mode = req.get("stream", "false").lower() == "true" files = await request.files @@ -915,7 +967,7 @@ async def event_stream(): return Response(event_stream(), content_type="text/event-stream") -@manager.route("/chats/mindmap", methods=["POST"]) # noqa: F821 +@manager.route("/chat/mindmap", methods=["POST"]) # noqa: F821 @login_required @validate_request("question", "kb_ids") async def mindmap(): @@ -933,10 +985,10 @@ async def mindmap(): return get_json_result(data=mind_map) -@manager.route("/chats/related_questions", methods=["POST"]) # noqa: F821 +@manager.route("/chat/recommendation", methods=["POST"]) # noqa: F821 @login_required @validate_request("question") -async def related_questions(): +async def recommendation(): req = await get_request_json() search_id = req.get("search_id", "") @@ -971,10 +1023,10 @@ async def related_questions(): return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) -@manager.route("/chats//sessions//completions", methods=["POST"]) # noqa: F821 +@manager.route("/chat/completions", methods=["POST"]) # noqa: F821 @login_required @validate_request("messages") -async def session_completion(chat_id, session_id): +async def session_completion(chat_id_in_arg=""): req = await get_request_json() msg = [] for m in req["messages"]: @@ -984,6 +1036,9 @@ async def session_completion(chat_id, session_id): continue msg.append(m) message_id = msg[-1].get("id") if msg else None + chat_id = req.pop("chat_id", "") or "" + chat_id = chat_id or chat_id_in_arg + session_id = req.pop("session_id", "") or "" chat_model_id = req.pop("llm_id", "") chat_model_config = {} @@ -993,21 +1048,41 @@ async def session_completion(chat_id, session_id): chat_model_config[model_config] = config try: - e, conv = ConversationService.get_by_id(session_id) - if not e: - return get_data_error_result(message="Session not found!") - if conv.dialog_id != chat_id: - return get_data_error_result(message="Session does not belong to this chat!") - conv.message = deepcopy(req["messages"]) - e, dia = DialogService.get_by_id(chat_id) - if not e: - return get_data_error_result(message="Chat not found!") + conv = None + if session_id and not chat_id: + return get_data_error_result(message="`chat_id` is required when `session_id` is provided.") + + if chat_id: + if not _ensure_owned_chat(chat_id): + return get_json_result( + data=False, + message="No authorization.", + code=RetCode.AUTHENTICATION_ERROR, + ) + e, dia = DialogService.get_by_id(chat_id) + if not e: + return get_data_error_result(message="Chat not found!") + if session_id: + e, conv = ConversationService.get_by_id(session_id) + if not e: + return get_data_error_result(message="Session not found!") + if conv.dialog_id != chat_id: + return get_data_error_result(message="Session does not belong to this chat!") + else: + conv = _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id)) + session_id = conv.id + conv.message = deepcopy(req["messages"]) + else: + dia = _build_default_completion_dialog() + dia.llm_setting = chat_model_config + del req["messages"] - if not conv.reference: - conv.reference = [] - conv.reference = [r for r in conv.reference if r] - conv.reference.append({"chunks": [], "doc_aggs": []}) + if conv is not None: + if not conv.reference: + conv.reference = [] + conv.reference = [r for r in conv.reference if r] + conv.reference.append({"chunks": [], "doc_aggs": []}) if chat_model_id: if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id): @@ -1015,16 +1090,21 @@ async def session_completion(chat_id, session_id): dia.llm_id = chat_model_id dia.llm_setting = chat_model_config - is_embedded = bool(chat_model_id) stream_mode = req.pop("stream", True) + def _format_answer(ans): + formatted = structure_answer(conv, ans, message_id, session_id) + if chat_id: + formatted["chat_id"] = chat_id + return formatted + async def stream(): nonlocal dia, msg, req, conv try: async for ans in async_chat(dia, msg, True, **req): - ans = structure_answer(conv, ans, message_id, conv.id) + ans = _format_answer(ans) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" - if not is_embedded: + if conv is not None: ConversationService.update_by_id(conv.id, conv.to_dict()) except Exception as ex: logging.exception(ex) @@ -1041,40 +1121,10 @@ async def stream(): answer = None async for ans in async_chat(dia, msg, **req): - answer = structure_answer(conv, ans, message_id, conv.id) - if not is_embedded: + answer = _format_answer(ans) + if conv is not None: ConversationService.update_by_id(conv.id, conv.to_dict()) break return get_json_result(data=answer) except Exception as ex: return server_error_response(ex) - - -@manager.route("/chats/ask", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("question", "kb_ids") -async def ask(): - req = await get_request_json() - uid = current_user.id - - search_id = req.get("search_id", "") - search_config = {} - if search_id: - if search_app := SearchService.get_detail(search_id): - search_config = search_app.get("search_config", {}) - - async def stream(): - nonlocal req, uid - try: - async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): - yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" - except Exception as ex: - yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - - resp = Response(stream(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp diff --git a/api/apps/restful_apis/chunk_api.py b/api/apps/restful_apis/chunk_api.py new file mode 100644 index 00000000000..13b5cb5801e --- /dev/null +++ b/api/apps/restful_apis/chunk_api.py @@ -0,0 +1,445 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import base64 +import datetime +import re + +import xxhash +from pydantic import BaseModel, Field, validator +from quart import request + +from api.apps import login_required +from api.db.joint_services.tenant_model_service import ( + get_model_config_by_id, + get_model_config_by_type_and_name, +) +from api.db.services.document_service import DocumentService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.tenant_llm_service import TenantLLMService +from api.utils.api_utils import ( + add_tenant_id_to_kwargs, + check_duplicate_ids, + get_error_data_result, + get_request_json, + get_result, + server_error_response, +) +from api.utils.image_utils import store_chunk_image +from common import settings +from common.constants import LLMType, ParserType, RetCode +from common.misc_utils import thread_pool_exec +from common.string_utils import is_content_empty, remove_redundant_spaces +from common.tag_feature_utils import validate_tag_features +from rag.app.qa import beAdoc, rmPrefix +from rag.nlp import rag_tokenizer, search + + +class Chunk(BaseModel): + id: str = "" + content: str = "" + document_id: str = "" + docnm_kwd: str = "" + important_keywords: list = Field(default_factory=list) + tag_kwd: list = Field(default_factory=list) + questions: list = Field(default_factory=list) + question_tks: str = "" + image_id: str = "" + available: bool = True + positions: list[list[int]] = Field(default_factory=list) + + @validator("positions") + def validate_positions(cls, value): + for sublist in value: + if len(sublist) != 5: + raise ValueError("Each sublist in positions must have a length of 5") + return value + + +def _map_doc(doc): + key_mapping = { + "chunk_num": "chunk_count", + "kb_id": "dataset_id", + "token_num": "token_count", + "parser_id": "chunk_method", + } + run_mapping = { + "0": "UNSTART", + "1": "RUNNING", + "2": "CANCEL", + "3": "DONE", + "4": "FAIL", + } + renamed_doc = {} + for key, value in doc.to_dict().items(): + renamed_doc[key_mapping.get(key, key)] = value + if key == "run": + renamed_doc["run"] = run_mapping.get(str(value)) + return renamed_doc + + +def _strip_chunk_runtime_fields(chunk): + for name in [name for name in chunk.keys() if re.search(r"(_vec$|_sm_|_tks|_ltks)", name)]: + del chunk[name] + return chunk + + +@manager.route("/datasets//documents//chunks", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def list_chunks(tenant_id, dataset_id, document_id): + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + doc = DocumentService.query(id=document_id, kb_id=dataset_id) + if not doc: + return get_error_data_result(message=f"You don't own the document {document_id}.") + doc = doc[0] + req = request.args + page = int(req.get("page", 1)) + size = int(req.get("page_size", 30)) + question = req.get("keywords", "") + query = { + "doc_ids": [document_id], + "page": page, + "size": size, + "question": question, + "sort": True, + } + if "available" in req: + query["available_int"] = 1 if req["available"] == "true" else 0 + + res = {"total": 0, "chunks": [], "doc": _map_doc(doc)} + if req.get("id"): + chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id]) + if not chunk: + return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.DATA_ERROR) + if str(chunk.get("doc_id", chunk.get("document_id"))) != str(document_id): + return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.DATA_ERROR) + _strip_chunk_runtime_fields(chunk) + res["total"] = 1 + final_chunk = { + "id": chunk.get("id", chunk.get("chunk_id")), + "content": chunk["content_with_weight"], + "document_id": chunk.get("doc_id", chunk.get("document_id")), + "docnm_kwd": chunk["docnm_kwd"], + "important_keywords": chunk.get("important_kwd", []), + "questions": chunk.get("question_kwd", []), + "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")), + "image_id": chunk.get("img_id", ""), + "available": bool(chunk.get("available_int", 1)), + "positions": chunk.get("position_int", []), + "tag_kwd": chunk.get("tag_kwd", []), + "tag_feas": chunk.get("tag_feas", {}), + } + res["chunks"].append(final_chunk) + _ = Chunk(**final_chunk) + elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id): + sres = await settings.retriever.search( + query, + search.index_name(tenant_id), + [dataset_id], + emb_mdl=None, + highlight=True, + ) + res["total"] = sres.total + for chunk_id in sres.ids: + d = { + "id": chunk_id, + "content": ( + remove_redundant_spaces(sres.highlight[chunk_id]) + if question and chunk_id in sres.highlight + else sres.field[chunk_id].get("content_with_weight", "") + ), + "document_id": sres.field[chunk_id]["doc_id"], + "docnm_kwd": sres.field[chunk_id]["docnm_kwd"], + "important_keywords": sres.field[chunk_id].get("important_kwd", []), + "tag_kwd": sres.field[chunk_id].get("tag_kwd", []), + "questions": sres.field[chunk_id].get("question_kwd", []), + "dataset_id": sres.field[chunk_id].get("kb_id", sres.field[chunk_id].get("dataset_id")), + "image_id": sres.field[chunk_id].get("img_id", ""), + "available": bool(int(sres.field[chunk_id].get("available_int", "1"))), + "positions": sres.field[chunk_id].get("position_int", []), + } + res["chunks"].append(d) + _ = Chunk(**d) + return get_result(data=res) + + +@manager.route("/datasets//documents//chunks/", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def get_chunk(tenant_id, dataset_id, document_id, chunk_id): + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + doc = DocumentService.query(id=document_id, kb_id=dataset_id) + if not doc: + return get_error_data_result(message=f"You don't own the document {document_id}.") + try: + chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) + if chunk is None or str(chunk.get("doc_id", chunk.get("document_id"))) != str(document_id): + return get_result(data=False, message="Chunk not found!", code=RetCode.DATA_ERROR) + return get_result(data=_strip_chunk_runtime_fields(chunk)) + except Exception as e: + if str(e).find("NotFoundError") >= 0: + return get_result(data=False, message="Chunk not found!", code=RetCode.DATA_ERROR) + return server_error_response(e) + + +@manager.route("/datasets//documents//chunks", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def add_chunk(tenant_id, dataset_id, document_id): + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + doc = DocumentService.query(id=document_id, kb_id=dataset_id) + if not doc: + return get_error_data_result(message=f"You don't own the document {document_id}.") + doc = doc[0] + req = await get_request_json() + if is_content_empty(req.get("content")): + return get_error_data_result(message="`content` is required") + if "important_keywords" in req and not isinstance(req["important_keywords"], list): + return get_error_data_result("`important_keywords` is required to be a list") + if "questions" in req and not isinstance(req["questions"], list): + return get_error_data_result("`questions` is required to be a list") + + chunk_id = xxhash.xxh64((req["content"] + document_id).encode("utf-8")).hexdigest() + d = { + "id": chunk_id, + "content_ltks": rag_tokenizer.tokenize(req["content"]), + "content_with_weight": req["content"], + } + d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) + d["important_kwd"] = req.get("important_keywords", []) + d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_keywords", []))) + d["question_kwd"] = [str(q).strip() for q in req.get("questions", []) if str(q).strip()] + d["question_tks"] = rag_tokenizer.tokenize("\n".join(req.get("questions", []))) + d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] + d["create_timestamp_flt"] = datetime.datetime.now().timestamp() + d["kb_id"] = dataset_id + d["docnm_kwd"] = doc.name + d["doc_id"] = document_id + + if "tag_kwd" in req: + if not isinstance(req["tag_kwd"], list): + return get_error_data_result("`tag_kwd` is required to be a list") + if not all(isinstance(t, str) for t in req["tag_kwd"]): + return get_error_data_result("`tag_kwd` must be a list of strings") + d["tag_kwd"] = req["tag_kwd"] + if "tag_feas" in req: + try: + d["tag_feas"] = validate_tag_features(req["tag_feas"]) + except ValueError as exc: + return get_error_data_result(f"`tag_feas` {exc}") + + image_base64 = req.get("image_base64") + if image_base64: + d["img_id"] = f"{dataset_id}-{chunk_id}" + d["doc_type_kwd"] = "image" + + tenant_embd_id = DocumentService.get_tenant_embd_id(document_id) + if tenant_embd_id: + model_config = get_model_config_by_id(tenant_embd_id) + else: + embd_id = DocumentService.get_embd_id(document_id) + model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id) + embd_mdl = TenantLLMService.model_instance(model_config) + v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) + v = 0.1 * v[0] + 0.9 * v[1] + d[f"q_{len(v)}_vec"] = v.tolist() + settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) + + if image_base64: + store_chunk_image(dataset_id, chunk_id, base64.b64decode(image_base64)) + + DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) + key_mapping = { + "id": "id", + "content_with_weight": "content", + "doc_id": "document_id", + "important_kwd": "important_keywords", + "tag_kwd": "tag_kwd", + "question_kwd": "questions", + "kb_id": "dataset_id", + "create_timestamp_flt": "create_timestamp", + "create_time": "create_time", + "document_keyword": "document", + "img_id": "image_id", + } + renamed_chunk = {new_key: d[key] for key, new_key in key_mapping.items() if key in d} + _ = Chunk(**renamed_chunk) + return get_result(data={"chunk": renamed_chunk}) + + +@manager.route("/datasets//documents//chunks", methods=["DELETE"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def rm_chunk(tenant_id, dataset_id, document_id): + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + docs = DocumentService.query(id=document_id, kb_id=dataset_id) + if not docs: + return get_error_data_result(message=f"You don't own the document {document_id}.") + req = await get_request_json() + if not req: + return get_result() + + chunk_ids = req.get("chunk_ids") + if not chunk_ids: + if req.get("delete_all") is True: + doc = docs[0] + DocumentService.delete_chunk_images(doc, tenant_id) + chunk_number = settings.docStoreConn.delete({"doc_id": document_id}, search.index_name(tenant_id), dataset_id) + if chunk_number != 0: + DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) + return get_result(message=f"deleted {chunk_number} chunks") + return get_result() + + unique_chunk_ids, duplicate_messages = check_duplicate_ids(chunk_ids, "chunk") + chunk_number = settings.docStoreConn.delete( + {"doc_id": document_id, "id": unique_chunk_ids}, + search.index_name(tenant_id), + dataset_id, + ) + if chunk_number != 0: + DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) + if chunk_number != len(unique_chunk_ids): + if len(unique_chunk_ids) == 0: + return get_result(message=f"deleted {chunk_number} chunks") + return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(unique_chunk_ids)}") + if duplicate_messages: + return get_result( + message=f"Partially deleted {chunk_number} chunks with {len(duplicate_messages)} errors", + data={"success_count": chunk_number, "errors": duplicate_messages}, + ) + return get_result(message=f"deleted {chunk_number} chunks") + + +@manager.route("/datasets//documents//chunks/", methods=["PATCH"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + doc = DocumentService.query(id=document_id, kb_id=dataset_id) + if not doc: + return get_error_data_result(message=f"You don't own the document {document_id}.") + doc = doc[0] + chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) + if chunk is None or str(chunk.get("doc_id", chunk.get("document_id"))) != str(document_id): + return get_error_data_result(f"Can't find this chunk {chunk_id}") + req = await get_request_json() + content = req.get("content") + if content is not None: + if is_content_empty(content): + return get_error_data_result(message="`content` is required") + else: + content = chunk.get("content_with_weight", "") + d = {"id": chunk_id, "content_with_weight": content} + d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"]) + d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) + if "important_keywords" in req: + if not isinstance(req["important_keywords"], list): + return get_error_data_result("`important_keywords` should be a list") + d["important_kwd"] = req.get("important_keywords", []) + d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"])) + if "questions" in req: + if not isinstance(req["questions"], list): + return get_error_data_result("`questions` should be a list") + d["question_kwd"] = [str(q).strip() for q in req.get("questions", []) if str(q).strip()] + d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["questions"])) + if "available" in req: + d["available_int"] = int(req["available"]) + if "positions" in req: + if not isinstance(req["positions"], list): + return get_error_data_result("`positions` should be a list") + d["position_int"] = req["positions"] + if "tag_kwd" in req: + if not isinstance(req["tag_kwd"], list): + return get_error_data_result("`tag_kwd` should be a list") + if not all(isinstance(t, str) for t in req["tag_kwd"]): + return get_error_data_result("`tag_kwd` must be a list of strings") + d["tag_kwd"] = req["tag_kwd"] + if "tag_feas" in req: + try: + d["tag_feas"] = validate_tag_features(req["tag_feas"]) + except ValueError as exc: + return get_error_data_result(f"`tag_feas` {exc}") + image_base64 = req.get("image_base64") + if image_base64: + d["img_id"] = f"{dataset_id}-{chunk_id}" + d["doc_type_kwd"] = "image" + + tenant_embd_id = DocumentService.get_tenant_embd_id(document_id) + if tenant_embd_id: + model_config = get_model_config_by_id(tenant_embd_id) + else: + embd_id = DocumentService.get_embd_id(document_id) + model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id) + embd_mdl = TenantLLMService.model_instance(model_config) + if doc.parser_id == ParserType.QA: + arr = [t for t in re.split(r"[\n\t]", d["content_with_weight"]) if len(t) > 1] + if len(arr) != 2: + return get_error_data_result(message="Q&A must be separated by TAB/ENTER key.") + q, a = rmPrefix(arr[0]), rmPrefix(arr[1]) + d = beAdoc(d, arr[0], arr[1], not any([rag_tokenizer.is_chinese(t) for t in q + a])) + + v, _ = embd_mdl.encode( + [ + doc.name, + d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"]), + ] + ) + v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] + d[f"q_{len(v)}_vec"] = v.tolist() + settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) + if image_base64: + store_chunk_image(dataset_id, chunk_id, base64.b64decode(image_base64)) + return get_result() + + +@manager.route("/datasets//documents//chunks", methods=["PATCH"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def switch_chunks(tenant_id, dataset_id, document_id): + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + req = await get_request_json() + if not req.get("chunk_ids"): + return get_error_data_result(message="`chunk_ids` is required.") + if "available_int" not in req and "available" not in req: + return get_error_data_result(message="`available_int` or `available` is required.") + available_int = int(req["available_int"]) if "available_int" in req else (1 if req.get("available") else 0) + + try: + def _switch_sync(): + e, doc = DocumentService.get_by_id(document_id) + if not e: + return get_error_data_result(message="Document not found!") + if not doc or str(doc.kb_id) != str(dataset_id): + return get_error_data_result(message="Document not found!") + for cid in req["chunk_ids"]: + if not settings.docStoreConn.update( + {"id": cid}, + {"available_int": available_int}, + search.index_name(tenant_id), + doc.kb_id, + ): + return get_error_data_result(message="Index updating failure") + return get_result(data=True) + + return await thread_pool_exec(_switch_sync) + except Exception as e: + return server_error_response(e) diff --git a/api/apps/connector_app.py b/api/apps/restful_apis/connector_api.py similarity index 86% rename from api/apps/connector_app.py rename to api/apps/restful_apis/connector_api.py index 0c123f70077..99a58930211 100644 --- a/api/apps/connector_app.py +++ b/api/apps/restful_apis/connector_api.py @@ -35,15 +35,30 @@ from api.apps import login_required, current_user from box_sdk_gen import BoxOAuth, OAuthConfig, GetAuthorizeUrlOptions - -@manager.route("/set", methods=["POST"]) # noqa: F821 +@manager.route("/connectors/", methods=["PATCH"]) # noqa: F821 @login_required -async def set_connector(): +async def update_connector(connector_id): req = await get_request_json() - if req.get("id"): + e, conn = ConnectorService.get_by_id(connector_id) + if not e: + return get_data_error_result(message="Can't find this Connector!") + + if req: conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req} - ConnectorService.update_by_id(req["id"], conn) - else: + conn["id"] = connector_id + ConnectorService.update_by_id(connector_id, conn) + + await asyncio.sleep(1) + e, conn = ConnectorService.get_by_id(connector_id) + + return get_json_result(data=conn.to_dict()) + + +@manager.route("/connectors", methods=["POST"]) # noqa: F821 +@login_required +async def create_connector(): + req = await get_request_json() + if req: req["id"] = get_uuid() conn = { "id": req["id"], @@ -65,13 +80,13 @@ async def set_connector(): return get_json_result(data=conn.to_dict()) -@manager.route("/list", methods=["GET"]) # noqa: F821 +@manager.route("/connectors", methods=["GET"]) # noqa: F821 @login_required def list_connector(): return get_json_result(data=ConnectorService.list(current_user.id)) -@manager.route("/", methods=["GET"]) # noqa: F821 +@manager.route("/connectors/", methods=["GET"]) # noqa: F821 @login_required def get_connector(connector_id): e, conn = ConnectorService.get_by_id(connector_id) @@ -80,7 +95,7 @@ def get_connector(connector_id): return get_json_result(data=conn.to_dict()) -@manager.route("//logs", methods=["GET"]) # noqa: F821 +@manager.route("/connectors//logs", methods=["GET"]) # noqa: F821 @login_required def list_logs(connector_id): req = request.args.to_dict(flat=True) @@ -88,7 +103,7 @@ def list_logs(connector_id): return get_json_result(data={"total": total, "logs": arr}) -@manager.route("//resume", methods=["PUT"]) # noqa: F821 +@manager.route("/connectors//resume", methods=["POST"]) # noqa: F821 @login_required async def resume(connector_id): req = await get_request_json() @@ -99,7 +114,7 @@ async def resume(connector_id): return get_json_result(data=True) -@manager.route("//rebuild", methods=["PUT"]) # noqa: F821 +@manager.route("/connectors//rebuild", methods=["POST"]) # noqa: F821 @login_required @validate_request("kb_id") async def rebuild(connector_id): @@ -110,7 +125,7 @@ async def rebuild(connector_id): return get_json_result(data=True) -@manager.route("//rm", methods=["POST"]) # noqa: F821 +@manager.route("/connectors/", methods=["DELETE"]) # noqa: F821 @login_required def rm_connector(connector_id): ConnectorService.resume(connector_id, TaskStatus.CANCEL) @@ -157,6 +172,22 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]: return {"web": web_section} +def _exchange_google_web_oauth_code( + client_config: dict[str, Any], + scopes: list[str], + redirect_uri: str, + code: str, + code_verifier: str | None, +) -> Flow: + flow = Flow.from_client_config(client_config, scopes=scopes) + flow.redirect_uri = redirect_uri + fetch_token_kwargs: dict[str, Any] = {"code": code} + if code_verifier: + fetch_token_kwargs["code_verifier"] = code_verifier + flow.fetch_token(**fetch_token_kwargs) + return flow + + async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"): status = "success" if success else "error" auto_close = "window.close();" if success else "" @@ -185,7 +216,7 @@ async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, sou return response -@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821 +@manager.route("/connectors/google/oauth/web/start", methods=["POST"]) # noqa: F821 @login_required @validate_request("credentials") async def start_google_web_oauth(): @@ -252,6 +283,7 @@ async def start_google_web_oauth(): "user_id": current_user.id, "client_config": client_config, "redirect_uri": redirect_uri, + "code_verifier": flow.code_verifier, "created_at": int(time.time()), } REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS) @@ -265,7 +297,7 @@ async def start_google_web_oauth(): ) -@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821 +@manager.route("/connectors/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821 async def google_gmail_web_oauth_callback(): state_id = request.args.get("state") error = request.args.get("error") @@ -283,6 +315,7 @@ async def google_gmail_web_oauth_callback(): state_obj = json.loads(state_cache) client_config = state_obj.get("client_config") redirect_uri = state_obj.get("redirect_uri", GMAIL_WEB_OAUTH_REDIRECT_URI) + code_verifier = state_obj.get("code_verifier") if not client_config: REDIS_CONN.delete(_web_state_cache_key(state_id, source)) return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) @@ -296,10 +329,13 @@ async def google_gmail_web_oauth_callback(): return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) try: - # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) - flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL]) - flow.redirect_uri = redirect_uri - flow.fetch_token(code=code) + flow = _exchange_google_web_oauth_code( + client_config=client_config, + scopes=GOOGLE_SCOPES[DocumentSource.GMAIL], + redirect_uri=redirect_uri, + code=code, + code_verifier=code_verifier, + ) except Exception as exc: # pragma: no cover - defensive logging.exception("Failed to exchange Google OAuth code: %s", exc) REDIS_CONN.delete(_web_state_cache_key(state_id, source)) @@ -316,7 +352,7 @@ async def google_gmail_web_oauth_callback(): return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source) -@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821 +@manager.route("/connectors/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821 async def google_drive_web_oauth_callback(): state_id = request.args.get("state") error = request.args.get("error") @@ -334,6 +370,7 @@ async def google_drive_web_oauth_callback(): state_obj = json.loads(state_cache) client_config = state_obj.get("client_config") redirect_uri = state_obj.get("redirect_uri", GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI) + code_verifier = state_obj.get("code_verifier") if not client_config: REDIS_CONN.delete(_web_state_cache_key(state_id, source)) return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) @@ -347,10 +384,13 @@ async def google_drive_web_oauth_callback(): return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) try: - # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) - flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) - flow.redirect_uri = redirect_uri - flow.fetch_token(code=code) + flow = _exchange_google_web_oauth_code( + client_config=client_config, + scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE], + redirect_uri=redirect_uri, + code=code, + code_verifier=code_verifier, + ) except Exception as exc: # pragma: no cover - defensive logging.exception("Failed to exchange Google OAuth code: %s", exc) REDIS_CONN.delete(_web_state_cache_key(state_id, source)) @@ -366,7 +406,7 @@ async def google_drive_web_oauth_callback(): return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source) -@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821 +@manager.route("/connectors/google/oauth/web/result", methods=["POST"]) # noqa: F821 @login_required @validate_request("flow_id") async def poll_google_web_result(): @@ -386,7 +426,7 @@ async def poll_google_web_result(): REDIS_CONN.delete(_web_result_cache_key(flow_id, source)) return get_json_result(data={"credentials": result.get("credentials")}) -@manager.route("/box/oauth/web/start", methods=["POST"]) # noqa: F821 +@manager.route("/connectors/box/oauth/web/start", methods=["POST"]) # noqa: F821 @login_required async def start_box_web_oauth(): req = await get_request_json() @@ -429,7 +469,7 @@ async def start_box_web_oauth(): "expires_in": WEB_FLOW_TTL_SECS,} ) -@manager.route("/box/oauth/web/callback", methods=["GET"]) # noqa: F821 +@manager.route("/connectors/box/oauth/web/callback", methods=["GET"]) # noqa: F821 async def box_web_oauth_callback(): flow_id = request.args.get("state") if not flow_id: @@ -471,7 +511,7 @@ async def box_web_oauth_callback(): return await _render_web_oauth_popup(flow_id, True, "Authorization completed successfully.", "box") -@manager.route("/box/oauth/web/result", methods=["POST"]) # noqa: F821 +@manager.route("/connectors/box/oauth/web/result", methods=["POST"]) # noqa: F821 @login_required @validate_request("flow_id") async def poll_box_web_result(): diff --git a/api/apps/restful_apis/dataset_api.py b/api/apps/restful_apis/dataset_api.py index 4f3ff2d59a4..55ded90e028 100644 --- a/api/apps/restful_apis/dataset_api.py +++ b/api/apps/restful_apis/dataset_api.py @@ -19,11 +19,13 @@ from quart import request from common.constants import RetCode from api.apps import login_required, current_user -from api.utils.api_utils import get_error_argument_result, get_error_data_result, get_result, add_tenant_id_to_kwargs +from api.utils.api_utils import get_error_argument_result, get_error_data_result, get_json_result, get_result, add_tenant_id_to_kwargs from api.utils.validation_utils import ( CreateDatasetReq, DeleteDatasetReq, ListDatasetReq, + SearchDatasetReq, + SearchDatasetsReq, UpdateDatasetReq, validate_and_parse_json_request, validate_and_parse_request_args, @@ -31,10 +33,54 @@ from api.apps.services import dataset_api_service +@manager.route("/datasets/tags/aggregation", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def aggregate_tags(tenant_id): + dataset_ids = request.args.get("dataset_ids", "").split(",") + dataset_ids = [d for d in dataset_ids if d] + if not dataset_ids: + return get_error_data_result(message="Lack of dataset_ids in query parameters") + + try: + success, result = dataset_api_service.aggregate_tags(dataset_ids, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets/metadata/flattened", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def get_flattened_metadata(tenant_id): + dataset_ids = request.args.get("dataset_ids", "").split(",") + dataset_ids = [d for d in dataset_ids if d] + if not dataset_ids: + return get_error_data_result(message="Lack of dataset_ids in query parameters") + + try: + success, result = dataset_api_service.get_flattened_metadata(dataset_ids, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + @manager.route("/datasets", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -async def create(tenant_id: str=None): +async def create(tenant_id: str = None): """ Create a new dataset. --- @@ -102,6 +148,8 @@ async def create(tenant_id: str=None): return get_result(data=result) else: return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") @@ -330,26 +378,188 @@ def list_datasets(tenant_id): return get_error_data_result(message="Internal server error") -@manager.route('/datasets//knowledge_graph', methods=['GET']) # noqa: F821 +@manager.route("/datasets/", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def get_dataset(tenant_id, dataset_id): + try: + success, result = dataset_api_service.get_dataset(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//ingestions/summary", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def get_ingestion_summary(tenant_id, dataset_id): + try: + success, result = dataset_api_service.get_ingestion_summary(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//tags", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def list_tags(tenant_id, dataset_id): + try: + success, result = dataset_api_service.list_tags(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//tags", methods=["DELETE"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def delete_tags(tenant_id, dataset_id): + req = await request.get_json() + if not req or "tags" not in req: + return get_error_data_result(message="Lack of tags in request body") + if not isinstance(req["tags"], list) or not all(isinstance(t, str) for t in req["tags"]): + return get_error_argument_result("tags must be a list of strings") + + try: + success, result = dataset_api_service.delete_tags(dataset_id, tenant_id, req["tags"]) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//tags", methods=["PUT"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -async def knowledge_graph(tenant_id, dataset_id): +async def rename_tag(tenant_id, dataset_id): + req = await request.get_json() + if not req or "from_tag" not in req or "to_tag" not in req: + return get_error_data_result(message="Lack of from_tag or to_tag in request body") + if not isinstance(req["from_tag"], str) or not isinstance(req["to_tag"], str): + return get_error_argument_result("from_tag and to_tag must be strings") + + if not req["from_tag"].strip() or not req["to_tag"].strip(): + return get_error_argument_result("from_tag and to_tag must not be empty") + + try: + success, result = dataset_api_service.rename_tag(dataset_id, tenant_id, req["from_tag"], req["to_tag"]) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets/search", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def search_datasets(tenant_id): + """Search (retrieval test) across multiple datasets. + + POST /api/v1/datasets/search + JSON body: {"dataset_ids": list[str] (required), "question": str (required), "doc_ids": list[str], "top_k": int, "page": int, "size": int, + "similarity_threshold": float, "vector_similarity_weight": float, "use_kg": bool, + "cross_languages": list[str], "keyword": bool, "meta_data_filter": dict} + Success: {"code": 0, "data": {"chunks": [...], "total": int, "labels": [...]}} + Errors: ARGUMENT_ERROR (101) for invalid payload; DATA_ERROR (102) for access denied or internal errors. + """ + req, err = await validate_and_parse_json_request(request, SearchDatasetsReq) + if err is not None: + return get_error_argument_result(err) + try: + success, result = await dataset_api_service.search_datasets(tenant_id, req) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + if "not_found" in str(e): + return get_error_data_result(message="No chunk found! Check the chunk status please!") + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//search", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def search(tenant_id, dataset_id): + """Search (retrieval test) within a dataset. + + POST /api/v1/datasets//search + JSON body: {"question": str (required), "doc_ids": list[str], "top_k": int, "page": int, "size": int, + "similarity_threshold": float, "vector_similarity_weight": float, "use_kg": bool, + "cross_languages": list[str], "keyword": bool, "meta_data_filter": dict} + Success: {"code": 0, "data": {"chunks": [...], "total": int, "labels": [...]}} + Errors: ARGUMENT_ERROR (101) for invalid payload; DATA_ERROR (102) for access denied or internal errors. + """ + req, err = await validate_and_parse_json_request(request, SearchDatasetReq) + if err is not None: + return get_error_argument_result(err) + req['dataset_ids'] = [dataset_id] + try: + success, result = await dataset_api_service.search_datasets(tenant_id, req) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + if "not_found" in str(e): + return get_error_data_result(message="No chunk found! Check the chunk status please!") + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//graph", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def get_knowledge_graph(tenant_id, dataset_id): + """Get the knowledge graph of a dataset. + + GET /api/v1/datasets//graph + Query params: optional filter params. + Success: {"code": 0, "data": {...}} + Errors: AUTHENTICATION_ERROR for access denied; DATA_ERROR for internal errors. + """ try: success, result = await dataset_api_service.get_knowledge_graph(dataset_id, tenant_id) if success: return get_result(data=result) else: - return get_result( - data=False, - message=result, - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message=result, code=RetCode.AUTHENTICATION_ERROR) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") -@manager.route('/datasets//knowledge_graph', methods=['DELETE']) # noqa: F821 +@manager.route("/datasets//graph", methods=["DELETE"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs def delete_knowledge_graph(tenant_id, dataset_id): @@ -358,67 +568,82 @@ def delete_knowledge_graph(tenant_id, dataset_id): if success: return get_result(data=result) else: - return get_result( - data=False, - message=result, - code=RetCode.AUTHENTICATION_ERROR - ) + return get_result(data=False, message=result, code=RetCode.AUTHENTICATION_ERROR) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") -@manager.route("/datasets//run_graphrag", methods=["POST"]) # noqa: F821 +@manager.route("/datasets//index", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -async def run_graphrag(tenant_id, dataset_id): +async def run_index(tenant_id, dataset_id): + index_type = request.args.get("type", "") + index_type = index_type.lower() try: - success, result = dataset_api_service.run_graphrag(dataset_id, tenant_id) + success, result = dataset_api_service.run_index(dataset_id, tenant_id, index_type) if success: return get_result(data=result) else: return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") -@manager.route("/datasets//trace_graphrag", methods=["GET"]) # noqa: F821 +@manager.route("/datasets//index", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -def trace_graphrag(tenant_id, dataset_id): +def trace_index(tenant_id, dataset_id): + index_type = request.args.get("type", "") + index_type = index_type.lower() try: - success, result = dataset_api_service.trace_graphrag(dataset_id, tenant_id) + success, result = dataset_api_service.trace_index(dataset_id, tenant_id, index_type) if success: return get_result(data=result) else: return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") -@manager.route("/datasets//run_raptor", methods=["POST"]) # noqa: F821 +@manager.route("/datasets//", methods=["DELETE"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -async def run_raptor(tenant_id, dataset_id): +def delete_index(tenant_id, dataset_id, index_type): + index_type = index_type.lower() + if index_type not in dataset_api_service._VALID_INDEX_TYPES: + return get_error_argument_result(f"Invalid index type '{index_type}'") + # `wipe` controls whether the persisted index artefacts (graph rows / + # raptor summaries) are removed. Default true preserves historical + # behaviour; pass wipe=false to cancel the running task while keeping + # prior progress so it can be resumed later. + wipe_arg = (request.args.get("wipe", "true") or "true").strip().lower() + wipe = wipe_arg not in ("false", "0", "no", "off") try: - success, result = dataset_api_service.run_raptor(dataset_id, tenant_id) + success, result = dataset_api_service.delete_index(dataset_id, tenant_id, index_type, wipe=wipe) if success: return get_result(data=result) else: return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") -@manager.route("/datasets//trace_raptor", methods=["GET"]) # noqa: F821 +@manager.route("/datasets//embedding", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -def trace_raptor(tenant_id, dataset_id): +async def run_embedding(tenant_id, dataset_id): try: - success, result = dataset_api_service.trace_raptor(dataset_id, tenant_id) + success, result = dataset_api_service.run_embedding(dataset_id, tenant_id) if success: return get_result(data=result) else: @@ -428,7 +653,70 @@ def trace_raptor(tenant_id, dataset_id): return get_error_data_result(message="Internal server error") -@manager.route("/datasets//auto_metadata", methods=["GET"]) # noqa: F821 +@manager.route("/datasets//embedding/check", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def check_embedding(tenant_id, dataset_id): + try: + req = await request.get_json() + if not req or not req.get("embd_id"): + return get_error_data_result(message="`embd_id` is required.") + status, result = dataset_api_service.check_embedding(dataset_id, tenant_id, req) + if status is True: + return get_result(data=result) + elif status == "not_effective": + return get_json_result(code=result["code"], message=result["message"], data=result["data"]) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//ingestions", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def list_ingestion_logs(tenant_id, dataset_id): + try: + page = int(request.args.get("page", 0)) + page_size = int(request.args.get("page_size", 0)) + orderby = request.args.get("orderby", "create_time") + desc = request.args.get("desc", "true").lower() != "false" + operation_status = request.args.getlist("operation_status") + create_date_from = request.args.get("create_date_from", None) + create_date_to = request.args.get("create_date_to", None) + log_type = request.args.get("log_type", "dataset") + keywords = request.args.get("keywords", None) + success, result = dataset_api_service.list_ingestion_logs(dataset_id, tenant_id, page, page_size, orderby, desc, operation_status, create_date_from, create_date_to, log_type, keywords) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//ingestions/", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def get_ingestion_log(tenant_id, dataset_id, log_id): + try: + success, result = dataset_api_service.get_ingestion_log(dataset_id, tenant_id, log_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//metadata/config", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs def get_auto_metadata(tenant_id, dataset_id): @@ -462,12 +750,14 @@ def get_auto_metadata(tenant_id, dataset_id): return get_result(data=result) else: return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") -@manager.route("/datasets//auto_metadata", methods=["PUT"]) # noqa: F821 +@manager.route("/datasets//metadata/config", methods=["PUT"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs async def update_auto_metadata(tenant_id, dataset_id): @@ -502,6 +792,7 @@ async def update_auto_metadata(tenant_id, dataset_id): type: object """ from api.utils.validation_utils import AutoMetadataConfig + cfg, err = await validate_and_parse_json_request(request, AutoMetadataConfig) if err is not None: return get_error_argument_result(err) @@ -512,6 +803,8 @@ async def update_auto_metadata(tenant_id, dataset_id): return get_result(data=result) else: return get_error_data_result(message=result) + except ValueError as e: + return get_error_argument_result(str(e)) except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") diff --git a/api/apps/restful_apis/document_api.py b/api/apps/restful_apis/document_api.py index b2e749f3e51..7300a55a9f7 100644 --- a/api/apps/restful_apis/document_api.py +++ b/api/apps/restful_apis/document_api.py @@ -15,26 +15,107 @@ # import logging import json +import os +import re +from pathlib import Path -from quart import request +from quart import request, make_response from peewee import OperationalError from pydantic import ValidationError -from api.apps import login_required +from api.apps import current_user, login_required +from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX from api.apps.services.document_api_service import validate_document_update_fields, map_doc_keys, \ - map_doc_keys_with_run_status, update_document_name_only, update_chunk_method_only, update_document_status_only -from api.constants import IMG_BASE64_PREFIX -from api.db import VALID_FILE_TYPES + map_doc_keys_with_run_status, update_document_name_only, update_chunk_method, update_document_status_only, \ + reset_document_for_reparse +from api.db import VALID_FILE_TYPES, FileType +from api.db.services import duplicate_name from api.db.services.doc_metadata_service import DocMetadataService +from api.db.db_models import Task from api.db.services.document_service import DocumentService +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.common.check_team_permission import check_kb_team_permission +from api.db.services.task_service import TaskService, cancel_all_task_of from api.utils.api_utils import get_data_error_result, get_error_data_result, get_result, get_json_result, \ - server_error_response, add_tenant_id_to_kwargs, get_request_json + server_error_response, add_tenant_id_to_kwargs, get_request_json, get_error_argument_result, check_duplicate_ids from api.utils.validation_utils import ( - UpdateDocumentReq, format_validation_error_message, + UpdateDocumentReq, format_validation_error_message, validate_and_parse_json_request, DeleteDocumentReq, ) -from common.constants import RetCode + +from common import settings +from common.constants import ParserType, RetCode, TaskStatus, SANDBOX_ARTIFACT_BUCKET from common.metadata_utils import convert_conditions, meta_filter, turn2jsonschema +from common.misc_utils import get_uuid, thread_pool_exec +from api.utils.file_utils import filename_type, thumbnail +from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url, apply_safe_file_response_headers +from common.ssrf_guard import assert_url_is_safe +from rag.nlp import search + + +@manager.route("/documents/upload", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def upload_info(tenant_id: str): + """ + Upload a document and get its parsed info. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: formData + name: file + type: file + required: false + description: File to upload. + - in: query + name: url + type: string + required: false + description: URL to fetch file from. + responses: + 200: + description: Successful operation. + """ + files = await request.files + file_objs = files.getlist("file") if files and files.get("file") else [] + url = request.args.get("url") + + if file_objs and url: + return get_error_argument_result("Provide either multipart file(s) or ?url=..., not both.") + + if not file_objs and not url: + return get_error_argument_result("Missing input: provide multipart file(s) or url") + + try: + if url and not file_objs: + try: + assert_url_is_safe(url) + except ValueError as ve: + logging.warning("upload_info: rejected unsafe url: %s", ve) + return get_error_argument_result(str(ve)) + + data = await thread_pool_exec(FileService.upload_info, tenant_id, None, url) + return get_result(data=data) + + if len(file_objs) == 1: + data = await thread_pool_exec(FileService.upload_info, tenant_id, file_objs[0], None) + return get_result(data=data) + + results = [await thread_pool_exec(FileService.upload_info, tenant_id, f, None) for f in file_objs] + return get_result(data=results) + except Exception as e: + logging.exception("upload_info failed") + return server_error_response(e) + @manager.route("/datasets//documents/", methods=["PATCH"]) # noqa: F821 @login_required @@ -125,16 +206,26 @@ async def update_document(tenant_id, dataset_id, document_id): if error := update_document_name_only(document_id, req["name"]): return error + # "parser_id" provided but does not match with existing doc's file type + if "parser_id" in req and ((doc.type == FileType.VISUAL and req["parser_id"] != "picture") + or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation")): + return get_data_error_result(message="Not supported yet!") + # parser config provided (already validated in UpdateDocumentReq), update it if update_doc_req.parser_config: + req["parser_config"].update(update_doc_req.parser_config.ext) DocumentService.update_parser_config(doc.id, req["parser_config"]) + # pipeline_id provided - reset document for reparse + if update_doc_req.pipeline_id: + if error := reset_document_for_reparse(doc, tenant_id, pipeline_id=update_doc_req.pipeline_id): + return error # chunk method provided - the update method will check if it's different with existing one - if update_doc_req.chunk_method: - if error := update_chunk_method_only(req, doc, dataset_id, tenant_id): + elif update_doc_req.chunk_method: + if error := update_chunk_method(req, doc, tenant_id): return error - if "enabled" in req: # already checked in UpdateDocumentReq - it's int if it's present + if "enabled" in req: # already checked in UpdateDocumentReq - it's int if present # "enabled" flag provided, the update method will check if it's changed and then update if so if error := update_document_status_only(int(req["enabled"]), doc, kb): return error @@ -189,6 +280,88 @@ async def metadata_summary(dataset_id, tenant_id): return server_error_response(e) +@manager.route("/datasets//metadata/update", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def metadata_batch_update(dataset_id, tenant_id): + """ + Batch update metadata for documents in a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + selector: + type: object + updates: + type: array + deletes: + type: array + responses: + 200: + description: Metadata updated successfully. + """ + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") + + req = await get_request_json() + selector = req.get("selector", {}) or {} + updates = req.get("updates", []) or [] + deletes = req.get("deletes", []) or [] + + if not isinstance(selector, dict): + return get_error_data_result(message="selector must be an object.") + if not isinstance(updates, list) or not isinstance(deletes, list): + return get_error_data_result(message="updates and deletes must be lists.") + + metadata_condition = selector.get("metadata_condition", {}) or {} + if metadata_condition and not isinstance(metadata_condition, dict): + return get_error_data_result(message="metadata_condition must be an object.") + + document_ids = selector.get("document_ids", []) or [] + if document_ids and not isinstance(document_ids, list): + return get_error_data_result(message="document_ids must be a list.") + + for upd in updates: + if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd: + return get_error_data_result(message="Each update requires key and value.") + for d in deletes: + if not isinstance(d, dict) or not d.get("key"): + return get_error_data_result(message="Each delete requires key.") + + target_doc_ids = set() + if document_ids: + kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id]) + invalid_ids = set(document_ids) - set(kb_doc_ids) + if invalid_ids: + return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}") + target_doc_ids = set(document_ids) + + if metadata_condition: + metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id]) + filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) + target_doc_ids = target_doc_ids & filtered_ids + if metadata_condition.get("conditions") and not target_doc_ids: + return get_result(data={"updated": 0, "matched_docs": 0}) + + target_doc_ids = list(target_doc_ids) + updated = DocMetadataService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes) + return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)}) + + @manager.route("/datasets//documents", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs @@ -259,19 +432,148 @@ async def upload_document(dataset_id, tenant_id): type: string description: Processing status. """ - from api.constants import FILE_NAME_LEN_LIMIT - from api.common.check_team_permission import check_kb_team_permission - from api.db.services.file_service import FileService - from common.misc_utils import thread_pool_exec - + upload_type = (request.args.get("type") or "local").lower() + e, kb = KnowledgebaseService.get_by_id(dataset_id) + if not e: + logging.error(f"Can't find the dataset with ID {dataset_id}!") + return get_error_data_result(message=f"Can't find the dataset with ID {dataset_id}!", code=RetCode.DATA_ERROR) + + if not check_kb_team_permission(kb, tenant_id): + logging.error("No authorization.") + return get_error_data_result(message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + + if upload_type == "web": + return await _upload_web_document(dataset_id, kb, tenant_id) + + if upload_type == "empty": + return await _upload_empty_document(dataset_id, kb, tenant_id) + + if upload_type != "local": + return get_error_data_result( + message='`type` must be one of "local", "web", or "empty".', + code=RetCode.ARGUMENT_ERROR, + ) + + return await _upload_local_documents(kb, tenant_id) + + +async def _upload_web_document(dataset_id, kb, tenant_id): + form = await request.form + name = (form.get("name") or "").strip() + url = form.get("url") + + if not name: + return get_error_data_result(message='Lack of "name"', code=RetCode.ARGUMENT_ERROR) + if not url: + return get_error_data_result(message='Lack of "url"', code=RetCode.ARGUMENT_ERROR) + if len(name.encode("utf-8")) > FILE_NAME_LEN_LIMIT: + return get_error_data_result( + message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", + code=RetCode.ARGUMENT_ERROR, + ) + if not is_valid_url(url): + return get_error_data_result(message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR) + + blob = html2pdf(url) + if not blob: + return server_error_response(ValueError("Download failure.")) + + root_folder = FileService.get_root_folder(tenant_id) + FileService.init_knowledgebase_docs(root_folder["id"], tenant_id) + kb_root_folder = FileService.get_kb_folder(tenant_id) + kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) + + try: + filename = duplicate_name(DocumentService.query, name=f"{name}.pdf", kb_id=kb.id) + filetype = filename_type(filename) + if filetype == FileType.OTHER.value: + raise RuntimeError("This type of file has not been supported yet!") + + location = filename + while settings.STORAGE_IMPL.obj_exist(dataset_id, location): + location += "_" + settings.STORAGE_IMPL.put(dataset_id, location, blob) + + doc = { + "id": get_uuid(), + "kb_id": kb.id, + "parser_id": kb.parser_id, + "pipeline_id": kb.pipeline_id, + "parser_config": kb.parser_config, + "created_by": tenant_id, + "type": filetype, + "name": filename, + "location": location, + "size": len(blob), + "thumbnail": thumbnail(filename, blob), + "suffix": Path(filename).suffix.lstrip("."), + } + if doc["type"] == FileType.VISUAL: + doc["parser_id"] = ParserType.PICTURE.value + if doc["type"] == FileType.AURAL: + doc["parser_id"] = ParserType.AUDIO.value + if re.search(r"\.(ppt|pptx|pages)$", filename): + doc["parser_id"] = ParserType.PRESENTATION.value + if re.search(r"\.(eml)$", filename): + doc["parser_id"] = ParserType.EMAIL.value + + DocumentService.insert(doc) + FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id) + return get_result(data=map_doc_keys_with_run_status(doc, run_status="0")) + except Exception as e: + return server_error_response(e) + + +async def _upload_empty_document(dataset_id, kb, tenant_id): + req = await get_request_json() + name = (req.get("name") or "").strip() + + if not name: + return get_error_data_result(message="File name can't be empty.", code=RetCode.ARGUMENT_ERROR) + if len(name.encode("utf-8")) > FILE_NAME_LEN_LIMIT: + return get_error_data_result( + message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", + code=RetCode.ARGUMENT_ERROR, + ) + if DocumentService.query(name=name, kb_id=dataset_id): + return get_error_data_result(message="Duplicated document name in the same dataset.") + + try: + kb_root_folder = FileService.get_kb_folder(kb.tenant_id) + if not kb_root_folder: + return get_error_data_result(message="Cannot find the root folder.") + kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) + if not kb_folder: + return get_error_data_result(message="Cannot find the kb folder for this file.") + + doc = DocumentService.insert( + { + "id": get_uuid(), + "kb_id": kb.id, + "parser_id": kb.parser_id, + "pipeline_id": kb.pipeline_id, + "parser_config": kb.parser_config, + "created_by": tenant_id, + "type": FileType.VIRTUAL, + "name": name, + "suffix": Path(name).suffix.lstrip("."), + "location": "", + "size": 0, + } + ) + FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], kb.tenant_id) + return get_result(data=map_doc_keys(doc)) + except Exception as e: + return server_error_response(e) + + +async def _upload_local_documents(kb, tenant_id): form = await request.form files = await request.files - - # Validation if "file" not in files: logging.error("No file part!") return get_error_data_result(message="No file part!", code=RetCode.ARGUMENT_ERROR) - + file_objs = files.getlist("file") for file_obj in file_objs: if file_obj is None or file_obj.filename is None or file_obj.filename == "": @@ -282,18 +584,6 @@ async def upload_document(dataset_id, tenant_id): logging.error(msg) return get_error_data_result(message=msg, code=RetCode.ARGUMENT_ERROR) - # KB Lookup - e, kb = KnowledgebaseService.get_by_id(dataset_id) - if not e: - logging.error(f"Can't find the dataset with ID {dataset_id}!") - return get_error_data_result(message=f"Can't find the dataset with ID {dataset_id}!", code=RetCode.DATA_ERROR) - - # Permission Check - if not check_kb_team_permission(kb, tenant_id): - logging.error("No authorization.") - return get_error_data_result(message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - - # File Upload (async) err, files = await thread_pool_exec( FileService.upload_document, kb, file_objs, tenant_id, parent_path=form.get("parent_path") @@ -307,10 +597,8 @@ async def upload_document(dataset_id, tenant_id): msg = "There seems to be an issue with your file format. please verify it is correct and not corrupted." logging.error(msg) return get_error_data_result(message=msg, code=RetCode.DATA_ERROR) - - files = [f[0] for f in files] # remove the blob - # Check if we should return raw files without document key mapping + files = [f[0] for f in files] # remove the blob return_raw_files = request.args.get("return_raw_files", "false").lower() == "true" if return_raw_files: @@ -432,19 +720,24 @@ def list_docs(dataset_id, tenant_id): logging.error(f"You don't own the dataset {dataset_id}. ") return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") - err_code, err_msg, docs, total = _get_docs_with_request(request, dataset_id) + if request.args.get("type") == "filter": + err_code, err_msg, payload, total = _get_doc_filters_with_request(request, dataset_id) + if err_code != RetCode.SUCCESS: + return get_data_error_result(code=err_code, message=err_msg) + return get_json_result(data={"total": total, "filter": payload}) + + err_code, err_msg, payload, total = _get_docs_with_request(request, dataset_id) if err_code != RetCode.SUCCESS: return get_data_error_result(code=err_code, message=err_msg) - renamed_doc_list = [map_doc_keys(doc) for doc in docs] + renamed_doc_list = [map_doc_keys(doc) for doc in payload] for doc_item in renamed_doc_list: if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX): - doc_item["thumbnail"] = f"/v1/document/image/{dataset_id}-{doc_item['thumbnail']}" + doc_item["thumbnail"] = f"/api/v1/documents/images/{dataset_id}-{doc_item['thumbnail']}" if doc_item.get("source_type"): doc_item["source_type"] = doc_item["source_type"].split("/")[0] if doc_item["parser_config"].get("metadata"): doc_item["parser_config"]["metadata"] = turn2jsonschema(doc_item["parser_config"]["metadata"]) - return get_json_result(data={"total": total, "docs": renamed_doc_list}) @@ -517,13 +810,21 @@ def _get_docs_with_request(req, dataset_id:str): doc_name = q.get("name") doc_id = q.get("id") - if doc_id and not DocumentService.query(id=doc_id, kb_id=dataset_id): - return RetCode.DATA_ERROR, f"You don't own the document {doc_id}.", [], 0 + if doc_id: + if not DocumentService.query(id=doc_id, kb_id=dataset_id): + return RetCode.DATA_ERROR, f"You don't own the document {doc_id}.", [], 0 + doc_ids_filter = [doc_id] # id provided, ignore other filters if doc_name and not DocumentService.query(name=doc_name, kb_id=dataset_id): return RetCode.DATA_ERROR, f"You don't own the document {doc_name}.", [], 0 + doc_ids = q.getlist("ids") + if doc_id and len(doc_ids) > 0: + return RetCode.DATA_ERROR, f"Should not provide both 'id':{doc_id} and 'ids'{doc_ids}" + if len(doc_ids) > 0: + doc_ids_filter = doc_ids + docs, total = DocumentService.get_by_kb_id(dataset_id, page, page_size, orderby, desc, keywords, run_status_converted, types, suffix, - doc_id=doc_id, name=doc_name, doc_ids_filter=doc_ids_filter, return_empty_metadata=return_empty_metadata) + name=doc_name, doc_ids=doc_ids_filter, return_empty_metadata=return_empty_metadata) # time range filter (0 means no bound) create_time_from = int(q.get("create_time_from", 0)) @@ -533,6 +834,40 @@ def _get_docs_with_request(req, dataset_id:str): return RetCode.SUCCESS, "", docs, total + +def _get_doc_filters_with_request(req, dataset_id: str): + """Get aggregated document filters with request parameters from a dataset.""" + q = req.args + + keywords = q.get("keywords", "") + + suffix = q.getlist("suffix") + + types = q.getlist("types") + if types: + invalid_types = {t for t in types if t not in VALID_FILE_TYPES} + if invalid_types: + msg = f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}" + return RetCode.DATA_ERROR, msg, {}, 0 + + run_status = q.getlist("run") + run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"} + run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status] + if run_status_converted: + invalid_status = {s for s in run_status_converted if s not in run_status_text_to_numeric.values()} + if invalid_status: + msg = f"Invalid filter run status conditions: {', '.join(invalid_status)}" + return RetCode.DATA_ERROR, msg, {}, 0 + + docs_filter, total = DocumentService.get_filter_by_kb_id( + dataset_id, + keywords, + run_status_converted, + types, + suffix, + ) + return RetCode.SUCCESS, "", docs_filter, total + def _parse_doc_id_filter_with_metadata(req, kb_id): """Parse document ID filter based on metadata conditions from the request. @@ -568,7 +903,7 @@ def _parse_doc_id_filter_with_metadata(req, kb_id): - The metadata_condition uses operators like: =, !=, >, <, >=, <=, contains, not contains, in, not in, start with, end with, empty, not empty. - The metadata parameter performs exact matching where values are OR'd within the same key - and AND'd across different keys. + & AND'd across different keys. Examples: Simple metadata filter (exact match): @@ -622,11 +957,11 @@ def _parse_doc_id_filter_with_metadata(req, kb_id): if metadata and not isinstance(metadata, dict): return RetCode.DATA_ERROR, "metadata must be an object.", [], return_empty_metadata - doc_ids_filter = None - metas = None + metas = dict() if metadata_condition or metadata: metas = DocMetadataService.get_flatted_meta_by_kbs([kb_id]) + doc_ids_filter = None if metadata_condition: doc_ids_filter = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) if metadata_condition.get("conditions") and not doc_ids_filter: @@ -651,6 +986,7 @@ def _parse_doc_id_filter_with_metadata(req, kb_id): metadata_doc_ids &= key_doc_ids if not metadata_doc_ids: return RetCode.SUCCESS, "", [], return_empty_metadata + if metadata_doc_ids is not None: if doc_ids_filter is None: doc_ids_filter = metadata_doc_ids @@ -660,3 +996,900 @@ def _parse_doc_id_filter_with_metadata(req, kb_id): return RetCode.SUCCESS, "", [], return_empty_metadata return RetCode.SUCCESS, "", list(doc_ids_filter) if doc_ids_filter is not None else [], return_empty_metadata + + +@manager.route("/datasets//documents", methods=["DELETE"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def delete_documents(tenant_id, dataset_id): + """ + Delete documents from a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset containing the documents. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Document deletion parameters. + required: true + schema: + type: object + properties: + ids: + type: array or null + items: + type: string + description: | + Specifies the documents to delete: + - An array of IDs, only the specified documents will be deleted. + delete_all: + type: boolean + default: false + description: Whether to delete all documents in the dataset. + responses: + 200: + description: Successful operation. + schema: + type: object + """ + req, err = await validate_and_parse_json_request(request, DeleteDocumentReq) + if err is not None or req is None: + return get_error_argument_result(err) + + try: + # Validate dataset exists and user has permission + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") + + # Get documents to delete + doc_ids = req.get("ids") or [] + delete_all = req.get("delete_all", False) + if not delete_all and len(doc_ids) == 0: + return get_error_data_result(message=f"should either provide doc ids or set delete_all(true), dataset: {dataset_id}. ") + + if len(doc_ids) > 0 and delete_all: + return get_error_data_result(message=f"should not provide both doc ids and delete_all(true), dataset: {dataset_id}. ") + if delete_all: + doc_ids = [doc.id for doc in DocumentService.query(kb_id=dataset_id)] + + dataset_doc_ids = {doc.id for doc in DocumentService.query(kb_id=dataset_id)} + invalid_ids = [doc_id for doc_id in doc_ids if doc_id not in dataset_doc_ids] + if invalid_ids: + return get_error_data_result( + message=f"These documents do not belong to dataset {dataset_id} or Document not found: {', '.join(invalid_ids)}" + ) + + # make sure each id is unique + unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_ids, "document") + if duplicate_messages: + logging.warning(f"duplicate_messages:{duplicate_messages}") + else: + doc_ids = unique_doc_ids + + # Delete documents using existing FileService.delete_docs + errors = await thread_pool_exec(FileService.delete_docs, doc_ids, tenant_id) + + if errors: + return get_error_data_result(message=str(errors)) + + return get_result(data={"deleted": len(doc_ids)}) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + +@manager.route("/datasets//documents//metadata/config", methods=["PUT"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def update_metadata_config(tenant_id, dataset_id, document_id): + """ + Update document metadata configuration. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: path + name: document_id + type: string + required: true + description: ID of the document. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Metadata configuration. + required: true + schema: + type: object + properties: + metadata: + type: object + description: Metadata configuration JSON. + responses: + 200: + description: Document updated successfully. + """ + # Verify ownership and existence of dataset + if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): + return get_error_data_result(message="You don't own the dataset.") + + # Verify document exists in the dataset + doc = DocumentService.query(id=document_id, kb_id=dataset_id) + if not doc: + msg = f"Document {document_id} not found in dataset {dataset_id}" + return get_error_data_result(message=msg) + doc = doc[0] + + # Get request body + req = await get_request_json() + if "metadata" not in req: + return get_error_argument_result(message="metadata is required") + + # Update parser config with metadata + try: + DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]}) + except Exception as e: + logging.error("error when update_parser_config", exc_info=e) + return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) + + # Get updated document + try: + e, doc = DocumentService.get_by_id(doc.id) + if not e: + return get_data_error_result(message="Document not found!") + except Exception as e: + return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) + + return get_result(data=doc.to_dict()) + + +@manager.route("/thumbnails", methods=["GET"]) # noqa: F821 +def list_thumbnails(): + """ + Get thumbnails for documents. + --- + tags: + - Documents + parameters: + - in: query + name: doc_ids + type: array + required: true + description: List of document IDs to get thumbnails for. + responses: + 200: + description: Successfully retrieved thumbnails + 400: + description: Missing document IDs + """ + from api.constants import IMG_BASE64_PREFIX + from api.db.services.document_service import DocumentService + + doc_ids = request.args.getlist("doc_ids") + if not doc_ids: + return get_json_result(data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR) + + try: + docs = DocumentService.get_thumbnails(doc_ids) + + for doc_item in docs: + if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX): + doc_item["thumbnail"] = f"/api/v1/documents/images/{doc_item['kb_id']}-{doc_item['thumbnail']}" + + return get_json_result(data={d["id"]: d["thumbnail"] for d in docs}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/datasets//documents/metadatas", methods=["PATCH"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def update_metadata(tenant_id, dataset_id): + """ + Update document metadata in batch. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Metadata update request. + required: true + schema: + type: object + properties: + selector: + type: object + description: Document selector. + properties: + document_ids: + type: array + items: + type: string + description: List of document IDs to update. + metadata_condition: + type: object + description: Filter documents by existing metadata. + updates: + type: array + items: + type: object + properties: + key: + type: string + value: + type: any + description: List of metadata key-value pairs to update. + deletes: + type: array + items: + type: object + properties: + key: + type: string + description: List of metadata keys to delete. + responses: + 200: + description: Metadata updated successfully. + """ + # Verify ownership of dataset + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + + # Get request body + req = await get_request_json() + selector = req.get("selector", {}) or {} + updates = req.get("updates", []) or [] + deletes = req.get("deletes", []) or [] + + # Validate selector + if not isinstance(selector, dict): + return get_error_data_result(message="selector must be an object.") + if not isinstance(updates, list) or not isinstance(deletes, list): + return get_error_data_result(message="updates and deletes must be lists.") + + # Validate metadata_condition + metadata_condition = selector.get("metadata_condition", {}) or {} + if metadata_condition and not isinstance(metadata_condition, dict): + return get_error_data_result(message="metadata_condition must be an object.") + + # Validate document_ids + document_ids = selector.get("document_ids", []) or [] + if document_ids and not isinstance(document_ids, list): + return get_error_data_result(message="document_ids must be a list.") + + # Validate updates + for upd in updates: + if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd: + return get_error_data_result(message="Each update requires key and value.") + + # Validate deletes + for d in deletes: + if not isinstance(d, dict) or not d.get("key"): + return get_error_data_result(message="Each delete requires key.") + + # Initialize target document IDs + target_doc_ids = set() + + # If document_ids provided, validate they belong to the dataset + if document_ids: + kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id]) + invalid_ids = set(document_ids) - set(kb_doc_ids) + if invalid_ids: + return get_error_data_result( + message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}" + ) + target_doc_ids = set(document_ids) + + # Apply metadata_condition filtering if provided + if metadata_condition: + metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id]) + filtered_ids = set( + meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")) + ) + target_doc_ids = target_doc_ids & filtered_ids + if metadata_condition.get("conditions") and not target_doc_ids: + return get_result(data={"updated": 0, "matched_docs": 0}) + + # Convert to list and perform update + target_doc_ids = list(target_doc_ids) + updated = DocMetadataService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes) + return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)}) + + +@manager.route("/documents/ingest", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def ingest(tenant_id): + req = await get_request_json() + try: + user_id = tenant_id + + error_code, error_message = await thread_pool_exec(_run_sync, user_id, req) + + if error_code: + logging.error(f"error when ingest documents:{req}, error message:{error_message}") + return get_json_result(error_code, error_message) + + return get_json_result(data=True) + except Exception as e: + logging.exception("document ingest/run failed") + return server_error_response(e) + +def _run_sync(user_id:str, req): + for doc_id in req["doc_ids"]: + if not DocumentService.accessible(doc_id, user_id): + return RetCode.AUTHENTICATION_ERROR, "No authorization." + + kb_table_num_map = {} + for doc_id in req["doc_ids"]: + info = {"run": str(req["run"]), "progress": 0} + rerun_with_delete = str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False) + if rerun_with_delete: + info["progress_msg"] = "" + info["chunk_num"] = 0 + info["token_num"] = 0 + + doc_tenant_id = DocumentService.get_tenant_id(doc_id) + if not doc_tenant_id: + return RetCode.DATA_ERROR, "Tenant not found!" + e, doc = DocumentService.get_by_id(doc_id) + if not e: + return RetCode.DATA_ERROR, "Document not found!" + + if str(req["run"]) == TaskStatus.CANCEL.value: + tasks = list(TaskService.query(doc_id=doc_id)) + has_unfinished_task = any((task.progress or 0) < 1 for task in tasks) + if str(doc.run) in [TaskStatus.RUNNING.value, TaskStatus.CANCEL.value] or has_unfinished_task: + cancel_all_task_of(doc_id) + else: + return RetCode.DATA_ERROR, "Cannot cancel a task that is not in RUNNING status" + if all([rerun_with_delete, str(doc.run) == TaskStatus.DONE.value]): + DocumentService.clear_chunk_num_when_rerun(doc_id) + + DocumentService.update_by_id(doc_id, info) + if req.get("delete", False): + TaskService.filter_delete([Task.doc_id == doc_id]) + if settings.docStoreConn.index_exist(search.index_name(doc_tenant_id), doc.kb_id): + settings.docStoreConn.delete({"doc_id": doc_id}, search.index_name(doc_tenant_id), doc.kb_id) + + if str(req["run"]) == TaskStatus.RUNNING.value: + if req.get("apply_kb"): + e, kb = KnowledgebaseService.get_by_id(doc.kb_id) + if not e: + raise LookupError("Can't find this dataset!") + doc.parser_config["llm_id"] = kb.parser_config.get("llm_id") + doc.parser_config["enable_metadata"] = kb.parser_config.get("enable_metadata", False) + doc.parser_config["metadata"] = kb.parser_config.get("metadata", {}) + DocumentService.update_parser_config(doc.id, doc.parser_config) + doc_dict = doc.to_dict() + DocumentService.run(doc_tenant_id, doc_dict, kb_table_num_map) + + return None, None + + +@manager.route("/datasets//documents/parse", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def parse_documents(tenant_id, dataset_id): + """ + Start parsing documents in a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Document parse parameters. + required: true + schema: + type: object + properties: + document_ids: + type: array + items: + type: string + description: List of document IDs to parse. + responses: + 200: + description: Successful operation. + """ + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + + req = await get_request_json() + if req is None: + return get_error_data_result(message="Request body is required") + + document_ids = req.get("document_ids") + if document_ids is None or not isinstance(document_ids, list): + return get_error_data_result(message="`document_ids` is required") + if len(document_ids) == 0: + return get_error_data_result(message="`document_ids` is required") + + # Check for duplicate document IDs + unique_doc_ids, duplicate_messages = check_duplicate_ids(document_ids, "document") + errors = duplicate_messages if duplicate_messages else [] + + # Validate all document IDs belong to the dataset + not_found_ids = [] + valid_doc_ids = [] + for doc_id in unique_doc_ids: + docs = DocumentService.query(kb_id=dataset_id, id=doc_id) + if not docs: + not_found_ids.append(doc_id) + else: + valid_doc_ids.append(doc_id) + + if not_found_ids: + errors.append(f"Documents not found: {not_found_ids}") + # Still parse valid documents, but return error code + if not valid_doc_ids: + return get_error_data_result(message=f"Documents not found: {not_found_ids}") + + try: + def _run_sync(): + kb_table_num_map = {} + success_count = 0 + for doc_id in valid_doc_ids: + e, doc = DocumentService.get_by_id(doc_id) + if not e: + errors.append(f"Document not found: {doc_id}") + continue + + info = {"run": str(TaskStatus.RUNNING.value), "progress": 0} + # If re-running a completed document, clear previous chunks + if str(doc.run) == TaskStatus.DONE.value: + DocumentService.clear_chunk_num_when_rerun(doc.id) + info["progress_msg"] = "" + info["chunk_num"] = 0 + info["token_num"] = 0 + + DocumentService.update_by_id(doc_id, info) + TaskService.filter_delete([Task.doc_id == doc_id]) + if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.delete({"doc_id": doc_id}, search.index_name(tenant_id), doc.kb_id) + + doc_dict = doc.to_dict() + DocumentService.run(tenant_id, doc_dict, kb_table_num_map) + success_count += 1 + + result = {"success_count": success_count} + if errors: + result["errors"] = errors + return result + + result = await thread_pool_exec(_run_sync) + if not_found_ids: + return get_error_data_result(message=f"Documents not found: {not_found_ids}") + return get_result(data=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//documents/stop", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def stop_parse_documents(tenant_id, dataset_id): + """ + Stop parsing documents in a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Document stop parse parameters. + required: true + schema: + type: object + properties: + document_ids: + type: array + items: + type: string + description: List of document IDs to stop parsing. + responses: + 200: + description: Successful operation. + """ + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + + req = await get_request_json() + if req is None: + return get_error_data_result(message="Request body is required") + + document_ids = req.get("document_ids") + if document_ids is None or not isinstance(document_ids, list): + return get_error_data_result(message="`document_ids` is required") + if len(document_ids) == 0: + return get_error_data_result(message="`document_ids` is required") + + # Check for duplicate document IDs + unique_doc_ids, duplicate_messages = check_duplicate_ids(document_ids, "document") + errors = duplicate_messages if duplicate_messages else [] + + # Validate all document IDs belong to the dataset + not_found_ids = [] + valid_doc_ids = [] + for doc_id in unique_doc_ids: + docs = DocumentService.query(kb_id=dataset_id, id=doc_id) + if not docs: + not_found_ids.append(doc_id) + else: + valid_doc_ids.append(doc_id) + + if not_found_ids: + return get_error_data_result(message=f"Documents not found: {not_found_ids}") + + try: + def _run_sync(): + success_count = 0 + for doc_id in valid_doc_ids: + e, doc = DocumentService.get_by_id(doc_id) + if not e: + errors.append(f"Document not found: {doc_id}") + continue + + # Check if the document is currently running + tasks = list(TaskService.query(doc_id=doc_id)) + has_unfinished_task = any((task.progress or 0) < 1 for task in tasks) + if str(doc.run) not in [TaskStatus.RUNNING.value, TaskStatus.CANCEL.value] and not has_unfinished_task: + errors.append("Can't stop parsing document that has not started or already completed") + continue + + cancel_all_task_of(doc_id) + DocumentService.update_by_id(doc_id, {"run": str(TaskStatus.CANCEL.value)}) + success_count += 1 + + result = {"success_count": success_count} + if errors: + result["errors"] = errors + return result + + result = await thread_pool_exec(_run_sync) + if not_found_ids: + return get_error_data_result(message=f"Documents not found: {not_found_ids}") + return get_result(data=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/documents/images/", methods=["GET"]) # noqa: F821 +async def get_document_image(image_id): + """ + Get a document image by ID. + --- + tags: + - Documents + parameters: + - name: image_id + in: path + required: true + schema: + type: string + description: The image ID (format: bucket-name-image-name) + responses: + 200: + description: Image file + content: + image/jpeg: + schema: + type: string + format: binary + """ + try: + arr = image_id.split("-") + if len(arr) != 2: + return get_data_error_result(message="Image not found.") + bkt, nm = image_id.split("-") + data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm) + response = await make_response(data) + response.headers.set("Content-Type", "image/JPEG") + return response + except Exception as e: + return server_error_response(e) + + +ARTIFACT_CONTENT_TYPES = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".svg": "image/svg+xml", + ".pdf": "application/pdf", + ".csv": "text/csv", + ".json": "application/json", + ".html": "text/html", +} + + +@manager.route("/documents/artifact/", methods=["GET"]) # noqa: F821 +@login_required +async def get_artifact(filename): + """ + Get an artifact file. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: filename + type: string + required: true + description: Name of the artifact file. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Artifact file returned successfully. + """ + from common import settings + + try: + bucket = SANDBOX_ARTIFACT_BUCKET + # Validate filename: must be uuid hex + allowed extension, nothing else + basename = os.path.basename(filename) + if basename != filename or "/" in filename or "\\" in filename: + return get_data_error_result(message="Invalid filename.") + ext = os.path.splitext(basename)[1].lower() + if ext not in ARTIFACT_CONTENT_TYPES: + return get_data_error_result(message="Invalid file type.") + data = await thread_pool_exec(settings.STORAGE_IMPL.get, bucket, basename) + if not data: + return get_data_error_result(message="Artifact not found.") + content_type = ARTIFACT_CONTENT_TYPES.get(ext, "application/octet-stream") + response = await make_response(data) + safe_filename = re.sub(r"[^\w.\-]", "_", basename) + apply_safe_file_response_headers(response, content_type, ext) + if not response.headers.get("Content-Disposition"): + response.headers.set("Content-Disposition", f'inline; filename="{safe_filename}"') + return response + except Exception as e: + return server_error_response(e) + + +@manager.route("/datasets//documents/batch-update-status", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def batch_update_document_status(tenant_id, dataset_id): + """ + Batch update status of documents within a dataset. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Document status update parameters. + required: true + schema: + type: object + required: + - doc_ids + - status + properties: + doc_ids: + type: array + items: + type: string + description: List of document IDs to update. + status: + type: string + enum: ["0", "1"] + description: New status (0 = disabled, 1 = enabled). + responses: + 200: + description: Document statuses updated successfully. + """ + + req = await get_request_json() + doc_ids = req.get("doc_ids", []) + if not isinstance(doc_ids, list) or not doc_ids: + return get_error_argument_result(message='"doc_ids" must be a non-empty list.') + if any(not isinstance(doc_id, str) or not doc_id for doc_id in doc_ids): + return get_error_argument_result(message='"doc_ids" must contain non-empty document IDs.') + + status = str(req.get("status", -1)) + if status not in ["0", "1"]: + return get_error_argument_result(message=f'"Status" must be either 0 or 1:{status}!') + + # Verify dataset ownership + if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): + return get_error_data_result(message="You don't own the dataset.") + + e, kb = KnowledgebaseService.get_by_id(dataset_id) + if not e: + return get_error_data_result(message="Can't find this dataset!") + + result = {} + has_error = False + for doc_id in doc_ids: + try: + e, doc = DocumentService.get_by_id(doc_id) + if not e: + result[doc_id] = {"error": "Document not found"} + has_error = True + continue + + if doc.kb_id != dataset_id: + logging.warning(f"Document {doc.kb_id} not in dataset {dataset_id}") + result[doc_id] = {"error": "Document not found in this dataset."} + has_error = True + continue + + current_status = str(doc.status) + if current_status == status: + result[doc_id] = {"status": status} + continue + if not DocumentService.update_by_id(doc_id, {"status": str(status)}): + result[doc_id] = {"error": "Database error (Document update)!"} + has_error = True + continue + + status_int = int(status) + if getattr(doc, "chunk_num", 0) > 0: + try: + ok = settings.docStoreConn.update( + {"doc_id": doc_id}, + {"available_int": status_int}, + search.index_name(kb.tenant_id), + doc.kb_id, + ) + except Exception as exc: + msg = str(exc) + if "3022" in msg: + result[doc_id] = {"error": "Document store table missing."} + else: + result[doc_id] = {"error": f"Document store update failed: {msg}"} + has_error = True + continue + if not ok: + result[doc_id] = {"error": "Database error (docStore update)!"} + has_error = True + continue + result[doc_id] = {"status": status} + except Exception as e: + result[doc_id] = {"error": f"Internal server error: {str(e)}"} + has_error = True + + if has_error: + return get_json_result(data=result, message="Partial failure", code=RetCode.SERVER_ERROR) + return get_json_result(data=result) + +@manager.route("/documents//preview", methods=["GET"]) # noqa: F821 +@login_required +async def get(doc_id): + """Return the raw file bytes for a document the requesting user is authorized to read. + + The user must belong to the tenant that owns the document's knowledge base; otherwise + the response is indistinguishable from a missing document to avoid cross-tenant ID + enumeration. + """ + try: + if not DocumentService.accessible(doc_id, current_user.id): + return get_data_error_result(message="Document not found!") + + e, doc = DocumentService.get_by_id(doc_id) + if not e: + return get_data_error_result(message="Document not found!") + + b, n = File2DocumentService.get_storage_address(doc_id=doc_id) + data = await thread_pool_exec(settings.STORAGE_IMPL.get, b, n) + response = await make_response(data) + + ext = re.search(r"\.([^.]+)$", doc.name.lower()) + ext = ext.group(1) if ext else None + content_type = None + if ext: + fallback_prefix = "image" if doc.type == FileType.VISUAL.value else "application" + content_type = CONTENT_TYPE_MAP.get(ext, f"{fallback_prefix}/{ext}") + apply_safe_file_response_headers(response, content_type, ext) + return response + except Exception as e: + return server_error_response(e) + + +@manager.route("/documents//download", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def download_attachment(tenant_id=None, doc_id=None, attachment_id=None): + """Stream a document's underlying file to the requesting user. + + Mirrors the authorization model of the preview endpoint: the user must belong + to the tenant that owns the document's knowledge base. A denial returns the + same "Document not found!" response so the endpoint cannot be used to + enumerate doc ids across tenants. + """ + try: + # Keep backward compatibility with older callers and unit tests that still + # pass `attachment_id` instead of the route parameter name. + doc_id = doc_id or attachment_id + if not DocumentService.accessible(doc_id, current_user.id): + return get_data_error_result(message="Document not found!") + ext = request.args.get("ext", "markdown") + data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, doc_id) + response = await make_response(data) + content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}") + apply_safe_file_response_headers(response, content_type, ext) + + return response + + except Exception as e: + return server_error_response(e) diff --git a/api/apps/file2document_app.py b/api/apps/restful_apis/file2document_api.py similarity index 63% rename from api/apps/file2document_app.py rename to api/apps/restful_apis/file2document_api.py index c82207ab73a..9c466a441d3 100644 --- a/api/apps/file2document_app.py +++ b/api/apps/restful_apis/file2document_api.py @@ -18,6 +18,7 @@ import logging from pathlib import Path +from api.common.check_team_permission import check_file_team_permission, check_kb_team_permission from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService @@ -25,10 +26,11 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from common.misc_utils import get_uuid -from common.constants import RetCode from api.db import FileType from api.db.services.document_service import DocumentService +logger = logging.getLogger(__name__) + def _convert_files(file_ids, kb_ids, user_id): """Synchronous worker: delete old docs and insert new ones for the given file/kb pairs.""" @@ -74,7 +76,7 @@ def _convert_files(file_ids, kb_ids, user_id): }) -@manager.route('/convert', methods=['POST']) # noqa: F821 +@manager.route('/files/link-to-datasets', methods=['POST']) # noqa: F821 @login_required @validate_request("file_ids", "kb_ids") async def convert(): @@ -89,13 +91,29 @@ async def convert(): # Validate all files exist before starting any work for file_id in file_ids: if not files_set.get(file_id): + logger.warning( + "user_id=%s resource_type=file resource_id=%s action=validate_file_lookup result=not_found file_ids=%s kb_ids=%s", + current_user.id, + file_id, + file_ids, + kb_ids, + ) return get_data_error_result(message="File not found!") # Validate all kb_ids exist before scheduling background work + kb_map = {} for kb_id in kb_ids: - e, _ = KnowledgebaseService.get_by_id(kb_id) + e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: + logger.warning( + "user_id=%s resource_type=dataset resource_id=%s action=validate_dataset_lookup result=not_found file_ids=%s kb_ids=%s", + current_user.id, + kb_id, + file_ids, + kb_ids, + ) return get_data_error_result(message="Can't find this dataset!") + kb_map[kb_id] = kb # Expand folders to their innermost file IDs all_file_ids = [] @@ -107,6 +125,38 @@ async def convert(): all_file_ids.append(file_id) user_id = current_user.id + for file_id in all_file_ids: + e, file = FileService.get_by_id(file_id) + if not e or not file: + logger.warning( + "user_id=%s resource_type=file resource_id=%s action=validate_expanded_file_lookup result=not_found file_ids=%s kb_ids=%s", + user_id, + file_id, + file_ids, + kb_ids, + ) + return get_data_error_result(message="File not found!") + if not check_file_team_permission(file, user_id): + logger.warning( + "user_id=%s resource_type=file resource_id=%s action=authorize_file result=denied file_ids=%s kb_ids=%s", + user_id, + file_id, + file_ids, + kb_ids, + ) + return get_data_error_result(message="No authorization.") + + for kb_id, kb in kb_map.items(): + if not check_kb_team_permission(kb, user_id): + logger.warning( + "user_id=%s resource_type=dataset resource_id=%s action=authorize_dataset result=denied file_ids=%s kb_ids=%s", + user_id, + kb_id, + file_ids, + kb_ids, + ) + return get_data_error_result(message="No authorization.") + # Run the blocking DB work in a thread so the event loop is not blocked. # For large folders this prevents 504 Gateway Timeout by returning as # soon as the background task is scheduled. @@ -115,39 +165,12 @@ async def convert(): future.add_done_callback( lambda f: logging.error("_convert_files failed: %s", f.exception()) if f.exception() else None ) - return get_json_result(data=True) - except Exception as e: - return server_error_response(e) - - -@manager.route('/rm', methods=['POST']) # noqa: F821 -@login_required -@validate_request("file_ids") -async def rm(): - req = await get_request_json() - file_ids = req["file_ids"] - if not file_ids: - return get_json_result( - data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR) - try: - for file_id in file_ids: - informs = File2DocumentService.get_by_file_id(file_id) - if not informs: - return get_data_error_result(message="Inform not found!") - for inform in informs: - if not inform: - return get_data_error_result(message="Inform not found!") - File2DocumentService.delete_by_file_id(file_id) - doc_id = inform.document_id - e, doc = DocumentService.get_by_id(doc_id) - if not e: - return get_data_error_result(message="Document not found!") - tenant_id = DocumentService.get_tenant_id(doc_id) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - if not DocumentService.remove_document(doc, tenant_id): - return get_data_error_result( - message="Database error (Document removal)!") + logger.info( + "user_id=%s resource_type=file_to_dataset_link resource_id=batch action=schedule_convert result=scheduled file_ids=%s kb_ids=%s", + user_id, + all_file_ids, + kb_ids, + ) return get_json_result(data=True) except Exception as e: return server_error_response(e) diff --git a/api/apps/restful_apis/file_api.py b/api/apps/restful_apis/file_api.py index fbe1e39d50a..b67aa30ffce 100644 --- a/api/apps/restful_apis/file_api.py +++ b/api/apps/restful_apis/file_api.py @@ -24,8 +24,10 @@ add_tenant_id_to_kwargs, get_error_argument_result, get_error_data_result, + get_json_result, get_result, ) +from common.constants import RetCode from api.utils.validation_utils import ( CreateFolderReq, DeleteFileReq, @@ -99,7 +101,7 @@ async def create_or_upload(tenant_id: str = None): @manager.route("/files", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -def list_files(tenant_id: str = None): +async def list_files(tenant_id: str = None): """ List files under a folder. --- @@ -185,10 +187,22 @@ async def delete(tenant_id: str = None): return get_error_argument_result(err) try: - success, result = await file_api_service.delete_files(tenant_id, req["ids"]) + # Get Authorization header to pass to Go backend + auth_header = request.headers.get("Authorization", "") + success, result = await file_api_service.delete_files(tenant_id, req["ids"], auth_header) if success: return get_result(data=result) else: + if isinstance(result, dict): + success_count = result.get("success_count", 0) + errors = result.get("errors", []) + return get_json_result( + code=RetCode.DATA_ERROR, + message=f"Partially deleted {success_count} files with {len(errors)} errors" + if success_count > 0 + else f"Deleted files failed with {len(errors)} errors", + data=result, + ) return get_error_data_result(message=result) except Exception as e: logging.exception(e) @@ -303,7 +317,7 @@ async def download(tenant_id: str = None, file_id: str = None): @manager.route("/files//parent", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -def parent_folder(tenant_id: str = None, file_id: str = None): +async def parent_folder(tenant_id: str = None, file_id: str = None): """ Get parent folder of a file. --- @@ -321,7 +335,7 @@ def parent_folder(tenant_id: str = None, file_id: str = None): description: Parent folder information. """ try: - success, result = file_api_service.get_parent_folder(file_id) + success, result = file_api_service.get_parent_folder(file_id, user_id=tenant_id) if success: return get_result(data=result) else: @@ -334,7 +348,7 @@ def parent_folder(tenant_id: str = None, file_id: str = None): @manager.route("/files//ancestors", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -def ancestors(tenant_id: str = None, file_id: str = None): +async def ancestors(tenant_id: str = None, file_id: str = None): """ Get all ancestor folders of a file. --- @@ -352,7 +366,7 @@ def ancestors(tenant_id: str = None, file_id: str = None): description: List of ancestor folders. """ try: - success, result = file_api_service.get_all_parent_folders(file_id) + success, result = file_api_service.get_all_parent_folders(file_id, user_id=tenant_id) if success: return get_result(data=result) else: @@ -360,5 +374,3 @@ def ancestors(tenant_id: str = None, file_id: str = None): except Exception as e: logging.exception(e) return get_error_data_result(message="Internal server error") - - diff --git a/api/apps/langfuse_app.py b/api/apps/restful_apis/langfuse_api.py similarity index 94% rename from api/apps/langfuse_app.py rename to api/apps/restful_apis/langfuse_api.py index 1d7993d365c..70b81b42c63 100644 --- a/api/apps/langfuse_app.py +++ b/api/apps/restful_apis/langfuse_api.py @@ -23,7 +23,7 @@ from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request -@manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821 +@manager.route("/langfuse/api-key", methods=["POST", "PUT"]) # noqa: F821 @login_required @validate_request("secret_key", "public_key", "host") async def set_api_key(): @@ -58,7 +58,7 @@ async def set_api_key(): return server_error_response(e) -@manager.route("/api_key", methods=["GET"]) # noqa: F821 +@manager.route("/langfuse/api-key", methods=["GET"]) # noqa: F821 @login_required @validate_request() def get_api_key(): @@ -82,7 +82,7 @@ def get_api_key(): return get_json_result(data=langfuse_entry) -@manager.route("/api_key", methods=["DELETE"]) # noqa: F821 +@manager.route("/langfuse/api-key", methods=["DELETE"]) # noqa: F821 @login_required @validate_request() def delete_api_key(): diff --git a/api/apps/mcp_server_app.py b/api/apps/restful_apis/mcp_api.py similarity index 62% rename from api/apps/mcp_server_app.py rename to api/apps/restful_apis/mcp_api.py index 187560d626b..ec384f6074d 100644 --- a/api/apps/mcp_server_app.py +++ b/api/apps/restful_apis/mcp_api.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,20 +13,49 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from quart import Response, request -from api.apps import current_user, login_required +from api.apps import current_user, login_required from api.db.db_models import MCPServer from api.db.services.mcp_server_service import MCPServerService from api.db.services.user_service import TenantService -from common.constants import RetCode, VALID_MCP_SERVER_TYPES - -from common.misc_utils import get_uuid, thread_pool_exec from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request from api.utils.web_utils import get_float, safe_json_parse +from common.constants import VALID_MCP_SERVER_TYPES from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions +from common.misc_utils import get_uuid, thread_pool_exec -@manager.route("/list", methods=["POST"]) # noqa: F821 + +def _get_mcp_ids_from_args() -> list[str]: + mcp_ids = request.args.getlist("mcp_ids") + if mcp_ids: + return [mcp_id for item in mcp_ids for mcp_id in item.split(",") if mcp_id] + mcp_ids = request.args.get("mcp_id", "") + return [mcp_id for mcp_id in mcp_ids.split(",") if mcp_id] + + +def _export_mcp_servers(mcp_ids: list[str]) -> dict | None: + exported_servers = {} + for mcp_id in mcp_ids: + e, mcp_server = MCPServerService.get_by_id(mcp_id) + if e and mcp_server.tenant_id == current_user.id: + server_key = mcp_server.name + exported_servers[server_key] = { + "type": mcp_server.server_type, + "url": mcp_server.url, + "name": mcp_server.name, + "authorization_token": mcp_server.variables.get("authorization_token", ""), + "tools": mcp_server.variables.get("tools", {}), + } + + if not exported_servers: + return None + + return {"mcpServers": exported_servers} + + +@manager.route("/mcp/servers", methods=["GET"]) # noqa: F821 @login_required async def list_mcp() -> Response: keywords = request.args.get("keywords", "") @@ -38,8 +67,7 @@ async def list_mcp() -> Response: else: desc = True - req = await get_request_json() - mcp_ids = req.get("mcp_ids", []) + mcp_ids = _get_mcp_ids_from_args() try: servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or [] total = len(servers) @@ -52,22 +80,27 @@ async def list_mcp() -> Response: return server_error_response(e) -@manager.route("/detail", methods=["GET"]) # noqa: F821 +@manager.route("/mcp/servers/", methods=["GET"]) # noqa: F821 @login_required -def detail() -> Response: - mcp_id = request.args["mcp_id"] +def detail(mcp_id: str) -> Response: try: + if request.args.get("mode") == "download": + exported_servers = _export_mcp_servers([mcp_id]) + if exported_servers is None: + return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") + return get_json_result(data=exported_servers) + mcp_server = MCPServerService.get_or_none(id=mcp_id, tenant_id=current_user.id) if mcp_server is None: - return get_json_result(code=RetCode.NOT_FOUND, data=None) + return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") return get_json_result(data=mcp_server.to_dict()) except Exception as e: return server_error_response(e) -@manager.route("/create", methods=["POST"]) # noqa: F821 +@manager.route("/mcp/servers", methods=["POST"]) # noqa: F821 @login_required @validate_request("name", "url", "server_type") async def create() -> Response: @@ -107,7 +140,7 @@ async def create() -> Response: mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) if err_message: - return get_data_error_result(err_message) + return get_data_error_result(message=err_message) tools = server_tools[server_name] tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool} @@ -115,20 +148,18 @@ async def create() -> Response: req["variables"] = variables if not MCPServerService.insert(**req): - return get_data_error_result("Failed to create MCP server.") + return get_data_error_result(message="Failed to create MCP server.") return get_json_result(data=req) except Exception as e: return server_error_response(e) -@manager.route("/update", methods=["POST"]) # noqa: F821 +@manager.route("/mcp/servers/", methods=["PUT"]) # noqa: F821 @login_required -@validate_request("mcp_id") -async def update() -> Response: +async def update(mcp_id: str) -> Response: req = await get_request_json() - mcp_id = req.get("mcp_id", "") e, mcp_server = MCPServerService.get_by_id(mcp_id) if not e or mcp_server.tenant_id != current_user.id: return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") @@ -153,13 +184,12 @@ async def update() -> Response: try: req["tenant_id"] = current_user.id - req.pop("mcp_id", None) req["id"] = mcp_id mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) server_tools, err_message = await thread_pool_exec(get_mcp_tools, [mcp_server], timeout) if err_message: - return get_data_error_result(err_message) + return get_data_error_result(message=err_message) tools = server_tools[server_name] tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool} @@ -178,25 +208,22 @@ async def update() -> Response: return server_error_response(e) -@manager.route("/rm", methods=["POST"]) # noqa: F821 +@manager.route("/mcp/servers/", methods=["DELETE"]) # noqa: F821 @login_required -@validate_request("mcp_ids") -async def rm() -> Response: - req = await get_request_json() - mcp_ids = req.get("mcp_ids", []) - +async def rm(mcp_id: str) -> Response: try: - req["tenant_id"] = current_user.id - - if not MCPServerService.delete_by_ids(mcp_ids): - return get_data_error_result(message=f"Failed to delete MCP servers {mcp_ids}") + e, mcp_server = MCPServerService.get_by_id(mcp_id) + if not e or mcp_server.tenant_id != current_user.id: + return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") + if not MCPServerService.delete_by_ids([mcp_id]): + return get_data_error_result(message=f"Failed to delete MCP servers {[mcp_id]}") return get_json_result(data=True) except Exception as e: return server_error_response(e) -@manager.route("/import", methods=["POST"]) # noqa: F821 +@manager.route("/mcp/servers/import", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcpServers") async def import_multiple() -> Response: @@ -263,144 +290,10 @@ async def import_multiple() -> Response: return server_error_response(e) -@manager.route("/export", methods=["POST"]) # noqa: F821 +@manager.route("/mcp/servers//test", methods=["POST"]) # noqa: F821 @login_required -@validate_request("mcp_ids") -async def export_multiple() -> Response: - req = await get_request_json() - mcp_ids = req.get("mcp_ids", []) - - if not mcp_ids: - return get_data_error_result(message="No MCP server IDs provided.") - - try: - exported_servers = {} - - for mcp_id in mcp_ids: - e, mcp_server = MCPServerService.get_by_id(mcp_id) - - if e and mcp_server.tenant_id == current_user.id: - server_key = mcp_server.name - - exported_servers[server_key] = { - "type": mcp_server.server_type, - "url": mcp_server.url, - "name": mcp_server.name, - "authorization_token": mcp_server.variables.get("authorization_token", ""), - "tools": mcp_server.variables.get("tools", {}), - } - - return get_json_result(data={"mcpServers": exported_servers}) - except Exception as e: - return server_error_response(e) - - -@manager.route("/list_tools", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("mcp_ids") -async def list_tools() -> Response: - req = await get_request_json() - mcp_ids = req.get("mcp_ids", []) - if not mcp_ids: - return get_data_error_result(message="No MCP server IDs provided.") - - timeout = get_float(req, "timeout", 10) - - results = {} - tool_call_sessions = [] - try: - for mcp_id in mcp_ids: - e, mcp_server = MCPServerService.get_by_id(mcp_id) - - if e and mcp_server.tenant_id == current_user.id: - server_key = mcp_server.id - - cached_tools = mcp_server.variables.get("tools", {}) - - tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) - tool_call_sessions.append(tool_call_session) - - try: - tools = await thread_pool_exec(tool_call_session.get_tools, timeout) - except Exception as e: - return get_data_error_result(message=f"MCP list tools error: {e}") - - results[server_key] = [] - for tool in tools: - tool_dict = tool.model_dump() - cached_tool = cached_tools.get(tool_dict["name"], {}) - - tool_dict["enabled"] = cached_tool.get("enabled", True) - results[server_key].append(tool_dict) - - return get_json_result(data=results) - except Exception as e: - return server_error_response(e) - finally: - # PERF: blocking call to close sessions — consider moving to background thread or task queue - await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions) - - -@manager.route("/test_tool", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("mcp_id", "tool_name", "arguments") -async def test_tool() -> Response: - req = await get_request_json() - mcp_id = req.get("mcp_id", "") - if not mcp_id: - return get_data_error_result(message="No MCP server ID provided.") - - timeout = get_float(req, "timeout", 10) - - tool_name = req.get("tool_name", "") - arguments = req.get("arguments", {}) - if not all([tool_name, arguments]): - return get_data_error_result(message="Require provide tool name and arguments.") - - tool_call_sessions = [] - try: - e, mcp_server = MCPServerService.get_by_id(mcp_id) - if not e or mcp_server.tenant_id != current_user.id: - return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") - - tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) - tool_call_sessions.append(tool_call_session) - result = await thread_pool_exec(tool_call_session.tool_call, tool_name, arguments, timeout) - - # PERF: blocking call to close sessions — consider moving to background thread or task queue - await thread_pool_exec(close_multiple_mcp_toolcall_sessions, tool_call_sessions) - return get_json_result(data=result) - except Exception as e: - return server_error_response(e) - - -@manager.route("/cache_tools", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("mcp_id", "tools") -async def cache_tool() -> Response: - req = await get_request_json() - mcp_id = req.get("mcp_id", "") - if not mcp_id: - return get_data_error_result(message="No MCP server ID provided.") - tools = req.get("tools", []) - - e, mcp_server = MCPServerService.get_by_id(mcp_id) - if not e or mcp_server.tenant_id != current_user.id: - return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") - - variables = mcp_server.variables - tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool} - variables["tools"] = tools - - if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], {"variables": variables}): - return get_data_error_result(message="Failed to updated MCP server.") - - return get_json_result(data=tools) - - -@manager.route("/test_mcp", methods=["POST"]) # noqa: F821 @validate_request("url", "server_type") -async def test_mcp() -> Response: +async def test_mcp(mcp_id: str) -> Response: req = await get_request_json() url = req.get("url", "") @@ -415,7 +308,7 @@ async def test_mcp() -> Response: headers = safe_json_parse(req.get("headers", {})) variables = safe_json_parse(req.get("variables", {})) - mcp_server = MCPServer(id=f"{server_type}: {url}", server_type=server_type, url=url, headers=headers, variables=variables) + mcp_server = MCPServer(id=mcp_id, server_type=server_type, url=url, headers=headers, variables=variables) result = [] try: @@ -426,7 +319,6 @@ async def test_mcp() -> Response: except Exception as e: return get_data_error_result(message=f"Test MCP error: {e}") finally: - # PERF: blocking call to close sessions — consider moving to background thread or task queue await thread_pool_exec(close_multiple_mcp_toolcall_sessions, [tool_call_session]) for tool in tools: diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index 8f92661e700..c361d816b60 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -130,7 +130,7 @@ async def delete_memory(memory_id): @login_required async def list_memory(): filter_params = { - k: request.args.get(k) for k in ["memory_type", "tenant_id", "storage_type"] if k in request.args + k: request.args.get(k) for k in ["memory_type", "tenant_id", "owner_ids", "storage_type"] if k in request.args } keywords = request.args.get("keywords") page = int(request.args.get("page", 1)) diff --git a/api/apps/restful_apis/openai_api.py b/api/apps/restful_apis/openai_api.py new file mode 100644 index 00000000000..baa011f32a8 --- /dev/null +++ b/api/apps/restful_apis/openai_api.py @@ -0,0 +1,300 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import time + +from quart import Response, jsonify + +from api.apps import current_user, login_required +from api.db.services.dialog_service import DialogService, async_chat +from api.db.services.doc_metadata_service import DocMetadataService +from api.db.services.tenant_llm_service import TenantLLMService +from api.utils.api_utils import get_error_data_result, get_request_json, validate_request +from common.constants import RetCode, StatusEnum +from common.metadata_utils import convert_conditions, meta_filter +from common.token_utils import num_tokens_from_string +from rag.prompts.generator import chunks_format + +def _validate_llm_id(llm_id, tenant_id, llm_setting=None): + if not llm_id: + return None + + llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(llm_id) + model_type = (llm_setting or {}).get("model_type") + if model_type not in {"chat", "image2text"}: + model_type = "chat" + + if not TenantLLMService.query( + tenant_id=tenant_id, + llm_name=llm_name, + llm_factory=llm_factory, + model_type=model_type, + ): + return f"`llm_id` {llm_id} doesn't exist" + return None + + +import logging +from api.utils.reference_metadata_utils import enrich_chunks_with_document_metadata + +def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None): + chunks = chunks_format(reference) + if not include_metadata: + logging.debug("Skipping document metadata enrichment (include_metadata=False)") + return chunks + + normalized_fields = None + if metadata_fields is not None: + if not isinstance(metadata_fields, list): + return chunks + normalized_fields = {f for f in metadata_fields if isinstance(f, str)} + if not normalized_fields: + return chunks + + logging.debug( + "Enriching %d chunks with document metadata (fields: %s)", + len(chunks), + "ALL" if normalized_fields is None else list(normalized_fields), + ) + + enrich_chunks_with_document_metadata( + chunks, + normalized_fields, + kb_field="dataset_id", + doc_field="document_id", + ) + + return chunks + + +def _build_sse_response(body): + resp = Response(body, mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + +@manager.route("/openai//chat/completions", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("model", "messages") +async def openai_chat_completions(chat_id): + req = await get_request_json() + + extra_body = req.get("extra_body") or {} + if extra_body and not isinstance(extra_body, dict): + return get_error_data_result("extra_body must be an object.") + + need_reference = bool(extra_body.get("reference", False)) + reference_metadata = extra_body.get("reference_metadata") or {} + if reference_metadata and not isinstance(reference_metadata, dict): + return get_error_data_result("reference_metadata must be an object.") + include_reference_metadata = bool(reference_metadata.get("include", False)) + metadata_fields = reference_metadata.get("fields") + if metadata_fields is not None and not isinstance(metadata_fields, list): + return get_error_data_result("reference_metadata.fields must be an array.") + + messages = req.get("messages", []) + if len(messages) < 1: + return get_error_data_result("You have to provide messages.") + if messages[-1]["role"] != "user": + return get_error_data_result("The last content of this conversation is not from user.") + + prompt = messages[-1]["content"] + context_token_used = sum(num_tokens_from_string(message["content"]) for message in messages) + requested_model = req.get("model", "") or "" + completion_id = f"chatcmpl-{chat_id}" + + dia = DialogService.query(tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value) + if not dia: + return get_error_data_result(f"You don't own the chat {chat_id}") + dia = dia[0] + + using_placeholder_model = requested_model == "model" + if using_placeholder_model: + requested_model = dia.llm_id or requested_model + else: + llm_id_error = _validate_llm_id(requested_model, current_user.id, {"model_type": "chat"}) + if llm_id_error: + return get_error_data_result(message=llm_id_error, code=RetCode.ARGUMENT_ERROR) + dia.llm_id = requested_model + if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=requested_model): + return get_error_data_result(message=f"Cannot use specified model {requested_model}.") + + metadata_condition = extra_body.get("metadata_condition") or {} + if metadata_condition and not isinstance(metadata_condition, dict): + return get_error_data_result(message="metadata_condition must be an object.") + + doc_ids_str = None + if metadata_condition: + metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or []) + filtered_doc_ids = meta_filter( + metas, + convert_conditions(metadata_condition), + metadata_condition.get("logic", "and"), + ) + if metadata_condition.get("conditions") and not filtered_doc_ids: + filtered_doc_ids = ["-999"] + doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None + + msg = [] + for message in messages: + if message["role"] == "system": + continue + if message["role"] == "assistant" and not msg: + continue + msg.append(message) + + tools = None + toolcall_session = None + stream_mode = req.get("stream", True) + + if stream_mode: + async def streamed_response_generator(): + token_used = 0 + last_ans = {} + full_content = "" + final_answer = None + final_reference = None + in_think = False + response = { + "id": completion_id, + "choices": [ + { + "delta": { + "content": "", + "role": "assistant", + "function_call": None, + "tool_calls": None, + "reasoning_content": "", + }, + "finish_reason": None, + "index": 0, + "logprobs": None, + } + ], + "created": int(time.time()), + "model": requested_model, + "object": "chat.completion.chunk", + "system_fingerprint": "", + "usage": None, + } + + try: + chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference} + if doc_ids_str: + chat_kwargs["doc_ids"] = doc_ids_str + async for ans in async_chat(dia, msg, True, **chat_kwargs): + last_ans = ans + if ans.get("final"): + if ans.get("answer"): + full_content = ans["answer"] + response["choices"][0]["delta"]["content"] = full_content + response["choices"][0]["delta"]["reasoning_content"] = None + yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" + final_answer = full_content + final_reference = ans.get("reference", {}) + continue + if ans.get("start_to_think"): + in_think = True + continue + if ans.get("end_to_think"): + in_think = False + continue + delta = ans.get("answer") or "" + if not delta: + continue + token_used += num_tokens_from_string(delta) + if in_think: + response["choices"][0]["delta"]["reasoning_content"] = delta + response["choices"][0]["delta"]["content"] = None + else: + full_content += delta + response["choices"][0]["delta"]["content"] = delta + response["choices"][0]["delta"]["reasoning_content"] = None + yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" + except Exception as e: + response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e) + yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" + + response["choices"][0]["delta"]["content"] = None + response["choices"][0]["delta"]["reasoning_content"] = None + response["choices"][0]["finish_reason"] = "stop" + prompt_tokens = num_tokens_from_string(prompt) + response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": token_used, + "total_tokens": prompt_tokens + token_used, + } + if need_reference: + reference_payload = final_reference if final_reference is not None else last_ans.get("reference", []) + response["choices"][0]["delta"]["reference"] = _build_reference_chunks( + reference_payload, + include_metadata=include_reference_metadata, + metadata_fields=metadata_fields, + ) + response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content + yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" + yield "data:[DONE]\n\n" + + return _build_sse_response(streamed_response_generator()) + + answer = None + chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference} + if doc_ids_str: + chat_kwargs["doc_ids"] = doc_ids_str + async for ans in async_chat(dia, msg, False, **chat_kwargs): + answer = ans + break + + content = answer["answer"] + response = { + "id": completion_id, + "object": "chat.completion", + "created": int(time.time()), + "model": requested_model, + "usage": { + "prompt_tokens": num_tokens_from_string(prompt), + "completion_tokens": num_tokens_from_string(content), + "total_tokens": num_tokens_from_string(prompt) + num_tokens_from_string(content), + "completion_tokens_details": { + "reasoning_tokens": context_token_used, + "accepted_prediction_tokens": num_tokens_from_string(content), + "rejected_prediction_tokens": 0, + }, + }, + "choices": [ + { + "message": { + "role": "assistant", + "content": content, + }, + "logprobs": None, + "finish_reason": "stop", + "index": 0, + } + ], + } + if need_reference: + response["choices"][0]["message"]["reference"] = _build_reference_chunks( + answer.get("reference", {}), + include_metadata=include_reference_metadata, + metadata_fields=metadata_fields, + ) + + return jsonify(response) diff --git a/api/apps/plugin_app.py b/api/apps/restful_apis/plugin_api.py similarity index 93% rename from api/apps/plugin_app.py rename to api/apps/restful_apis/plugin_api.py index fb0a7bb6106..6d53fbc6267 100644 --- a/api/apps/plugin_app.py +++ b/api/apps/restful_apis/plugin_api.py @@ -21,7 +21,7 @@ from agent.plugin import GlobalPluginManager -@manager.route('/llm_tools', methods=['GET']) # noqa: F821 +@manager.route('/plugin/tools', methods=['GET']) # noqa: F821 @login_required def llm_tools() -> Response: tools = GlobalPluginManager.get_llm_tools() diff --git a/api/apps/restful_apis/search_api.py b/api/apps/restful_apis/search_api.py index 82a357f306b..c56d0ff8344 100644 --- a/api/apps/restful_apis/search_api.py +++ b/api/apps/restful_apis/search_api.py @@ -14,7 +14,10 @@ # limitations under the License. # -from quart import request +import json + +from quart import Response, request +from api.db.services.dialog_service import async_ask from api.apps import current_user, login_required from api.constants import DATASET_NAME_LIMIT @@ -168,3 +171,46 @@ def delete_search(search_id): return get_json_result(data=True) except Exception as e: return server_error_response(e) + + +@manager.route("/searches//completion", methods=["POST"]) # noqa: F821 +@manager.route("/searches//completions", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("question") +async def completion(search_id): + if not SearchService.accessible4deletion(search_id, current_user.id): + return get_json_result( + data=False, + message="No authorization.", + code=RetCode.AUTHENTICATION_ERROR, + ) + + req = await get_request_json() + uid = current_user.id + search_app = SearchService.get_detail(search_id) + if not search_app: + return get_data_error_result(message=f"Cannot find search {search_id}") + + search_config = search_app.get("search_config", {}) + kb_ids = search_config.get("kb_ids") or req.get("kb_ids") or [] + if not kb_ids: + return get_data_error_result(message="`kb_ids` is required.") + + async def stream(): + nonlocal req, uid, kb_ids, search_config + try: + async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config): + yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" + except Exception as ex: + yield "data:" + json.dumps( + {"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, + ensure_ascii=False, + ) + "\n\n" + yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" + + resp = Response(stream(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp diff --git a/api/apps/api_app.py b/api/apps/restful_apis/stats_api.py similarity index 97% rename from api/apps/api_app.py rename to api/apps/restful_apis/stats_api.py index 0d5d62334ed..7185194327d 100644 --- a/api/apps/api_app.py +++ b/api/apps/restful_apis/stats_api.py @@ -20,7 +20,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response from api.apps import login_required, current_user -@manager.route('/stats', methods=['GET']) # noqa: F821 +@manager.route('/system/stats', methods=['GET']) # noqa: F821 @login_required def stats(): try: diff --git a/api/apps/restful_apis/system_api.py b/api/apps/restful_apis/system_api.py index 467d9111d90..55c34c25a34 100644 --- a/api/apps/restful_apis/system_api.py +++ b/api/apps/restful_apis/system_api.py @@ -14,25 +14,31 @@ # limitations under the License. # +import json +import logging +from datetime import datetime +from timeit import default_timer as timer + from quart import jsonify from api.apps import login_required, current_user from api.utils.api_utils import get_json_result, get_data_error_result, server_error_response, generate_confirmation_token -from api.utils.health_utils import run_health_checks +from api.utils.health_utils import run_health_checks, get_oceanbase_status from common.versions import get_ragflow_version -from datetime import datetime from common.time_utils import current_timestamp, datetime_format from api.db.db_models import APIToken from api.db.services.api_service import APITokenService +from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import UserTenantService from common.log_utils import get_log_levels, set_log_level +from common import settings +from rag.utils.redis_conn import REDIS_CONN @manager.route("/system/ping", methods=["GET"]) # noqa: F821 async def ping(): return "pong", 200 @manager.route("/system/version", methods=["GET"]) # noqa: F821 -@login_required def version(): """ Get the current version of the application. @@ -53,6 +59,174 @@ def version(): """ return get_json_result(data=get_ragflow_version()) + +@manager.route("/system/status", methods=["GET"]) # noqa: F821 +@login_required +def status(): + """ + Get the system status. + --- + tags: + - System + security: + - ApiKeyAuth: [] + responses: + 200: + description: System is operational. + schema: + type: object + properties: + es: + type: object + description: Elasticsearch status. + storage: + type: object + description: Storage status. + database: + type: object + description: Database status. + 503: + description: Service unavailable. + schema: + type: object + properties: + error: + type: string + description: Error message. + """ + res = {} + st = timer() + try: + res["doc_engine"] = settings.docStoreConn.health() + res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) + except Exception as e: + res["doc_engine"] = { + "type": "unknown", + "status": "red", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + "error": str(e), + } + + st = timer() + try: + settings.STORAGE_IMPL.health() + res["storage"] = { + "storage": settings.STORAGE_IMPL_TYPE.lower(), + "status": "green", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + } + except Exception as e: + res["storage"] = { + "storage": settings.STORAGE_IMPL_TYPE.lower(), + "status": "red", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + "error": str(e), + } + + st = timer() + try: + KnowledgebaseService.get_by_id("x") + res["database"] = { + "database": settings.DATABASE_TYPE.lower(), + "status": "green", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + } + except Exception as e: + res["database"] = { + "database": settings.DATABASE_TYPE.lower(), + "status": "red", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + "error": str(e), + } + + st = timer() + try: + if not REDIS_CONN.health(): + raise Exception("Lost connection!") + res["redis"] = { + "status": "green", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + } + except Exception as e: + res["redis"] = { + "status": "red", + "elapsed": "{:.1f}".format((timer() - st) * 1000.0), + "error": str(e), + } + + task_executor_heartbeats = {} + try: + task_executors = REDIS_CONN.smembers("TASKEXE") + now = datetime.now().timestamp() + for task_executor_id in task_executors: + heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60 * 30, now) + heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats] + task_executor_heartbeats[task_executor_id] = heartbeats + except Exception: + logging.exception("get task executor heartbeats failed!") + res["task_executor_heartbeats"] = task_executor_heartbeats + + return get_json_result(data=res) + + +@manager.route("/system/oceanbase/status", methods=["GET"]) # noqa: F821 +@login_required +def oceanbase_status(): + """ + Get OceanBase health status and performance metrics. + --- + tags: + - System + security: + - ApiKeyAuth: [] + responses: + 200: + description: OceanBase status retrieved successfully. + schema: + type: object + properties: + status: + type: string + description: Status (alive/timeout). + message: + type: object + description: Detailed status information including health and performance metrics. + """ + try: + status_info = get_oceanbase_status() + return get_json_result(data=status_info) + except Exception as e: + return get_json_result( + data={ + "status": "error", + "message": f"Failed to get OceanBase status: {str(e)}" + }, + code=500 + ) + + +@manager.route("/system/config", methods=["GET"]) # noqa: F821 +def get_config(): + """ + Get system configuration. + --- + tags: + - System + responses: + 200: + description: Return system configuration + schema: + type: object + properties: + registerEnable: + type: integer 0 means disabled, 1 means enabled + description: Whether user registration is enabled + """ + return get_json_result(data={ + "registerEnabled": settings.REGISTER_ENABLED, + "disablePasswordLogin": settings.DISABLE_PASSWORD_LOGIN, + }) + @manager.route("/system/healthz", methods=["GET"]) # noqa: F821 def healthz(): result, all_ok = run_health_checks() diff --git a/api/apps/restful_apis/task_api.py b/api/apps/restful_apis/task_api.py new file mode 100644 index 00000000000..2bd7a41802f --- /dev/null +++ b/api/apps/restful_apis/task_api.py @@ -0,0 +1,101 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +from datetime import datetime + +from api.apps import login_required +from api.db.services.task_service import TaskService, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID +from api.utils.api_utils import ( + get_json_result, + get_request_json, + validate_request, +) +from common.constants import RetCode, TaskStatus +from rag.utils.redis_conn import REDIS_CONN + + +@manager.route("/tasks//cancel", methods=["POST"]) # noqa: F821 +@login_required +async def cancel_task(task_id): + """Cancel a running task. + """ + return await _cancel_task(task_id) + + +@manager.route("/tasks/", methods=["PATCH"]) # noqa: F821 +@login_required +@validate_request("action") +async def patch_task(task_id): + req = await get_request_json() + action = req.get("action") + + if action != "stop": + return get_json_result( + code=RetCode.ARGUMENT_ERROR, + message=f"Invalid action '{action}'. Only 'stop' is supported.", + ) + + return await _cancel_task(task_id) + + +async def _cancel_task(task_id): + """ + Sets a Redis cancel flag, updates the task progress to -1 (cancelled), + and marks the associated document's run status as CANCEL if applicable. + """ + try: + REDIS_CONN.set(f"{task_id}-cancel", "x") + except Exception as e: + logging.exception("Failed to set cancel flag for task %s: %s", task_id, str(e)) + return get_json_result( + code=RetCode.CONNECTION_ERROR, + message="Failed to stop task", + ) + + exists, task = TaskService.get_by_id(task_id) + if not exists: + return get_json_result(data=True) + + # Append a cancellation message so the user can see it in progress_msg. + try: + cancel_msg = f"\n{datetime.now().strftime('%H:%M:%S')} Task stopped by user." + # Only transition to -1 if the task is still in a non-terminal state, + # mirroring TaskService.update_progress semantics. + TaskService.model.update( + progress_msg=TaskService.model.progress_msg + cancel_msg, + progress=-1, + ).where( + (TaskService.model.id == task_id) + & (TaskService.model.progress >= 0) + & (TaskService.model.progress < 1) + ).execute() + except Exception as e: + logging.warning("Failed to update task %s progress after cancellation: %s", task_id, str(e)) + + # If the task belongs to a document, also mark the document's run status as + # cancelled so that the UI reflects the state correctly. + try: + from api.db.services.document_service import DocumentService + doc_id = task.doc_id + if doc_id and doc_id not in (CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID): + _, doc = DocumentService.get_by_id(doc_id) + if doc and str(doc.run) in (TaskStatus.RUNNING.value, TaskStatus.SCHEDULE.value): + DocumentService.update_by_id(doc_id, {"run": TaskStatus.CANCEL.value, "progress": 0}) + except Exception as e: + logging.warning("Failed to update document run status for task %s: %s", task_id, str(e)) + + logging.info(f"Cancel task succeeded: task_id={task_id} doc_id={task.doc_id}") + return get_json_result(data=True) diff --git a/api/apps/tenant_app.py b/api/apps/restful_apis/tenant_api.py similarity index 59% rename from api/apps/tenant_app.py rename to api/apps/restful_apis/tenant_api.py index be6305e8911..4d45337cb0b 100644 --- a/api/apps/tenant_app.py +++ b/api/apps/restful_apis/tenant_api.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,48 +13,56 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import logging import asyncio +import logging + +from api.apps import current_user, login_required from api.db import UserTenantRole from api.db.db_models import UserTenant -from api.db.services.user_service import UserTenantService, UserService - +from api.db.services.user_service import UserService, UserTenantService +from api.utils.api_utils import ( + get_data_error_result, + get_json_result, + get_request_json, + server_error_response, + validate_request, +) +from api.utils.web_utils import send_invite_email +from common import settings from common.constants import RetCode, StatusEnum from common.misc_utils import get_uuid from common.time_utils import delta_seconds -from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request -from api.utils.web_utils import send_invite_email -from common import settings -from api.apps import login_required, current_user -@manager.route("//user/list", methods=["GET"]) # noqa: F821 +@manager.route("/tenants//users", methods=["GET"]) # noqa: F821 @login_required def user_list(tenant_id): if current_user.id != tenant_id: return get_json_result( data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR) + message="No authorization.", + code=RetCode.AUTHENTICATION_ERROR, + ) try: users = UserTenantService.get_by_tenant_id(tenant_id) - for u in users: - u["delta_seconds"] = delta_seconds(str(u["update_date"])) + for user in users: + user["delta_seconds"] = delta_seconds(str(user["update_date"])) return get_json_result(data=users) - except Exception as e: - return server_error_response(e) + except Exception as exc: + return server_error_response(exc) -@manager.route('//user', methods=['POST']) # noqa: F821 +@manager.route("/tenants//users", methods=["POST"]) # noqa: F821 @login_required @validate_request("email") async def create(tenant_id): if current_user.id != tenant_id: return get_json_result( data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR) + message="No authorization.", + code=RetCode.AUTHENTICATION_ERROR, + ) req = await get_request_json() invite_user_email = req["email"] @@ -71,7 +79,8 @@ async def create(tenant_id): if user_tenant_role == UserTenantRole.OWNER: return get_data_error_result(message=f"{invite_user_email} is the owner of the team.") return get_data_error_result( - message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.") + message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid." + ) UserTenantService.save( id=get_uuid(), @@ -79,10 +88,10 @@ async def create(tenant_id): tenant_id=tenant_id, invited_by=current_user.id, role=UserTenantRole.INVITE, - status=StatusEnum.VALID.value) + status=StatusEnum.VALID.value, + ) try: - user_name = "" _, user = UserService.get_by_id(current_user.id) if user: @@ -93,52 +102,62 @@ async def create(tenant_id): to_email=invite_user_email, invite_url=settings.MAIL_FRONTEND_URL, tenant_id=tenant_id, - inviter=user_name or current_user.email + inviter=user_name or current_user.email, ) ) - except Exception as e: - logging.exception(f"Failed to send invite email to {invite_user_email}: {e}") - return get_json_result(data=False, message="Failed to send invite email.", code=RetCode.SERVER_ERROR) - usr = invite_users[0].to_dict() - usr = {k: v for k, v in usr.items() if k in ["id", "avatar", "email", "nickname"]} + except Exception as exc: + logging.exception(f"Failed to send invite email to {invite_user_email}: {exc}") + return get_json_result( + data=False, + message="Failed to send invite email.", + code=RetCode.SERVER_ERROR, + ) - return get_json_result(data=usr) + user = invite_users[0].to_dict() + user = {k: v for k, v in user.items() if k in ["id", "avatar", "email", "nickname"]} + return get_json_result(data=user) -@manager.route('//user/', methods=['DELETE']) # noqa: F821 +@manager.route("/tenants//users", methods=["DELETE"]) # noqa: F821 @login_required -def rm(tenant_id, user_id): +@validate_request("user_id") +async def rm(tenant_id): + req = await get_request_json() + user_id = req["user_id"] if current_user.id != tenant_id and current_user.id != user_id: return get_json_result( data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR) + message="No authorization.", + code=RetCode.AUTHENTICATION_ERROR, + ) try: UserTenantService.filter_delete([UserTenant.tenant_id == tenant_id, UserTenant.user_id == user_id]) return get_json_result(data=True) - except Exception as e: - return server_error_response(e) + except Exception as exc: + return server_error_response(exc) -@manager.route("/list", methods=["GET"]) # noqa: F821 +@manager.route("/tenants", methods=["GET"]) # noqa: F821 @login_required def tenant_list(): try: users = UserTenantService.get_tenants_by_user_id(current_user.id) - for u in users: - u["delta_seconds"] = delta_seconds(str(u["update_date"])) + for user in users: + user["delta_seconds"] = delta_seconds(str(user["update_date"])) return get_json_result(data=users) - except Exception as e: - return server_error_response(e) + except Exception as exc: + return server_error_response(exc) -@manager.route("/agree/", methods=["PUT"]) # noqa: F821 +@manager.route("/tenants/", methods=["PATCH"]) # noqa: F821 @login_required def agree(tenant_id): try: - UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], - {"role": UserTenantRole.NORMAL}) + UserTenantService.filter_update( + [UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], + {"role": UserTenantRole.NORMAL}, + ) return get_json_result(data=True) - except Exception as e: - return server_error_response(e) + except Exception as exc: + return server_error_response(exc) diff --git a/api/apps/user_app.py b/api/apps/restful_apis/user_api.py similarity index 75% rename from api/apps/user_app.py rename to api/apps/restful_apis/user_api.py index 74248992696..714453ac6fa 100644 --- a/api/apps/user_app.py +++ b/api/apps/restful_apis/user_api.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import json import logging import string import os @@ -60,10 +59,9 @@ captcha_key, ) from common import settings -from common.http_client import async_request -@manager.route("/login", methods=["POST", "GET"]) # noqa: F821 +@manager.route("/auth/login", methods=["POST"]) # noqa: F821 async def login(): """ User login endpoint. @@ -140,7 +138,7 @@ async def login(): ) -@manager.route("/login/channels", methods=["GET"]) # noqa: F821 +@manager.route("/auth/login/channels", methods=["GET"]) # noqa: F821 async def get_login_channels(): """ Get all supported authentication channels. @@ -161,7 +159,7 @@ async def get_login_channels(): return get_json_result(data=[], message=f"Load channels failure, error: {str(e)}", code=RetCode.EXCEPTION_ERROR) -@manager.route("/login/", methods=["GET"]) # noqa: F821 +@manager.route("/auth/login/", methods=["GET"]) # noqa: F821 async def oauth_login(channel): channel_config = settings.OAUTH_CONFIG.get(channel) if not channel_config: @@ -174,7 +172,7 @@ async def oauth_login(channel): return redirect(auth_url) -@manager.route("/oauth/callback/", methods=["GET"]) # noqa: F821 +@manager.route("/auth/oauth//callback", methods=["GET"]) # noqa: F821 async def oauth_callback(channel): """ Handle the OAuth/OIDC callback for various channels dynamically. @@ -269,224 +267,7 @@ async def oauth_callback(channel): return redirect(f"/?error={str(e)}") -@manager.route("/github_callback", methods=["GET"]) # noqa: F821 -async def github_callback(): - """ - **Deprecated**, Use `/oauth/callback/` instead. - - GitHub OAuth callback endpoint. - --- - tags: - - OAuth - parameters: - - in: query - name: code - type: string - required: true - description: Authorization code from GitHub. - responses: - 200: - description: Authentication successful. - schema: - type: object - """ - res = await async_request( - "POST", - settings.GITHUB_OAUTH.get("url"), - data={ - "client_id": settings.GITHUB_OAUTH.get("client_id"), - "client_secret": settings.GITHUB_OAUTH.get("secret_key"), - "code": request.args.get("code"), - }, - headers={"Accept": "application/json"}, - ) - res = res.json() - if "error" in res: - return redirect("/?error=%s" % res["error_description"]) - - if "user:email" not in res["scope"].split(","): - return redirect("/?error=user:email not in scope") - - session["access_token"] = res["access_token"] - session["access_token_from"] = "github" - user_info = await user_info_from_github(session["access_token"]) - email_address = user_info["email"] - users = UserService.query(email=email_address) - user_id = get_uuid() - if not users: - # User isn't try to register - try: - try: - avatar = await download_img(user_info["avatar_url"]) - except Exception as e: - logging.exception(e) - avatar = "" - users = user_register( - user_id, - { - "access_token": session["access_token"], - "email": email_address, - "avatar": avatar, - "nickname": user_info["login"], - "login_channel": "github", - "last_login_time": get_format_time(), - "is_superuser": False, - }, - ) - if not users: - raise Exception(f"Fail to register {email_address}.") - if len(users) > 1: - raise Exception(f"Same email: {email_address} exists!") - - # Try to log in - user = users[0] - login_user(user) - return redirect("/?auth=%s" % user.get_id()) - except Exception as e: - rollback_user_registration(user_id) - logging.exception(e) - return redirect("/?error=%s" % str(e)) - - # User has already registered, try to log in - user = users[0] - user.access_token = get_uuid() - if user and hasattr(user, 'is_active') and user.is_active == "0": - return redirect("/?error=user_inactive") - login_user(user) - user.save() - return redirect("/?auth=%s" % user.get_id()) - - -@manager.route("/feishu_callback", methods=["GET"]) # noqa: F821 -async def feishu_callback(): - """ - Feishu OAuth callback endpoint. - --- - tags: - - OAuth - parameters: - - in: query - name: code - type: string - required: true - description: Authorization code from Feishu. - responses: - 200: - description: Authentication successful. - schema: - type: object - """ - app_access_token_res = await async_request( - "POST", - settings.FEISHU_OAUTH.get("app_access_token_url"), - data=json.dumps( - { - "app_id": settings.FEISHU_OAUTH.get("app_id"), - "app_secret": settings.FEISHU_OAUTH.get("app_secret"), - } - ), - headers={"Content-Type": "application/json; charset=utf-8"}, - ) - app_access_token_res = app_access_token_res.json() - if app_access_token_res["code"] != 0: - return redirect("/?error=%s" % app_access_token_res) - - res = await async_request( - "POST", - settings.FEISHU_OAUTH.get("user_access_token_url"), - data=json.dumps( - { - "grant_type": settings.FEISHU_OAUTH.get("grant_type"), - "code": request.args.get("code"), - } - ), - headers={ - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {app_access_token_res['app_access_token']}", - }, - ) - res = res.json() - if res["code"] != 0: - return redirect("/?error=%s" % res["message"]) - - if "contact:user.email:readonly" not in res["data"]["scope"].split(): - return redirect("/?error=contact:user.email:readonly not in scope") - session["access_token"] = res["data"]["access_token"] - session["access_token_from"] = "feishu" - user_info = await user_info_from_feishu(session["access_token"]) - email_address = user_info["email"] - users = UserService.query(email=email_address) - user_id = get_uuid() - if not users: - # User isn't try to register - try: - try: - avatar = await download_img(user_info["avatar_url"]) - except Exception as e: - logging.exception(e) - avatar = "" - users = user_register( - user_id, - { - "access_token": session["access_token"], - "email": email_address, - "avatar": avatar, - "nickname": user_info["en_name"], - "login_channel": "feishu", - "last_login_time": get_format_time(), - "is_superuser": False, - }, - ) - if not users: - raise Exception(f"Fail to register {email_address}.") - if len(users) > 1: - raise Exception(f"Same email: {email_address} exists!") - - # Try to log in - user = users[0] - login_user(user) - return redirect("/?auth=%s" % user.get_id()) - except Exception as e: - rollback_user_registration(user_id) - logging.exception(e) - return redirect("/?error=%s" % str(e)) - - # User has already registered, try to log in - user = users[0] - if user and hasattr(user, 'is_active') and user.is_active == "0": - return redirect("/?error=user_inactive") - user.access_token = get_uuid() - login_user(user) - user.save() - return redirect("/?auth=%s" % user.get_id()) - - -async def user_info_from_feishu(access_token): - headers = { - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {access_token}", - } - res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers) - user_info = res.json()["data"] - user_info["email"] = None if user_info.get("email") == "" else user_info["email"] - return user_info - - -async def user_info_from_github(access_token): - headers = {"Accept": "application/json", "Authorization": f"token {access_token}"} - res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers) - user_info = res.json() - email_info_response = await async_request( - "GET", - f"https://api.github.com/user/emails?access_token={access_token}", - headers=headers, - ) - email_info = email_info_response.json() - user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] - return user_info - - -@manager.route("/logout", methods=["GET"]) # noqa: F821 +@manager.route("/auth/logout", methods=["POST"]) # noqa: F821 @login_required async def log_out(): """ @@ -508,7 +289,7 @@ async def log_out(): return get_json_result(data=True) -@manager.route("/setting", methods=["POST"]) # noqa: F821 +@manager.route("/users/me", methods=["PATCH"]) # noqa: F821 @login_required async def setting_user(): """ @@ -576,7 +357,7 @@ async def setting_user(): return get_json_result(data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR) -@manager.route("/info", methods=["GET"]) # noqa: F821 +@manager.route("/users/me", methods=["GET"]) # noqa: F821 @login_required async def user_profile(): """ @@ -667,7 +448,7 @@ def user_register(user_id, user): return UserService.query(email=user["email"]) -@manager.route("/register", methods=["POST"]) # noqa: F821 +@manager.route("/users", methods=["POST"]) # noqa: F821 @validate_request("nickname", "email", "password") async def user_add(): """ @@ -761,7 +542,7 @@ async def user_add(): ) -@manager.route("/tenant_info", methods=["GET"]) # noqa: F821 +@manager.route("/users/me/models", methods=["GET"]) # noqa: F821 @login_required async def tenant_info(): """ @@ -799,7 +580,7 @@ async def tenant_info(): return server_error_response(e) -@manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821 +@manager.route("/users/me/models", methods=["PATCH"]) # noqa: F821 @login_required @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id") async def set_tenant_info(): @@ -849,7 +630,7 @@ async def set_tenant_info(): return server_error_response(e) -@manager.route("/forget/captcha", methods=["GET"]) # noqa: F821 +@manager.route("/auth/password/forgot/captcha", methods=["POST"]) # noqa: F821 async def forget_get_captcha(): """ GET /forget/captcha?email= @@ -877,7 +658,7 @@ async def forget_get_captcha(): return response -@manager.route("/forget/otp", methods=["POST"]) # noqa: F821 +@manager.route("/auth/password/forgot/otp", methods=["POST"]) # noqa: F821 async def forget_send_otp(): """ POST /forget/otp @@ -947,7 +728,7 @@ def _verified_key(email: str) -> str: return f"otp:verified:{email}" -@manager.route("/forget/verify-otp", methods=["POST"]) # noqa: F821 +@manager.route("/auth/password/forgot/otp/verify", methods=["POST"]) # noqa: F821 async def forget_verify_otp(): """ Verify email + OTP only. On success: @@ -1008,7 +789,7 @@ async def forget_verify_otp(): return get_json_result(data=True, code=RetCode.SUCCESS, message="otp verified") -@manager.route("/forget/reset-password", methods=["POST"]) # noqa: F821 +@manager.route("/auth/password/reset", methods=["POST"]) # noqa: F821 async def forget_reset_password(): """ Reset password after successful OTP verification. diff --git a/api/apps/sdk/agents.py b/api/apps/sdk/agents.py deleted file mode 100644 index f7f36fa19f0..00000000000 --- a/api/apps/sdk/agents.py +++ /dev/null @@ -1,938 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import asyncio -import base64 -import hashlib -import hmac -import ipaddress -import json -import logging -import time -from typing import Any, cast - -import jwt - -from agent.canvas import Canvas -from api.apps.services.canvas_replica_service import CanvasReplicaService -from api.db import CanvasCategory -from api.db.services.canvas_service import UserCanvasService -from api.db.services.file_service import FileService -from api.db.services.user_service import UserService -from api.db.services.user_canvas_version import UserCanvasVersionService -from common.constants import RetCode -from common.misc_utils import get_uuid -from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required -from api.utils.api_utils import get_result -from quart import request, Response -from rag.utils.redis_conn import REDIS_CONN - - -def _get_user_nickname(user_id: str) -> str: - exists, user = UserService.get_by_id(user_id) - if not exists: - return user_id - return str(getattr(user, "nickname", "") or user_id) - - -@manager.route('/agents', methods=['GET']) # noqa: F821 -@token_required -def list_agents(tenant_id): - id = request.args.get("id") - title = request.args.get("title") - if id or title: - canvas = UserCanvasService.query(id=id, title=title, user_id=tenant_id) - if not canvas: - return get_error_data_result("The agent doesn't exist.") - page_number = int(request.args.get("page", 1)) - items_per_page = int(request.args.get("page_size", 30)) - order_by = request.args.get("orderby", "update_time") - if str(request.args.get("desc","false")).lower() == "false": - desc = False - else: - desc = True - canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title) - return get_result(data=canvas) - - -@manager.route("/agents", methods=["POST"]) # noqa: F821 -@token_required -async def create_agent(tenant_id: str): - req: dict[str, Any] = cast(dict[str, Any], await get_request_json()) - req["user_id"] = tenant_id - - if req.get("dsl") is not None: - try: - req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"]) - except ValueError as e: - return get_json_result(data=False, message=str(e), code=RetCode.ARGUMENT_ERROR) - else: - return get_json_result(data=False, message="No DSL data in request.", code=RetCode.ARGUMENT_ERROR) - - if req.get("title") is not None: - req["title"] = req["title"].strip() - else: - return get_json_result(data=False, message="No title in request.", code=RetCode.ARGUMENT_ERROR) - - if UserCanvasService.query(user_id=tenant_id, title=req["title"]): - return get_data_error_result(message=f"Agent with title {req['title']} already exists.") - - agent_id = get_uuid() - req["id"] = agent_id - - if not UserCanvasService.save(**req): - return get_data_error_result(message="Fail to create agent.") - - owner_nickname = _get_user_nickname(tenant_id) - UserCanvasVersionService.save_or_replace_latest( - user_canvas_id=agent_id, - title=UserCanvasVersionService.build_version_title(owner_nickname, req.get("title")), - dsl=req["dsl"] - ) - - return get_json_result(data=True) - - -@manager.route("/agents/", methods=["PUT"]) # noqa: F821 -@token_required -async def update_agent(tenant_id: str, agent_id: str): - req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None} - req["user_id"] = tenant_id - - if req.get("dsl") is not None: - try: - req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"]) - except ValueError as e: - return get_json_result(data=False, message=str(e), code=RetCode.ARGUMENT_ERROR) - - if req.get("title") is not None: - req["title"] = req["title"].strip() - - if not UserCanvasService.query(user_id=tenant_id, id=agent_id): - return get_json_result( - data=False, message="Only owner of canvas authorized for this operation.", - code=RetCode.OPERATING_ERROR) - - _, current_agent = UserCanvasService.get_by_id(agent_id) - agent_title_for_version = req.get("title") or (current_agent.title if current_agent else "") - owner_nickname = _get_user_nickname(tenant_id) - - UserCanvasService.update_by_id(agent_id, req) - - if req.get("dsl") is not None: - UserCanvasVersionService.save_or_replace_latest( - user_canvas_id=agent_id, - title=UserCanvasVersionService.build_version_title(owner_nickname, agent_title_for_version), - dsl=req["dsl"] - ) - - return get_json_result(data=True) - - -@manager.route("/agents/", methods=["DELETE"]) # noqa: F821 -@token_required -def delete_agent(tenant_id: str, agent_id: str): - if not UserCanvasService.query(user_id=tenant_id, id=agent_id): - return get_json_result( - data=False, message="Only owner of canvas authorized for this operation.", - code=RetCode.OPERATING_ERROR) - - UserCanvasService.delete_by_id(agent_id) - return get_json_result(data=True) - -@manager.route("/webhook/", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821 -@manager.route("/webhook_test/",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821 -async def webhook(agent_id: str): - is_test = request.path.startswith("/api/v1/webhook_test") - start_ts = time.time() - - # 1. Fetch canvas by agent_id - exists, cvs = UserCanvasService.get_by_id(agent_id) - if not exists: - return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST - - # 2. Check canvas category - if cvs.canvas_category == CanvasCategory.DataFlow: - return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST - - # 3. Load DSL from canvas - dsl = getattr(cvs, "dsl", None) - if not isinstance(dsl, dict): - return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST - - # 4. Check webhook configuration in DSL - webhook_cfg = {} - components = dsl.get("components", {}) - for k, _ in components.items(): - cpn_obj = components[k]["obj"] - if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook": - webhook_cfg = cpn_obj["params"] - - if not webhook_cfg: - return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST - - # 5. Validate request method against webhook_cfg.methods - allowed_methods = webhook_cfg.get("methods", []) - request_method = request.method.upper() - if allowed_methods and request_method not in allowed_methods: - return get_data_error_result( - code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook." - ),RetCode.BAD_REQUEST - - # 6. Validate webhook security - async def validate_webhook_security(security_cfg: dict): - """Validate webhook security rules based on security configuration.""" - - if not security_cfg: - return # No security config → allowed by default - - # 1. Validate max body size - await _validate_max_body_size(security_cfg) - - # 2. Validate IP whitelist - _validate_ip_whitelist(security_cfg) - - # # 3. Validate rate limiting - _validate_rate_limit(security_cfg) - - # 4. Validate authentication - auth_type = security_cfg.get("auth_type", "none") - - if auth_type == "none": - return - - if auth_type == "token": - _validate_token_auth(security_cfg) - - elif auth_type == "basic": - _validate_basic_auth(security_cfg) - - elif auth_type == "jwt": - _validate_jwt_auth(security_cfg) - - else: - raise Exception(f"Unsupported auth_type: {auth_type}") - - async def _validate_max_body_size(security_cfg): - """Check request size does not exceed max_body_size.""" - max_size = security_cfg.get("max_body_size") - if not max_size: - return - - # Convert "10MB" → bytes - units = {"kb": 1024, "mb": 1024**2} - size_str = max_size.lower() - - for suffix, factor in units.items(): - if size_str.endswith(suffix): - limit = int(size_str.replace(suffix, "")) * factor - break - else: - raise Exception("Invalid max_body_size format") - MAX_LIMIT = 10 * 1024 * 1024 # 10MB - if limit > MAX_LIMIT: - raise Exception("max_body_size exceeds maximum allowed size (10MB)") - - content_length = request.content_length or 0 - if content_length > limit: - raise Exception(f"Request body too large: {content_length} > {limit}") - - def _validate_ip_whitelist(security_cfg): - """Allow only IPs listed in ip_whitelist.""" - whitelist = security_cfg.get("ip_whitelist", []) - if not whitelist: - return - - client_ip = request.remote_addr - - - for rule in whitelist: - if "/" in rule: - # CIDR notation - if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False): - return - else: - # Single IP - if client_ip == rule: - return - - raise Exception(f"IP {client_ip} is not allowed by whitelist") - - def _validate_rate_limit(security_cfg): - """Simple in-memory rate limiting.""" - rl = security_cfg.get("rate_limit") - if not rl: - return - - limit = int(rl.get("limit", 60)) - if limit <= 0: - raise Exception("rate_limit.limit must be > 0") - per = rl.get("per", "minute") - - window = { - "second": 1, - "minute": 60, - "hour": 3600, - "day": 86400, - }.get(per) - - if not window: - raise Exception(f"Invalid rate_limit.per: {per}") - - capacity = limit - rate = limit / window - cost = 1 - - key = f"rl:tb:{agent_id}" - now = time.time() - - try: - res = REDIS_CONN.lua_token_bucket( - keys=[key], - args=[capacity, rate, now, cost], - client=REDIS_CONN.REDIS, - ) - - allowed = int(res[0]) - if allowed != 1: - raise Exception("Too many requests (rate limit exceeded)") - - except Exception as e: - raise Exception(f"Rate limit error: {e}") - - def _validate_token_auth(security_cfg): - """Validate header-based token authentication.""" - token_cfg = security_cfg.get("token",{}) - header = token_cfg.get("token_header") - token_value = token_cfg.get("token_value") - - provided = request.headers.get(header) - if provided != token_value: - raise Exception("Invalid token authentication") - - def _validate_basic_auth(security_cfg): - """Validate HTTP Basic Auth credentials.""" - auth_cfg = security_cfg.get("basic_auth", {}) - username = auth_cfg.get("username") - password = auth_cfg.get("password") - - auth = request.authorization - if not auth or auth.username != username or auth.password != password: - raise Exception("Invalid Basic Auth credentials") - - def _validate_jwt_auth(security_cfg): - """Validate JWT token in Authorization header.""" - jwt_cfg = security_cfg.get("jwt", {}) - secret = jwt_cfg.get("secret") - if not secret: - raise Exception("JWT secret not configured") - - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - raise Exception("Missing Bearer token") - - token = auth_header[len("Bearer "):].strip() - if not token: - raise Exception("Empty Bearer token") - - alg = (jwt_cfg.get("algorithm") or "HS256").upper() - - decode_kwargs = { - "key": secret, - "algorithms": [alg], - } - options = {} - if jwt_cfg.get("audience"): - decode_kwargs["audience"] = jwt_cfg["audience"] - options["verify_aud"] = True - else: - options["verify_aud"] = False - - if jwt_cfg.get("issuer"): - decode_kwargs["issuer"] = jwt_cfg["issuer"] - options["verify_iss"] = True - else: - options["verify_iss"] = False - try: - decoded = jwt.decode( - token, - options=options, - **decode_kwargs, - ) - except Exception as e: - raise Exception(f"Invalid JWT: {str(e)}") - - raw_required_claims = jwt_cfg.get("required_claims", []) - if isinstance(raw_required_claims, str): - required_claims = [raw_required_claims] - elif isinstance(raw_required_claims, (list, tuple, set)): - required_claims = list(raw_required_claims) - else: - required_claims = [] - - required_claims = [ - c for c in required_claims - if isinstance(c, str) and c.strip() - ] - - RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"} - for claim in required_claims: - if claim in RESERVED_CLAIMS: - raise Exception(f"Reserved JWT claim cannot be required: {claim}") - - for claim in required_claims: - if claim not in decoded: - raise Exception(f"Missing JWT claim: {claim}") - - return decoded - - try: - security_config=webhook_cfg.get("security", {}) - await validate_webhook_security(security_config) - except Exception as e: - return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST - if not isinstance(cvs.dsl, str): - dsl = json.dumps(cvs.dsl, ensure_ascii=False) - try: - canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id) - except Exception as e: - resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)) - resp.status_code = RetCode.BAD_REQUEST - return resp - - # 7. Parse request body - async def parse_webhook_request(content_type): - """Parse request based on content-type and return structured data.""" - - # 1. Query - query_data = {k: v for k, v in request.args.items()} - - # 2. Headers - header_data = {k: v for k, v in request.headers.items()} - - # 3. Body - ctype = request.headers.get("Content-Type", "").split(";")[0].strip() - if ctype and ctype != content_type: - raise ValueError( - f"Invalid Content-Type: expect '{content_type}', got '{ctype}'" - ) - - body_data: dict = {} - - try: - if ctype == "application/json": - body_data = await request.get_json() or {} - - elif ctype == "multipart/form-data": - nonlocal canvas - form = await request.form - files = await request.files - - body_data = {} - - for key, value in form.items(): - body_data[key] = value - - if len(files) > 10: - raise Exception("Too many uploaded files") - for key, file in files.items(): - desc = FileService.upload_info( - cvs.user_id, # user - file, # FileStorage - None # url (None for webhook) - ) - file_parsed= await canvas.get_files_async([desc]) - body_data[key] = file_parsed - - elif ctype == "application/x-www-form-urlencoded": - form = await request.form - body_data = dict(form) - - else: - # text/plain / octet-stream / empty / unknown - raw = await request.get_data() - if raw: - try: - body_data = json.loads(raw.decode("utf-8")) - except Exception: - body_data = {} - else: - body_data = {} - - except Exception: - body_data = {} - - return { - "query": query_data, - "headers": header_data, - "body": body_data, - "content_type": ctype, - } - - def extract_by_schema(data, schema, name="section"): - """ - Extract only fields defined in schema. - Required fields must exist. - Optional fields default to type-based default values. - Type validation included. - """ - props = schema.get("properties", {}) - required = schema.get("required", []) - - extracted = {} - - for field, field_schema in props.items(): - field_type = field_schema.get("type") - - # 1. Required field missing - if field in required and field not in data: - raise Exception(f"{name} missing required field: {field}") - - # 2. Optional → default value - if field not in data: - extracted[field] = default_for_type(field_type) - continue - - raw_value = data[field] - - # 3. Auto convert value - try: - value = auto_cast_value(raw_value, field_type) - except Exception as e: - raise Exception(f"{name}.{field} auto-cast failed: {str(e)}") - - # 4. Type validation - if not validate_type(value, field_type): - raise Exception( - f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}" - ) - - extracted[field] = value - - return extracted - - - def default_for_type(t): - """Return default value for the given schema type.""" - if t == "file": - return [] - if t == "object": - return {} - if t == "boolean": - return False - if t == "number": - return 0 - if t == "string": - return "" - if t and t.startswith("array"): - return [] - if t == "null": - return None - return None - - def auto_cast_value(value, expected_type): - """Convert string values into schema type when possible.""" - - # Non-string values already good - if not isinstance(value, str): - return value - - v = value.strip() - - # Boolean - if expected_type == "boolean": - if v.lower() in ["true", "1"]: - return True - if v.lower() in ["false", "0"]: - return False - raise Exception(f"Cannot convert '{value}' to boolean") - - # Number - if expected_type == "number": - # integer - if v.isdigit() or (v.startswith("-") and v[1:].isdigit()): - return int(v) - - # float - try: - return float(v) - except Exception: - raise Exception(f"Cannot convert '{value}' to number") - - # Object - if expected_type == "object": - try: - parsed = json.loads(v) - if isinstance(parsed, dict): - return parsed - else: - raise Exception("JSON is not an object") - except Exception: - raise Exception(f"Cannot convert '{value}' to object") - - # Array - if expected_type.startswith("array"): - try: - parsed = json.loads(v) - if isinstance(parsed, list): - return parsed - else: - raise Exception("JSON is not an array") - except Exception: - raise Exception(f"Cannot convert '{value}' to array") - - # String (accept original) - if expected_type == "string": - return value - - # File - if expected_type == "file": - return value - # Default: do nothing - return value - - - def validate_type(value, t): - """Validate value type against schema type t.""" - if t == "file": - return isinstance(value, list) - - if t == "string": - return isinstance(value, str) - - if t == "number": - return isinstance(value, (int, float)) - - if t == "boolean": - return isinstance(value, bool) - - if t == "object": - return isinstance(value, dict) - - # array / array / array - if t.startswith("array"): - if not isinstance(value, list): - return False - - if "<" in t and ">" in t: - inner = t[t.find("<") + 1 : t.find(">")] - - # Check each element type - for item in value: - if not validate_type(item, inner): - return False - - return True - - return True - parsed = await parse_webhook_request(webhook_cfg.get("content_types")) - SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}}) - - # Extract strictly by schema - try: - query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query") - header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers") - body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body") - except Exception as e: - return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST - - clean_request = { - "query": query_clean, - "headers": header_clean, - "body": body_clean, - "input": parsed - } - - execution_mode = webhook_cfg.get("execution_mode", "Immediately") - response_cfg = webhook_cfg.get("response", {}) - - def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600): - key = f"webhook-trace-{agent_id}-logs" - - raw = REDIS_CONN.get(key) - obj = json.loads(raw) if raw else {"webhooks": {}} - - ws = obj["webhooks"].setdefault( - str(start_ts), - {"start_ts": start_ts, "events": []} - ) - - ws["events"].append({ - "ts": time.time(), - **event - }) - - REDIS_CONN.set_obj(key, obj, ttl) - - if execution_mode == "Immediately": - status = response_cfg.get("status", 200) - try: - status = int(status) - except (TypeError, ValueError): - return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST - - if not (200 <= status <= 399): - return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST - - body_tpl = response_cfg.get("body_template", "") - - def parse_body(body: str): - if not body: - return None, "application/json" - - try: - parsed = json.loads(body) - return parsed, "application/json" - except (json.JSONDecodeError, TypeError): - return body, "text/plain" - - - body, content_type = parse_body(body_tpl) - resp = Response( - json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body, - status=status, - content_type=content_type, - ) - - async def background_run(): - try: - async for ans in canvas.run( - query="", - user_id=cvs.user_id, - webhook_payload=clean_request - ): - if is_test: - append_webhook_trace(agent_id, start_ts, ans) - - if is_test: - append_webhook_trace( - agent_id, - start_ts, - { - "event": "finished", - "elapsed_time": time.time() - start_ts, - "success": True, - } - ) - - cvs.dsl = json.loads(str(canvas)) - UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict()) - - except Exception as e: - logging.exception("Webhook background run failed") - if is_test: - try: - append_webhook_trace( - agent_id, - start_ts, - { - "event": "error", - "message": str(e), - "error_type": type(e).__name__, - } - ) - append_webhook_trace( - agent_id, - start_ts, - { - "event": "finished", - "elapsed_time": time.time() - start_ts, - "success": False, - } - ) - except Exception: - logging.exception("Failed to append webhook trace") - - asyncio.create_task(background_run()) - return resp - else: - async def sse(): - nonlocal canvas - contents: list[str] = [] - status = 200 - try: - async for ans in canvas.run( - query="", - user_id=cvs.user_id, - webhook_payload=clean_request, - ): - if ans["event"] == "message": - content = ans["data"]["content"] - if ans["data"].get("start_to_think", False): - content = "" - elif ans["data"].get("end_to_think", False): - content = "" - if content: - contents.append(content) - if ans["event"] == "message_end": - status = int(ans["data"].get("status", status)) - if is_test: - append_webhook_trace( - agent_id, - start_ts, - ans - ) - if is_test: - append_webhook_trace( - agent_id, - start_ts, - { - "event": "finished", - "elapsed_time": time.time() - start_ts, - "success": True, - } - ) - final_content = "".join(contents) - return { - "message": final_content, - "success": True, - "code": status, - } - - except Exception as e: - if is_test: - append_webhook_trace( - agent_id, - start_ts, - { - "event": "error", - "message": str(e), - "error_type": type(e).__name__, - } - ) - append_webhook_trace( - agent_id, - start_ts, - { - "event": "finished", - "elapsed_time": time.time() - start_ts, - "success": False, - } - ) - return {"code": 400, "message": str(e),"success":False} - - result = await sse() - return Response( - json.dumps(result), - status=result["code"], - mimetype="application/json", - ) - - -@manager.route("/webhook_trace/", methods=["GET"]) # noqa: F821 -async def webhook_trace(agent_id: str): - def encode_webhook_id(start_ts: str) -> str: - WEBHOOK_ID_SECRET = "webhook_id_secret" - sig = hmac.new( - WEBHOOK_ID_SECRET.encode("utf-8"), - start_ts.encode("utf-8"), - hashlib.sha256, - ).digest() - return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=") - - def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None: - for ts in webhooks.keys(): - if encode_webhook_id(ts) == enc_id: - return ts - return None - since_ts = request.args.get("since_ts", type=float) - webhook_id = request.args.get("webhook_id") - - key = f"webhook-trace-{agent_id}-logs" - raw = REDIS_CONN.get(key) - - if since_ts is None: - now = time.time() - return get_json_result( - data={ - "webhook_id": None, - "events": [], - "next_since_ts": now, - "finished": False, - } - ) - - if not raw: - return get_json_result( - data={ - "webhook_id": None, - "events": [], - "next_since_ts": since_ts, - "finished": False, - } - ) - - obj = json.loads(raw) - webhooks = obj.get("webhooks", {}) - - if webhook_id is None: - candidates = [ - float(k) for k in webhooks.keys() if float(k) > since_ts - ] - - if not candidates: - return get_json_result( - data={ - "webhook_id": None, - "events": [], - "next_since_ts": since_ts, - "finished": False, - } - ) - - start_ts = min(candidates) - real_id = str(start_ts) - webhook_id = encode_webhook_id(real_id) - - return get_json_result( - data={ - "webhook_id": webhook_id, - "events": [], - "next_since_ts": start_ts, - "finished": False, - } - ) - - real_id = decode_webhook_id(webhook_id, webhooks) - - if not real_id: - return get_json_result( - data={ - "webhook_id": webhook_id, - "events": [], - "next_since_ts": since_ts, - "finished": True, - } - ) - - ws = webhooks.get(str(real_id)) - events = ws.get("events", []) - new_events = [e for e in events if e.get("ts", 0) > since_ts] - - next_ts = since_ts - for e in new_events: - next_ts = max(next_ts, e["ts"]) - - finished = any(e.get("event") == "finished" for e in new_events) - - return get_json_result( - data={ - "webhook_id": webhook_id, - "events": new_events, - "next_since_ts": next_ts, - "finished": finished, - } - ) diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index e6dd61d035e..e85a1d439c5 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -122,6 +122,8 @@ async def retrieval(tenant_id): retrieval_setting = req.get("retrieval_setting", {}) similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0)) top = int(retrieval_setting.get("top_k", 1024)) + if top <= 0: + return build_error_result(message="`top_k` must be greater than 0", code=RetCode.DATA_ERROR) metadata_condition = req.get("metadata_condition", {}) or {} metas = DocMetadataService.get_flatted_meta_by_kbs([kb_id]) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index bff583e4976..cf297c4b250 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -13,59 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import datetime -import re +import logging from io import BytesIO -import xxhash -from pydantic import BaseModel, Field, validator from quart import request, send_file -from api.db.db_models import APIToken, Document, File, Task +from api.db.db_models import APIToken, Document, Task from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService -from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks from api.db.services.tenant_llm_service import TenantLLMService from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_request_json, get_result, server_error_response, token_required -from api.utils.image_utils import store_chunk_image from common import settings -from common.constants import FileSource, LLMType, ParserType, RetCode, TaskStatus +from common.constants import LLMType, RetCode, TaskStatus from common.metadata_utils import convert_conditions, meta_filter -from common.misc_utils import thread_pool_exec -from common.string_utils import is_content_empty, remove_redundant_spaces -from common.tag_feature_utils import validate_tag_features -from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question -from rag.nlp import rag_tokenizer, search +from rag.nlp import search from rag.prompts.generator import cross_languages, keyword_extraction MAXIMUM_OF_UPLOADING_FILES = 256 -class Chunk(BaseModel): - id: str = "" - content: str = "" - document_id: str = "" - docnm_kwd: str = "" - important_keywords: list = Field(default_factory=list) - tag_kwd: list = Field(default_factory=list) - questions: list = Field(default_factory=list) - question_tks: str = "" - image_id: str = "" - available: bool = True - positions: list[list[int]] = Field(default_factory=list) +from api.utils.reference_metadata_utils import ( + enrich_chunks_with_document_metadata, + resolve_reference_metadata_preferences, +) + +def _resolve_reference_metadata(req: dict, search_config: dict | None = None): + return resolve_reference_metadata_preferences(req, search_config) - @validator("positions") - def validate_positions(cls, value): - for sublist in value: - if len(sublist) != 5: - raise ValueError("Each sublist in positions must have a length of 5") - return value +def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=None) -> None: + enrich_chunks_with_document_metadata(chunks, metadata_fields) @manager.route("/datasets//documents/", methods=["GET"]) # noqa: F821 @@ -134,15 +116,30 @@ async def download_doc(document_id): if len(token) != 2: return get_error_data_result(message="Authorization is not valid!") token = token[1] + logging.info("Beta API token lookup attempted for document download") objs = APIToken.query(beta=token) if not objs: + logging.warning("Beta API token lookup failed for document download: invalid API key") return get_error_data_result(message='Authentication error: API key is invalid!"') + if len(objs) > 1: + logging.error("Beta API token lookup is ambiguous for document download: matches=%s", len(objs)) + return get_error_data_result(message="Authentication error: API key configuration is ambiguous.") + tenant_id = objs[0].tenant_id + logging.info("Beta API token authorized for document download: tenant_id=%s", tenant_id) if not document_id: return get_error_data_result(message="Specify document_id please.") doc = DocumentService.query(id=document_id) if not doc: return get_error_data_result(message=f"The dataset not own the document {document_id}.") + if not KnowledgebaseService.query(id=doc[0].kb_id, tenant_id=tenant_id): + logging.warning( + "cross-tenant access denied for document download: tenant_id=%s kb_id=%s document_id=%s", + tenant_id, + doc[0].kb_id, + document_id, + ) + return get_error_data_result(message="You do not have access to this document.") # The process of downloading doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address file_stream = settings.STORAGE_IMPL.get(doc_id, doc_location) @@ -158,171 +155,6 @@ async def download_doc(document_id): ) -@manager.route("/datasets//metadata/update", methods=["POST"]) # noqa: F821 -@token_required -async def metadata_batch_update(dataset_id, tenant_id): - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): - return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") - - req = await get_request_json() - selector = req.get("selector", {}) or {} - updates = req.get("updates", []) or [] - deletes = req.get("deletes", []) or [] - - if not isinstance(selector, dict): - return get_error_data_result(message="selector must be an object.") - if not isinstance(updates, list) or not isinstance(deletes, list): - return get_error_data_result(message="updates and deletes must be lists.") - - metadata_condition = selector.get("metadata_condition", {}) or {} - if metadata_condition and not isinstance(metadata_condition, dict): - return get_error_data_result(message="metadata_condition must be an object.") - - document_ids = selector.get("document_ids", []) or [] - if document_ids and not isinstance(document_ids, list): - return get_error_data_result(message="document_ids must be a list.") - - for upd in updates: - if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd: - return get_error_data_result(message="Each update requires key and value.") - for d in deletes: - if not isinstance(d, dict) or not d.get("key"): - return get_error_data_result(message="Each delete requires key.") - - if document_ids: - kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id]) - target_doc_ids = set(kb_doc_ids) - invalid_ids = set(document_ids) - set(kb_doc_ids) - if invalid_ids: - return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}") - target_doc_ids = set(document_ids) - - if metadata_condition: - metas = DocMetadataService.get_flatted_meta_by_kbs([dataset_id]) - filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) - target_doc_ids = target_doc_ids & filtered_ids - if metadata_condition.get("conditions") and not target_doc_ids: - return get_result(data={"updated": 0, "matched_docs": 0}) - - target_doc_ids = list(target_doc_ids) - updated = DocMetadataService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes) - return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)}) - - -@manager.route("/datasets//documents", methods=["DELETE"]) # noqa: F821 -@token_required -async def delete(tenant_id, dataset_id): - """ - Delete documents from a dataset. - --- - tags: - - Documents - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset. - - in: body - name: body - description: Document deletion parameters. - required: true - schema: - type: object - properties: - ids: - type: array - items: - type: string - description: | - List of document IDs to delete. - If omitted, `null`, or an empty array is provided, no documents will be deleted. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: Documents deleted successfully. - schema: - type: object - """ - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): - return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") - req = await get_request_json() - if not req: - return get_result() - - doc_ids = req.get("ids") - if not doc_ids: - if req.get("delete_all") is True: - doc_ids = [doc.id for doc in DocumentService.query(kb_id=dataset_id)] - if not doc_ids: - return get_result() - else: - return get_result() - - doc_list = doc_ids - - unique_doc_ids, duplicate_messages = check_duplicate_ids(doc_list, "document") - doc_list = unique_doc_ids - - root_folder = FileService.get_root_folder(tenant_id) - pf_id = root_folder["id"] - FileService.init_knowledgebase_docs(pf_id, tenant_id) - errors = "" - not_found = [] - success_count = 0 - for doc_id in doc_list: - try: - e, doc = DocumentService.get_by_id(doc_id) - if not e: - not_found.append(doc_id) - continue - tenant_id = DocumentService.get_tenant_id(doc_id) - if not tenant_id: - return get_error_data_result(message="Tenant not found!") - - b, n = File2DocumentService.get_storage_address(doc_id=doc_id) - - if not DocumentService.remove_document(doc, tenant_id): - return get_error_data_result(message="Database error (Document removal)!") - - f2d = File2DocumentService.get_by_document_id(doc_id) - FileService.filter_delete( - [ - File.source_type == FileSource.KNOWLEDGEBASE, - File.id == f2d[0].file_id, - ] - ) - File2DocumentService.delete_by_document_id(doc_id) - - settings.STORAGE_IMPL.rm(b, n) - success_count += 1 - except Exception as e: - errors += str(e) - - if not_found: - return get_result(message=f"Documents not found: {not_found}", code=RetCode.DATA_ERROR) - - if errors: - return get_result(message=errors, code=RetCode.SERVER_ERROR) - - if duplicate_messages: - if success_count > 0: - return get_result( - message=f"Partially deleted {success_count} datasets with {len(duplicate_messages)} errors", - data={"success_count": success_count, "errors": duplicate_messages}, - ) - else: - return get_error_data_result(message=";".join(duplicate_messages)) - - return get_result() - - DOC_STOP_PARSING_INVALID_STATE_MESSAGE = "Can't stop parsing document that has not started or already completed" DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE = "DOC_STOP_PARSING_INVALID_STATE" @@ -495,642 +327,6 @@ async def stop_parsing(tenant_id, dataset_id): return get_result() -@manager.route("/datasets//documents//chunks", methods=["GET"]) # noqa: F821 -@token_required -async def list_chunks(tenant_id, dataset_id, document_id): - """ - List chunks of a document. - --- - tags: - - Chunks - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset. - - in: path - name: document_id - type: string - required: true - description: ID of the document. - - in: query - name: page - type: integer - required: false - default: 1 - description: Page number. - - in: query - name: page_size - type: integer - required: false - default: 30 - description: Number of items per page. - - in: query - name: id - type: string - required: false - default: "" - description: Chunk id. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: List of chunks. - schema: - type: object - properties: - total: - type: integer - description: Total number of chunks. - chunks: - type: array - items: - type: object - properties: - id: - type: string - description: Chunk ID. - content: - type: string - description: Chunk content. - document_id: - type: string - description: ID of the document. - important_keywords: - type: array - items: - type: string - description: Important keywords. - tag_kwd: - type: array - items: - type: string - description: Tag keywords. - image_id: - type: string - description: Image ID associated with the chunk. - doc: - type: object - description: Document details. - """ - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): - return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - doc = DocumentService.query(id=document_id, kb_id=dataset_id) - if not doc: - return get_error_data_result(message=f"You don't own the document {document_id}.") - doc = doc[0] - req = request.args - doc_id = document_id - page = int(req.get("page", 1)) - size = int(req.get("page_size", 30)) - question = req.get("keywords", "") - query = { - "doc_ids": [doc_id], - "page": page, - "size": size, - "question": question, - "sort": True, - } - if "available" in req: - query["available_int"] = 1 if req["available"] == "true" else 0 - key_mapping = { - "chunk_num": "chunk_count", - "kb_id": "dataset_id", - "token_num": "token_count", - "parser_id": "chunk_method", - } - run_mapping = { - "0": "UNSTART", - "1": "RUNNING", - "2": "CANCEL", - "3": "DONE", - "4": "FAIL", - } - doc = doc.to_dict() - renamed_doc = {} - for key, value in doc.items(): - new_key = key_mapping.get(key, key) - renamed_doc[new_key] = value - if key == "run": - renamed_doc["run"] = run_mapping.get(str(value)) - - res = {"total": 0, "chunks": [], "doc": renamed_doc} - if req.get("id"): - chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id]) - if not chunk: - return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND) - k = [] - for n in chunk.keys(): - if re.search(r"(_vec$|_sm_|_tks|_ltks)", n): - k.append(n) - for n in k: - del chunk[n] - if not chunk: - return get_error_data_result(f"Chunk `{req.get('id')}` not found.") - res["total"] = 1 - final_chunk = { - "id": chunk.get("id", chunk.get("chunk_id")), - "content": chunk["content_with_weight"], - "document_id": chunk.get("doc_id", chunk.get("document_id")), - "docnm_kwd": chunk["docnm_kwd"], - "important_keywords": chunk.get("important_kwd", []), - "questions": chunk.get("question_kwd", []), - "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")), - "image_id": chunk.get("img_id", ""), - "available": bool(chunk.get("available_int", 1)), - "positions": chunk.get("position_int", []), - "tag_kwd": chunk.get("tag_kwd", []), - "tag_feas": chunk.get("tag_feas", {}), - } - res["chunks"].append(final_chunk) - _ = Chunk(**final_chunk) - - elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id): - sres = await settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) - res["total"] = sres.total - for id in sres.ids: - d = { - "id": id, - "content": (remove_redundant_spaces(sres.highlight[id]) if question and id in sres.highlight else sres.field[id].get("content_with_weight", "")), - "document_id": sres.field[id]["doc_id"], - "docnm_kwd": sres.field[id]["docnm_kwd"], - "important_keywords": sres.field[id].get("important_kwd", []), - "tag_kwd": sres.field[id].get("tag_kwd", []), - "questions": sres.field[id].get("question_kwd", []), - "dataset_id": sres.field[id].get("kb_id", sres.field[id].get("dataset_id")), - "image_id": sres.field[id].get("img_id", ""), - "available": bool(int(sres.field[id].get("available_int", "1"))), - "positions": sres.field[id].get("position_int", []), - } - res["chunks"].append(d) - _ = Chunk(**d) # validate the chunk - return get_result(data=res) - - -@manager.route( # noqa: F821 - "/datasets//documents//chunks", methods=["POST"] -) -@token_required -async def add_chunk(tenant_id, dataset_id, document_id): - """ - Add a chunk to a document. - --- - tags: - - Chunks - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset. - - in: path - name: document_id - type: string - required: true - description: ID of the document. - - in: body - name: body - description: Chunk data. - required: true - schema: - type: object - properties: - content: - type: string - required: true - description: Content of the chunk. - important_keywords: - type: array - items: - type: string - description: Important keywords. - image_base64: - type: string - description: Base64-encoded image to associate with the chunk. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: Chunk added successfully. - schema: - type: object - properties: - chunk: - type: object - properties: - id: - type: string - description: Chunk ID. - content: - type: string - description: Chunk content. - document_id: - type: string - description: ID of the document. - important_keywords: - type: array - items: - type: string - description: Important keywords. - """ - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): - return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - doc = DocumentService.query(id=document_id, kb_id=dataset_id) - if not doc: - return get_error_data_result(message=f"You don't own the document {document_id}.") - doc = doc[0] - req = await get_request_json() - if is_content_empty(req.get("content")): - return get_error_data_result(message="`content` is required") - if "important_keywords" in req: - if not isinstance(req["important_keywords"], list): - return get_error_data_result("`important_keywords` is required to be a list") - if "questions" in req: - if not isinstance(req["questions"], list): - return get_error_data_result("`questions` is required to be a list") - chunk_id = xxhash.xxh64((req["content"] + document_id).encode("utf-8")).hexdigest() - d = { - "id": chunk_id, - "content_ltks": rag_tokenizer.tokenize(req["content"]), - "content_with_weight": req["content"], - } - d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - d["important_kwd"] = req.get("important_keywords", []) - d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_keywords", []))) - d["question_kwd"] = [str(q).strip() for q in req.get("questions", []) if str(q).strip()] - d["question_tks"] = rag_tokenizer.tokenize("\n".join(req.get("questions", []))) - d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] - d["create_timestamp_flt"] = datetime.datetime.now().timestamp() - d["kb_id"] = dataset_id - d["docnm_kwd"] = doc.name - d["doc_id"] = document_id - if "tag_kwd" in req: - if not isinstance(req["tag_kwd"], list): - return get_error_data_result("`tag_kwd` is required to be a list") - if not all(isinstance(t, str) for t in req["tag_kwd"]): - return get_error_data_result("`tag_kwd` must be a list of strings") - d["tag_kwd"] = req["tag_kwd"] - if "tag_feas" in req: - try: - d["tag_feas"] = validate_tag_features(req["tag_feas"]) - except ValueError as exc: - return get_error_data_result(f"`tag_feas` {exc}") - import base64 - - image_base64 = req.get("image_base64", None) - if image_base64: - d["img_id"] = "{}-{}".format(dataset_id, chunk_id) - d["doc_type_kwd"] = "image" - - tenant_embd_id = DocumentService.get_tenant_embd_id(document_id) - if tenant_embd_id: - model_config = get_model_config_by_id(tenant_embd_id) - else: - embd_id = DocumentService.get_embd_id(document_id) - model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id) - embd_mdl = TenantLLMService.model_instance(model_config) - v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) - v = 0.1 * v[0] + 0.9 * v[1] - d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) - - if image_base64: - store_chunk_image(dataset_id, chunk_id, base64.b64decode(image_base64)) - - DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) - # rename keys - key_mapping = { - "id": "id", - "content_with_weight": "content", - "doc_id": "document_id", - "important_kwd": "important_keywords", - "tag_kwd": "tag_kwd", - "question_kwd": "questions", - "kb_id": "dataset_id", - "create_timestamp_flt": "create_timestamp", - "create_time": "create_time", - "document_keyword": "document", - "img_id": "image_id", - } - renamed_chunk = {} - for key, value in d.items(): - if key in key_mapping: - new_key = key_mapping.get(key, key) - renamed_chunk[new_key] = value - _ = Chunk(**renamed_chunk) # validate the chunk - return get_result(data={"chunk": renamed_chunk}) - # return get_result(data={"chunk_id": chunk_id}) - - -@manager.route( # noqa: F821 - "datasets//documents//chunks", methods=["DELETE"] -) -@token_required -async def rm_chunk(tenant_id, dataset_id, document_id): - """ - Remove chunks from a document. - --- - tags: - - Chunks - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset. - - in: path - name: document_id - type: string - required: true - description: ID of the document. - - in: body - name: body - description: Chunk removal parameters. - required: true - schema: - type: object - properties: - chunk_ids: - type: array - items: - type: string - description: | - List of chunk IDs to remove. - If omitted, `null`, or an empty array is provided, no chunks will be deleted. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: Chunks removed successfully. - schema: - type: object - """ - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): - return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - docs = DocumentService.get_by_ids([document_id]) - if not docs: - raise LookupError(f"Can't find the document with ID {document_id}!") - req = await get_request_json() - if not req: - return get_result() - - chunk_ids = req.get("chunk_ids") - if not chunk_ids: - if req.get("delete_all") is True: - doc = docs[0] - # Clean up storage assets while index rows still exist for discovery - DocumentService.delete_chunk_images(doc, tenant_id) - condition = {"doc_id": document_id} - chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) - if chunk_number != 0: - DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) - return get_result(message=f"deleted {chunk_number} chunks") - else: - return get_result() - - condition = {"doc_id": document_id} - unique_chunk_ids, duplicate_messages = check_duplicate_ids(chunk_ids, "chunk") - condition["id"] = unique_chunk_ids - chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) - if chunk_number != 0: - DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) - if chunk_number != len(unique_chunk_ids): - if len(unique_chunk_ids) == 0: - return get_result(message=f"deleted {chunk_number} chunks") - return get_error_data_result(message=f"rm_chunk deleted chunks {chunk_number}, expect {len(unique_chunk_ids)}") - if duplicate_messages: - return get_result( - message=f"Partially deleted {chunk_number} chunks with {len(duplicate_messages)} errors", - data={"success_count": chunk_number, "errors": duplicate_messages}, - ) - return get_result(message=f"deleted {chunk_number} chunks") - - -@manager.route( # noqa: F821 - "/datasets//documents//chunks/", methods=["PUT"] -) -@token_required -async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): - """ - Update a chunk within a document. - --- - tags: - - Chunks - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset. - - in: path - name: document_id - type: string - required: true - description: ID of the document. - - in: path - name: chunk_id - type: string - required: true - description: ID of the chunk to update. - - in: body - name: body - description: Chunk update parameters. - required: true - schema: - type: object - properties: - content: - type: string - description: Updated content of the chunk. - important_keywords: - type: array - items: - type: string - description: Updated important keywords. - tag_kwd: - type: array - items: - type: string - description: Updated tag keywords. - available: - type: boolean - description: Availability status of the chunk. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: Chunk updated successfully. - schema: - type: object - """ - chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) - if chunk is None: - return get_error_data_result(f"Can't find this chunk {chunk_id}") - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): - return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - doc = DocumentService.query(id=document_id, kb_id=dataset_id) - if not doc: - return get_error_data_result(message=f"You don't own the document {document_id}.") - doc = doc[0] - req = await get_request_json() - content = req.get("content") - if content is not None: - if is_content_empty(content): - return get_error_data_result(message="`content` is required") - else: - content = chunk.get("content_with_weight", "") - d = {"id": chunk_id, "content_with_weight": content} - d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"]) - d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - if "important_keywords" in req: - if not isinstance(req["important_keywords"], list): - return get_error_data_result("`important_keywords` should be a list") - d["important_kwd"] = req.get("important_keywords", []) - d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"])) - if "questions" in req: - if not isinstance(req["questions"], list): - return get_error_data_result("`questions` should be a list") - d["question_kwd"] = [str(q).strip() for q in req.get("questions", []) if str(q).strip()] - d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["questions"])) - if "available" in req: - d["available_int"] = int(req["available"]) - if "positions" in req: - if not isinstance(req["positions"], list): - return get_error_data_result("`positions` should be a list") - d["position_int"] = req["positions"] - if "tag_kwd" in req: - if not isinstance(req["tag_kwd"], list): - return get_error_data_result("`tag_kwd` should be a list") - if not all(isinstance(t, str) for t in req["tag_kwd"]): - return get_error_data_result("`tag_kwd` must be a list of strings") - d["tag_kwd"] = req["tag_kwd"] - if "tag_feas" in req: - try: - d["tag_feas"] = validate_tag_features(req["tag_feas"]) - except ValueError as exc: - return get_error_data_result(f"`tag_feas` {exc}") - tenant_embd_id = DocumentService.get_tenant_embd_id(document_id) - if tenant_embd_id: - model_config = get_model_config_by_id(tenant_embd_id) - else: - embd_id = DocumentService.get_embd_id(document_id) - model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id) - embd_mdl = TenantLLMService.model_instance(model_config) - if doc.parser_id == ParserType.QA: - arr = [t for t in re.split(r"[\n\t]", d["content_with_weight"]) if len(t) > 1] - if len(arr) != 2: - return get_error_data_result(message="Q&A must be separated by TAB/ENTER key.") - q, a = rmPrefix(arr[0]), rmPrefix(arr[1]) - d = beAdoc(d, arr[0], arr[1], not any([rag_tokenizer.is_chinese(t) for t in q + a])) - - v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) - v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] - d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) - return get_result() - - -@manager.route( # noqa: F821 - "/datasets//documents//chunks/switch", methods=["POST"] -) -@token_required -async def switch_chunks(tenant_id, dataset_id, document_id): - """ - Switch availability of specified chunks (same as chunk_app switch). - --- - tags: - - Chunks - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset. - - in: path - name: document_id - type: string - required: true - description: ID of the document. - - in: body - name: body - required: true - schema: - type: object - properties: - chunk_ids: - type: array - items: - type: string - description: List of chunk IDs to switch. - available_int: - type: integer - description: 1 for available, 0 for unavailable. - available: - type: boolean - description: Availability status (alternative to available_int). - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: Chunks availability switched successfully. - """ - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): - return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - req = await get_request_json() - if not req.get("chunk_ids"): - return get_error_data_result(message="`chunk_ids` is required.") - if "available_int" not in req and "available" not in req: - return get_error_data_result(message="`available_int` or `available` is required.") - available_int = int(req["available_int"]) if "available_int" in req else (1 if req.get("available") else 0) - try: - - def _switch_sync(): - e, doc = DocumentService.get_by_id(document_id) - if not e: - return get_error_data_result(message="Document not found!") - if not doc or str(doc.kb_id) != str(dataset_id): - return get_error_data_result(message="Document not found!") - for cid in req["chunk_ids"]: - if not settings.docStoreConn.update( - {"id": cid}, - {"available_int": available_int}, - search.index_name(tenant_id), - doc.kb_id, - ): - return get_error_data_result(message="Index updating failure") - return get_result(data=True) - - return await thread_pool_exec(_switch_sync) - except Exception as e: - return server_error_response(e) - - @manager.route("/retrieval", methods=["POST"]) # noqa: F821 @token_required async def retrieval_test(tenant_id): @@ -1268,6 +464,8 @@ async def retrieval_test(tenant_id): similarity_threshold = float(req.get("similarity_threshold", 0.2)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = int(req.get("top_k", 1024)) + if top <= 0: + return get_error_data_result("`top_k` must be greater than 0") highlight_val = req.get("highlight", None) if highlight_val is None: highlight = False @@ -1280,6 +478,7 @@ async def retrieval_test(tenant_id): return get_error_data_result("`highlight` should be a boolean") else: return get_error_data_result("`highlight` should be a boolean") + include_metadata, metadata_fields = _resolve_reference_metadata(req) try: tenant_ids = list(set([kb.tenant_id for kb in kbs])) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) @@ -1338,6 +537,15 @@ async def retrieval_test(tenant_id): for c in ranks["chunks"]: c.pop("vector", None) + if include_metadata: + logging.info( + "sdk.retrieval reference_metadata enabled dataset_ids=%s fields=%s chunks=%s", + kb_ids, + sorted(metadata_fields) if metadata_fields else None, + len(ranks["chunks"]), + ) + enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields) + ##rename keys renamed_chunks = [] for chunk in ranks["chunks"]: diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 82e048ff17b..11960dcf65c 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -14,47 +14,44 @@ # limitations under the License. # import json -import copy import re -import time -import os -import tempfile import logging -from quart import Response, jsonify, request - -from common.token_utils import num_tokens_from_string +from quart import Response, request from agent.canvas import Canvas from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService -from api.db.services.canvas_service import UserCanvasService, completion_openai +from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import completion as agent_completion -from api.db.services.conversation_service import ConversationService from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.conversation_service import async_iframe_completion as iframe_completion -from api.db.services.conversation_service import async_completion as rag_completion -from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap +from api.db.services.dialog_service import DialogService, async_ask, gen_mindmap from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter +from common.metadata_utils import apply_meta_data_filter from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, \ get_model_config_by_type_and_name from common.misc_utils import get_uuid -from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \ +from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \ get_result, get_request_json, server_error_response, token_required, validate_request from rag.app.tag import label_question from rag.prompts.template import load_prompt -from rag.prompts.generator import cross_languages, keyword_extraction, chunks_format +from rag.prompts.generator import cross_languages, keyword_extraction from common.constants import RetCode, LLMType, StatusEnum from common import settings +from api.utils.reference_metadata_utils import ( + enrich_chunks_with_document_metadata, + resolve_reference_metadata_preferences, +) + +logger = logging.getLogger(__name__) -@manager.route("/agents//sessions", methods=["POST"]) # noqa: F821 @token_required async def create_agent_session(tenant_id, agent_id): req = await get_request_json() @@ -92,558 +89,6 @@ async def create_agent_session(tenant_id, agent_id): return get_result(data=conv) -@manager.route("/chats//completions", methods=["POST"]) # noqa: F821 -@token_required -async def chat_completion(tenant_id, chat_id): - req = await get_request_json() - if not req: - req = {"question": ""} - if not req.get("session_id"): - req["question"] = "" - dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value) - if not dia: - return get_error_data_result(f"You don't own the chat {chat_id}") - dia = dia[0] - if req.get("session_id"): - if not ConversationService.query(id=req["session_id"], dialog_id=chat_id): - return get_error_data_result(f"You don't own the session {req['session_id']}") - - metadata_condition = req.get("metadata_condition") or {} - if metadata_condition and not isinstance(metadata_condition, dict): - return get_error_data_result(message="metadata_condition must be an object.") - - if metadata_condition and req.get("question"): - metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or []) - filtered_doc_ids = meta_filter( - metas, - convert_conditions(metadata_condition), - metadata_condition.get("logic", "and"), - ) - if metadata_condition.get("conditions") and not filtered_doc_ids: - filtered_doc_ids = ["-999"] - - if filtered_doc_ids: - req["doc_ids"] = ",".join(filtered_doc_ids) - else: - req.pop("doc_ids", None) - - if req.get("stream", True): - resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - - return resp - else: - answer = None - async for ans in rag_completion(tenant_id, chat_id, **req): - answer = ans - break - return get_result(data=answer) - - -@manager.route("/chats_openai//chat/completions", methods=["POST"]) # noqa: F821 -@validate_request("model", "messages") # noqa: F821 -@token_required -async def chat_completion_openai_like(tenant_id, chat_id): - """ - OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint. - - This function allows users to interact with a model and receive responses based on a series of historical messages. - If `stream` is set to True (by default), the response will be streamed in chunks, mimicking the OpenAI-style API. - Set `stream` to False explicitly, the response will be returned in a single complete answer. - - Reference: - - - If `stream` is True, the final answer and reference information will appear in the **last chunk** of the stream. - - If `stream` is False, the reference will be included in `choices[0].message.reference`. - - If `extra_body.reference_metadata.include` is True, each reference chunk may include `document_metadata` in both streaming and non-streaming responses. - - Example usage: - - curl -X POST https://ragflow_address.com/api/v1/chats_openai//chat/completions \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer $RAGFLOW_API_KEY" \ - -d '{ - "model": "model", - "messages": [{"role": "user", "content": "Say this is a test!"}], - "stream": true - }' - - Alternatively, you can use Python's `OpenAI` client: - - NOTE: Streaming via `client.chat.completions.create(stream=True, ...)` does - not return `reference` currently. The only way to return `reference` is - non-stream mode with `with_raw_response`. - - from openai import OpenAI - import json - - model = "model" - client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/") - - stream = True - reference = True - - request_kwargs = dict( - model="model", - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Who are you?"}, - {"role": "assistant", "content": "I am an AI assistant named..."}, - {"role": "user", "content": "Can you tell me how to install neovim"}, - ], - extra_body={ - "reference": reference, - "reference_metadata": { - "include": True, - "fields": ["author", "year", "source"], - }, - "metadata_condition": { - "logic": "and", - "conditions": [ - { - "name": "author", - "comparison_operator": "is", - "value": "bob" - } - ] - } - }, - ) - - if stream: - completion = client.chat.completions.create(stream=True, **request_kwargs) - for chunk in completion: - print(chunk) - else: - resp = client.chat.completions.with_raw_response.create( - stream=False, **request_kwargs - ) - print("status:", resp.http_response.status_code) - raw_text = resp.http_response.text - print("raw:", raw_text) - - data = json.loads(raw_text) - print("assistant:", data["choices"][0]["message"].get("content")) - print("reference:", data["choices"][0]["message"].get("reference")) - - """ - req = await get_request_json() - - extra_body = req.get("extra_body") or {} - if extra_body and not isinstance(extra_body, dict): - return get_error_data_result("extra_body must be an object.") - - need_reference = bool(extra_body.get("reference", False)) - reference_metadata = extra_body.get("reference_metadata") or {} - if reference_metadata and not isinstance(reference_metadata, dict): - return get_error_data_result("reference_metadata must be an object.") - include_reference_metadata = bool(reference_metadata.get("include", False)) - metadata_fields = reference_metadata.get("fields") - if metadata_fields is not None and not isinstance(metadata_fields, list): - return get_error_data_result("reference_metadata.fields must be an array.") - - messages = req.get("messages", []) - # To prevent empty [] input - if len(messages) < 1: - return get_error_data_result("You have to provide messages.") - if messages[-1]["role"] != "user": - return get_error_data_result("The last content of this conversation is not from user.") - - prompt = messages[-1]["content"] - # Treat context tokens as reasoning tokens - context_token_used = sum(num_tokens_from_string(message["content"]) for message in messages) - - dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value) - if not dia: - return get_error_data_result(f"You don't own the chat {chat_id}") - dia = dia[0] - - metadata_condition = extra_body.get("metadata_condition") or {} - if metadata_condition and not isinstance(metadata_condition, dict): - return get_error_data_result(message="metadata_condition must be an object.") - - doc_ids_str = None - if metadata_condition: - metas = DocMetadataService.get_flatted_meta_by_kbs(dia.kb_ids or []) - filtered_doc_ids = meta_filter( - metas, - convert_conditions(metadata_condition), - metadata_condition.get("logic", "and"), - ) - if metadata_condition.get("conditions") and not filtered_doc_ids: - filtered_doc_ids = ["-999"] - doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None - - # Filter system and non-sense assistant messages - msg = [] - for m in messages: - if m["role"] == "system": - continue - if m["role"] == "assistant" and not msg: - continue - msg.append(m) - - # tools = get_tools() - # toolcall_session = SimpleFunctionCallServer() - tools = None - toolcall_session = None - - if req.get("stream", True): - # The value for the usage field on all chunks except for the last one will be null. - # The usage field on the last chunk contains token usage statistics for the entire request. - # The choices field on the last chunk will always be an empty array []. - async def streamed_response_generator(chat_id, dia, msg): - token_used = 0 - last_ans = {} - full_content = "" - full_reasoning = "" - final_answer = None - final_reference = None - in_think = False - response = { - "id": f"chatcmpl-{chat_id}", - "choices": [ - { - "delta": { - "content": "", - "role": "assistant", - "function_call": None, - "tool_calls": None, - "reasoning_content": "", - }, - "finish_reason": None, - "index": 0, - "logprobs": None, - } - ], - "created": int(time.time()), - "model": "model", - "object": "chat.completion.chunk", - "system_fingerprint": "", - "usage": None, - } - - try: - chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference} - if doc_ids_str: - chat_kwargs["doc_ids"] = doc_ids_str - async for ans in async_chat(dia, msg, True, **chat_kwargs): - last_ans = ans - if ans.get("final"): - if ans.get("answer"): - full_content = ans["answer"] - response["choices"][0]["delta"]["content"] = full_content - response["choices"][0]["delta"]["reasoning_content"] = None - yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" - final_answer = full_content - final_reference = ans.get("reference", {}) - continue - if ans.get("start_to_think"): - in_think = True - continue - if ans.get("end_to_think"): - in_think = False - continue - delta = ans.get("answer") or "" - if not delta: - continue - token_used += num_tokens_from_string(delta) - if in_think: - full_reasoning += delta - response["choices"][0]["delta"]["reasoning_content"] = delta - response["choices"][0]["delta"]["content"] = None - else: - full_content += delta - response["choices"][0]["delta"]["content"] = delta - response["choices"][0]["delta"]["reasoning_content"] = None - yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" - except Exception as e: - response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e) - yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" - - # The last chunk - response["choices"][0]["delta"]["content"] = None - response["choices"][0]["delta"]["reasoning_content"] = None - response["choices"][0]["finish_reason"] = "stop" - prompt_tokens = num_tokens_from_string(prompt) - response["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": token_used, "total_tokens": prompt_tokens + token_used} - if need_reference: - reference_payload = final_reference if final_reference is not None else last_ans.get("reference", []) - response["choices"][0]["delta"]["reference"] = _build_reference_chunks( - reference_payload, - include_metadata=include_reference_metadata, - metadata_fields=metadata_fields, - ) - response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content - yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" - yield "data:[DONE]\n\n" - - resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - else: - answer = None - chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference} - if doc_ids_str: - chat_kwargs["doc_ids"] = doc_ids_str - async for ans in async_chat(dia, msg, False, **chat_kwargs): - # focus answer content only - answer = ans - break - content = answer["answer"] - - response = { - "id": f"chatcmpl-{chat_id}", - "object": "chat.completion", - "created": int(time.time()), - "model": req.get("model", ""), - "usage": { - "prompt_tokens": num_tokens_from_string(prompt), - "completion_tokens": num_tokens_from_string(content), - "total_tokens": num_tokens_from_string(prompt) + num_tokens_from_string(content), - "completion_tokens_details": { - "reasoning_tokens": context_token_used, - "accepted_prediction_tokens": num_tokens_from_string(content), - "rejected_prediction_tokens": 0, # 0 for simplicity - }, - }, - "choices": [ - { - "message": { - "role": "assistant", - "content": content, - }, - "logprobs": None, - "finish_reason": "stop", - "index": 0, - } - ], - } - if need_reference: - response["choices"][0]["message"]["reference"] = _build_reference_chunks( - answer.get("reference", {}), - include_metadata=include_reference_metadata, - metadata_fields=metadata_fields, - ) - - return jsonify(response) - - -@manager.route("/agents_openai//chat/completions", methods=["POST"]) # noqa: F821 -@validate_request("model", "messages") # noqa: F821 -@token_required -async def agents_completion_openai_compatibility(tenant_id, agent_id): - req = await get_request_json() - messages = req.get("messages", []) - if not messages: - return get_error_data_result("You must provide at least one message.") - if not UserCanvasService.query(user_id=tenant_id, id=agent_id): - return get_error_data_result(f"You don't own the agent {agent_id}") - - filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]] - prompt_tokens = sum(num_tokens_from_string(m["content"]) for m in filtered_messages) - if not filtered_messages: - return jsonify( - get_data_openai( - id=agent_id, - content="No valid messages found (user or assistant).", - finish_reason="stop", - model=req.get("model", ""), - completion_tokens=num_tokens_from_string("No valid messages found (user or assistant)."), - prompt_tokens=prompt_tokens, - ) - ) - - question = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "") - - stream = req.pop("stream", False) - if stream: - resp = Response( - completion_openai( - tenant_id, - agent_id, - question, - session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""), - stream=True, - **req, - ), - mimetype="text/event-stream", - ) - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - else: - # For non-streaming, just return the response directly - async for response in completion_openai( - tenant_id, - agent_id, - question, - session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""), - stream=False, - **req, - ): - return jsonify(response) - - return None - - -@manager.route("/agents//completions", methods=["POST"]) # noqa: F821 -@token_required -async def agent_completions(tenant_id, agent_id): - req = await get_request_json() - return_trace = bool(req.get("return_trace", False)) - - if req.get("stream", True): - - async def generate(): - trace_items = [] - async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): - if isinstance(answer, str): - try: - ans = json.loads(answer[5:]) # remove "data:" - except Exception: - continue - - event = ans.get("event") - if event == "node_finished": - if return_trace: - data = ans.get("data", {}) - trace_items.append( - { - "component_id": data.get("component_id"), - "trace": [copy.deepcopy(data)], - } - ) - ans.setdefault("data", {})["trace"] = trace_items - answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" - yield answer - - if event not in ["message", "message_end"]: - continue - - yield answer - - yield "data:[DONE]\n\n" - - resp = Response(generate(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - - full_content = "" - reference = {} - final_ans = "" - trace_items = [] - structured_output = {} - async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): - try: - ans = json.loads(answer[5:]) - - if ans["event"] == "message": - full_content += ans["data"]["content"] - - if ans.get("data", {}).get("reference", None): - reference.update(ans["data"]["reference"]) - - if ans.get("event") == "node_finished": - data = ans.get("data", {}) - node_out = data.get("outputs", {}) - component_id = data.get("component_id") - if component_id is not None and "structured" in node_out: - structured_output[component_id] = copy.deepcopy(node_out["structured"]) - if return_trace: - trace_items.append( - { - "component_id": data.get("component_id"), - "trace": [copy.deepcopy(data)], - } - ) - - final_ans = ans - except Exception as e: - return get_result(data=f"**ERROR**: {str(e)}") - final_ans["data"]["content"] = full_content - final_ans["data"]["reference"] = reference - if structured_output: - final_ans["data"]["structured"] = structured_output - if return_trace and final_ans: - final_ans["data"]["trace"] = trace_items - return get_result(data=final_ans) - - -@manager.route("/agents//sessions", methods=["GET"]) # noqa: F821 -@token_required -async def list_agent_session(tenant_id, agent_id): - if not UserCanvasService.query(user_id=tenant_id, id=agent_id): - return get_error_data_result(message=f"You don't own the agent {agent_id}.") - id = request.args.get("id") - user_id = request.args.get("user_id") - page_number = int(request.args.get("page", 1)) - items_per_page = int(request.args.get("page_size", 30)) - orderby = request.args.get("orderby", "update_time") - if request.args.get("desc") == "False" or request.args.get("desc") == "false": - desc = False - else: - desc = True - # dsl defaults to True in all cases except for False and false - include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false" - total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, - user_id, include_dsl) - if not convs: - return get_result(data=[]) - for conv in convs: - conv["messages"] = conv.pop("message") - infos = conv["messages"] - for info in infos: - if "prompt" in info: - info.pop("prompt") - conv["agent_id"] = conv.pop("dialog_id") - # Fix for session listing endpoint - if conv["reference"]: - messages = conv["messages"] - message_num = 0 - chunk_num = 0 - # Ensure reference is a list type to prevent KeyError - if not isinstance(conv["reference"], list): - conv["reference"] = [] - while message_num < len(messages): - if message_num != 0 and messages[message_num]["role"] != "user": - chunk_list = [] - # Add boundary and type checks to prevent KeyError - if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance( - conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]: - chunks = conv["reference"][chunk_num]["chunks"] - for chunk in chunks: - # Ensure chunk is a dictionary before calling get method - if not isinstance(chunk, dict): - continue - new_chunk = { - "id": chunk.get("chunk_id", chunk.get("id")), - "content": chunk.get("content_with_weight", chunk.get("content")), - "document_id": chunk.get("doc_id", chunk.get("document_id")), - "document_name": chunk.get("docnm_kwd", chunk.get("document_name")), - "dataset_id": chunk.get("kb_id", chunk.get("dataset_id")), - "image_id": chunk.get("image_id", chunk.get("img_id")), - "positions": chunk.get("positions", chunk.get("position_int")), - } - chunk_list.append(new_chunk) - chunk_num += 1 - messages[message_num]["reference"] = chunk_list - message_num += 1 - del conv["reference"] - return get_result(data=convs) - - @manager.route("/agents//sessions", methods=["DELETE"]) # noqa: F821 @token_required async def delete_agent_session(tenant_id, agent_id): @@ -697,97 +142,6 @@ async def delete_agent_session(tenant_id, agent_id): return get_result() -@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821 -@token_required -async def ask_about(tenant_id): - req = await get_request_json() - if not req.get("question"): - return get_error_data_result("`question` is required.") - if not req.get("dataset_ids"): - return get_error_data_result("`dataset_ids` is required.") - if not isinstance(req.get("dataset_ids"), list): - return get_error_data_result("`dataset_ids` should be a list.") - req["kb_ids"] = req.pop("dataset_ids") - for kb_id in req["kb_ids"]: - if not KnowledgebaseService.accessible(kb_id, tenant_id): - return get_error_data_result(f"You don't own the dataset {kb_id}.") - kbs = KnowledgebaseService.query(id=kb_id) - kb = kbs[0] - if kb.chunk_num == 0: - return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") - uid = tenant_id - - async def stream(): - nonlocal req, uid - try: - async for ans in async_ask(req["question"], req["kb_ids"], uid): - yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" - except Exception as e: - yield "data:" + json.dumps( - {"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - - resp = Response(stream(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - - -@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821 -@token_required -async def related_questions(tenant_id): - req = await get_request_json() - if not req.get("question"): - return get_error_data_result("`question` is required.") - question = req["question"] - industry = req.get("industry", "") - chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) - chat_mdl = LLMBundle(tenant_id, chat_model_config) - prompt = """ -Objective: To generate search terms related to the user's search keywords, helping users find more valuable information. -Instructions: - - Based on the keywords provided by the user, generate 5-10 related search terms. - - Each search term should be directly or indirectly related to the keyword, guiding the user to find more valuable information. - - Use common, general terms as much as possible, avoiding obscure words or technical jargon. - - Keep the term length between 2-4 words, concise and clear. - - DO NOT translate, use the language of the original keywords. -""" - if industry: - prompt += f" - Ensure all search terms are relevant to the industry: {industry}.\n" - prompt += """ -### Example: -Keywords: Chinese football -Related search terms: -1. Current status of Chinese football -2. Reform of Chinese football -3. Youth training of Chinese football -4. Chinese football in the Asian Cup -5. Chinese football in the World Cup - -Reason: - - When searching, users often only use one or two keywords, making it difficult to fully express their information needs. - - Generating related search terms can help users dig deeper into relevant information and improve search efficiency. - - At the same time, related terms can also help search engines better understand user needs and return more accurate search results. - -""" - ans = await chat_mdl.async_chat( - prompt, - [ - { - "role": "user", - "content": f""" -Keywords: {question} -Related search terms: - """, - } - ], - {"temperature": 0.9}, - ) - return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) - @manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821 async def chatbot_completions(dialog_id): @@ -800,20 +154,69 @@ async def chatbot_completions(dialog_id): objs = APIToken.query(beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') + tenant_id = objs[0].tenant_id + exists, dialog = DialogService.get_by_id(dialog_id) + if (not exists + or getattr(dialog, "tenant_id", None) != tenant_id + or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value): + logger.warning( + "Denied chatbot access: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s", + "no access to this chatbot", + tenant_id, + dialog_id, + req.get("user_id"), + req.get("session_id"), + ) + return get_error_data_result(message="Authentication error: no access to this chatbot!") if "quote" not in req: req["quote"] = False + def _validate_iframe_access(): + if req.get("session_id"): + exists, conv = API4ConversationService.get_by_id(req.get("session_id")) + if not exists: + raise AssertionError("Session not found!") + if conv.dialog_id != dialog_id: + raise AssertionError("Session does not belong to this dialog") + if tenant_id and conv.user_id and conv.user_id != tenant_id: + raise AssertionError("Session does not belong to this tenant") + if req.get("stream", True): - resp = Response(iframe_completion(dialog_id, **req), mimetype="text/event-stream") + try: + _validate_iframe_access() + except AssertionError: + logger.warning( + "Denied chatbot completion stream: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s", + "no access to this chatbot", + tenant_id, + dialog_id, + req.get("user_id"), + req.get("session_id"), + ) + return get_error_data_result(message="Authentication error: no access to this chatbot!") + + resp = Response(iframe_completion(dialog_id, tenant_id=tenant_id, **req), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - async for answer in iframe_completion(dialog_id, **req): - return get_result(data=answer) + try: + _validate_iframe_access() + async for answer in iframe_completion(dialog_id, tenant_id=tenant_id, **req): + return get_result(data=answer) + except AssertionError: + logger.warning( + "Denied chatbot completion: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s", + "no access to this chatbot", + tenant_id, + dialog_id, + req.get("user_id"), + req.get("session_id"), + ) + return get_error_data_result(message="Authentication error: no access to this chatbot!") return None @@ -826,11 +229,23 @@ async def chatbots_inputs(dialog_id): objs = APIToken.query(beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - - e, dialog = DialogService.get_by_id(dialog_id) - if not e: - return get_error_data_result(f"Can't find dialog by ID: {dialog_id}") - + tenant_id = objs[0].tenant_id + exists, dialog = DialogService.get_by_id(dialog_id) + if (not exists + or getattr(dialog, "tenant_id", None) != tenant_id + or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value): + request_args = getattr(request, "args", {}) or {} + request_user_id = request_args.get("user_id") if hasattr(request_args, "get") else None + request_session_id = request_args.get("session_id") if hasattr(request_args, "get") else None + logger.warning( + "Denied chatbot access: reason=%s tenant_id=%s dialog_id=%s user_id=%s session_id=%s", + "no access to this chatbot", + tenant_id, + dialog_id, + request_user_id, + request_session_id, + ) + return get_error_data_result(message="Authentication error: no access to this chatbot!") return get_result( data={ "title": dialog.name, @@ -971,12 +386,15 @@ async def retrieval_test_embedded(): vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) use_kg = req.get("use_kg", False) top = int(req.get("top_k", 1024)) + if top <= 0: + return get_error_data_result("`top_k` must be greater than 0") langs = req.get("cross_languages", []) rerank_id = req.get("rerank_id", "") tenant_rerank_id = req.get("tenant_rerank_id", "") tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") + search_config = {} async def _retrieval(): nonlocal similarity_threshold, vector_similarity_weight, top, rerank_id @@ -987,8 +405,11 @@ async def _retrieval(): meta_data_filter = {} chat_mdl = None if req.get("search_id", ""): - search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) - meta_data_filter = search_config.get("meta_data_filter", {}) + nonlocal search_config + detail = SearchService.get_detail(req.get("search_id", "")) + if detail: + search_config = detail.get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_id = search_config.get("chat_id", "") if chat_id: @@ -1012,8 +433,15 @@ async def _retrieval(): chat_mdl = LLMBundle(tenant_id, chat_model_config) if meta_data_filter: - metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) - local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, _question, chat_mdl, local_doc_ids) + local_doc_ids = await apply_meta_data_filter( + meta_data_filter, + None, + _question, + chat_mdl, + local_doc_ids, + kb_ids=kb_ids, + metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids), + ) tenants = UserTenantService.query(user_id=tenant_id) for kb_id in kb_ids: @@ -1064,6 +492,11 @@ async def _retrieval(): for c in ranks["chunks"]: c.pop("vector", None) + + include_metadata, metadata_fields = _resolve_reference_metadata(req, search_config) + if include_metadata: + enrich_chunks_with_document_metadata(ranks["chunks"], metadata_fields) + ranks["labels"] = labels return get_json_result(data=ranks) @@ -1179,126 +612,6 @@ async def mindmap(): return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) -@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821 -@token_required -async def sequence2txt(tenant_id): - req = await request.form - stream_mode = req.get("stream", "false").lower() == "true" - files = await request.files - if "file" not in files: - return get_error_data_result(message="Missing 'file' in multipart form-data") - - uploaded = files["file"] - - ALLOWED_EXTS = { - ".wav", ".mp3", ".m4a", ".aac", - ".flac", ".ogg", ".webm", - ".opus", ".wma" - } - - filename = uploaded.filename or "" - suffix = os.path.splitext(filename)[-1].lower() - if suffix not in ALLOWED_EXTS: - return get_error_data_result(message= - f"Unsupported audio format: {suffix}. " - f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}" - ) - fd, temp_audio_path = tempfile.mkstemp(suffix=suffix) - os.close(fd) - await uploaded.save(temp_audio_path) - - try: - default_asr_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.SPEECH2TEXT) - except Exception as e: - return get_error_data_result(message=str(e)) - asr_mdl=LLMBundle(tenant_id, default_asr_model_config) - if not stream_mode: - text = asr_mdl.transcription(temp_audio_path) - try: - os.remove(temp_audio_path) - except Exception as e: - logging.error(f"Failed to remove temp audio file: {str(e)}") - return get_json_result(data={"text": text}) - async def event_stream(): - try: - for evt in asr_mdl.stream_transcription(temp_audio_path): - yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n" - except Exception as e: - err = {"event": "error", "text": str(e)} - yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n" - finally: - try: - os.remove(temp_audio_path) - except Exception as e: - logging.error(f"Failed to remove temp audio file: {str(e)}") - - return Response(event_stream(), content_type="text/event-stream") - -@manager.route("/tts", methods=["POST"]) # noqa: F821 -@token_required -async def tts(tenant_id): - req = await get_request_json() - text = req["text"] - - try: - default_tts_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.TTS) - except Exception as e: - return get_error_data_result(message=str(e)) - tts_mdl = LLMBundle(tenant_id, default_tts_model_config) - - def stream_audio(): - try: - for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text): - for chunk in tts_mdl.tts(txt): - yield chunk - except Exception as e: - yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8") - - resp = Response(stream_audio(), mimetype="audio/mpeg") - resp.headers.add_header("Cache-Control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - - return resp - - -def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None): - chunks = chunks_format(reference) - if not include_metadata: - return chunks - - doc_ids_by_kb = {} - for chunk in chunks: - kb_id = chunk.get("dataset_id") - doc_id = chunk.get("document_id") - if not kb_id or not doc_id: - continue - doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id) - - if not doc_ids_by_kb: - return chunks - - meta_by_doc = {} - for kb_id, doc_ids in doc_ids_by_kb.items(): - meta_map = DocMetadataService.get_metadata_for_documents(list(doc_ids), kb_id) - if meta_map: - meta_by_doc.update(meta_map) - - if metadata_fields is not None: - metadata_fields = {f for f in metadata_fields if isinstance(f, str)} - if not metadata_fields: - return chunks - - for chunk in chunks: - doc_id = chunk.get("document_id") - if not doc_id: - continue - meta = meta_by_doc.get(doc_id) - if not meta: - continue - if metadata_fields is not None: - meta = {k: v for k, v in meta.items() if k in metadata_fields} - if meta: - chunk["document_metadata"] = meta - return chunks +def _resolve_reference_metadata(req, search_config=None): + return resolve_reference_metadata_preferences(req, search_config) diff --git a/api/apps/services/canvas_replica_service.py b/api/apps/services/canvas_replica_service.py index a2aa56b6f96..17b6c99cb02 100644 --- a/api/apps/services/canvas_replica_service.py +++ b/api/apps/services/canvas_replica_service.py @@ -160,7 +160,7 @@ def bootstrap( @classmethod def load_for_run(cls, canvas_id: str, tenant_id: str, runtime_user_id: str): - """Load current runtime replica used by /completion.""" + """Load current runtime replica used by /completions.""" replica_key = cls._replica_key(canvas_id, str(tenant_id), str(runtime_user_id)) return cls._read_payload(replica_key) diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py index 8cb718467a3..9e49596539c 100644 --- a/api/apps/services/dataset_api_service.py +++ b/api/apps/services/dataset_api_service.py @@ -16,6 +16,7 @@ import logging import json import os +import re from common.constants import PAGERANK_FLD from common import settings from api.db.db_models import File @@ -25,10 +26,31 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.connector_service import Connector2KbService from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService -from api.db.services.user_service import TenantService, UserService +from api.db.services.user_service import TenantService, UserService, UserTenantService +from api.db.services.tenant_llm_service import TenantLLMService from common.constants import FileSource, StatusEnum from api.utils.api_utils import deep_merge, get_parser_config, remap_dictionary_keys, verify_embedding_availability +_VALID_INDEX_TYPES = {"graph", "raptor", "mindmap"} + +_INDEX_TYPE_TO_TASK_TYPE = { + "graph": "graphrag", + "raptor": "raptor", + "mindmap": "mindmap", +} + +_INDEX_TYPE_TO_TASK_ID_FIELD = { + "graph": "graphrag_task_id", + "raptor": "raptor_task_id", + "mindmap": "mindmap_task_id", +} + +_INDEX_TYPE_TO_DISPLAY_NAME = { + "graph": "Graph", + "raptor": "RAPTOR", + "mindmap": "Mindmap", +} + async def create_dataset(tenant_id: str, req: dict): """ @@ -61,12 +83,7 @@ async def create_dataset(tenant_id: str, req: dict): req["parser_config"] = parser_cfg req.update(ext_fields) - e, create_dict = KnowledgebaseService.create_with_name( - name=req.pop("name", None), - tenant_id=tenant_id, - parser_id=req.pop("parser_id", None), - **req - ) + e, create_dict = KnowledgebaseService.create_with_name(name=req.pop("name", None), tenant_id=tenant_id, parser_id=req.pop("parser_id", None), **req) if not e: return False, create_dict @@ -132,12 +149,12 @@ async def delete_datasets(tenant_id: str, ids: list = None, delete_all: bool = F ] ) File2DocumentService.delete_by_document_id(doc.id) - FileService.filter_delete( - [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) + FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) # Drop index for this dataset try: from rag.nlp import search + idxnm = search.index_name(kb.tenant_id) settings.docStoreConn.delete_idx(idxnm, kb_id) except Exception as e: @@ -158,6 +175,57 @@ async def delete_datasets(tenant_id: str, ids: list = None, delete_all: bool = F return True, {"success_count": success_count, "errors": errors[:5]} +def get_dataset(dataset_id: str, tenant_id: str): + """ + Get a single dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + response_data = remap_dictionary_keys(kb.to_dict()) + response_data["size"] = DocumentService.get_total_size_by_kb_id(dataset_id) + response_data["connectors"] = list(Connector2KbService.list_connectors(dataset_id)) + return True, response_data + + +def get_ingestion_summary(dataset_id: str, tenant_id: str): + """ + Get ingestion summary for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + status = DocumentService.get_parsing_status_by_kb_ids([dataset_id]).get(dataset_id, {}) + return True, { + "doc_num": kb.doc_num, + "chunk_num": kb.chunk_num, + "token_num": kb.token_num, + "status": status, + } + + async def update_dataset(tenant_id: str, dataset_id: str, req: dict): """ Update a dataset. @@ -195,7 +263,7 @@ async def update_dataset(tenant_id: str, dataset_id: str, req: dict): parser_cfg["metadata"] = fields parser_cfg["enable_metadata"] = auto_meta.get("enabled", True) req["parser_config"] = parser_cfg - + # Merge ext fields with req req.update(ext_fields) @@ -232,16 +300,13 @@ async def update_dataset(tenant_id: str, dataset_id: str, req: dict): req["pipeline_id"] = "" if "name" in req and req["name"].lower() != kb.name.lower(): - exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, - status=StatusEnum.VALID.value) + exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value) if exists: return False, f"Dataset name '{req['name']}' already exists" if "embd_id" in req: if not req["embd_id"]: req["embd_id"] = kb.embd_id - if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: - return False, f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}" ok, err = verify_embedding_availability(req["embd_id"], tenant_id) if not ok: return False, err @@ -252,13 +317,13 @@ async def update_dataset(tenant_id: str, dataset_id: str, req: dict): if req["pagerank"] > 0: from rag.nlp import search - settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, - search.index_name(kb.tenant_id), kb.id) + + settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) else: # Elasticsearch requires PAGERANK_FLD be non-zero! from rag.nlp import search - settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, - search.index_name(kb.tenant_id), kb.id) + + settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) if "parse_type" in req: del req["parse_type"] @@ -317,27 +382,13 @@ def list_datasets(tenant_id: str, args: dict): else: tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) tenant_ids = [m["tenant_id"] for m in tenants] - kbs, total = KnowledgebaseService.get_list( - tenant_ids, - tenant_id, - page, - page_size, - orderby, - desc, - kb_id, - name, - keywords, - parser_id - ) + kbs, total = KnowledgebaseService.get_list(tenant_ids, tenant_id, page, page_size, orderby, desc, kb_id, name, keywords, parser_id) users = UserService.get_by_ids([m["tenant_id"] for m in kbs]) user_map = {m.id: m.to_dict() for m in users} response_data_list = [] for kb in kbs: user_dict = user_map.get(kb["tenant_id"], {}) - kb.update({ - "nickname": user_dict.get("nickname", ""), - "tenant_avatar": user_dict.get("avatar", "") - }) + kb.update({"nickname": user_dict.get("nickname", ""), "tenant_avatar": user_dict.get("avatar", "")}) response_data_list.append(remap_dictionary_keys(kb)) return True, {"data": response_data_list, "total": total} @@ -354,13 +405,11 @@ async def get_knowledge_graph(dataset_id: str, tenant_id: str): return False, "No authorization." _, kb = KnowledgebaseService.get_by_id(dataset_id) - req = { - "kb_id": [dataset_id], - "knowledge_graph_kwd": ["graph"] - } + req = {"kb_id": [dataset_id], "knowledge_graph_kwd": ["graph"]} obj = {"graph": {}, "mind_map": {}} from rag.nlp import search + if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): return True, obj sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) @@ -380,8 +429,7 @@ async def get_knowledge_graph(dataset_id: str, tenant_id: str): obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256] if "edges" in obj["graph"]: node_id_set = {o["id"] for o in obj["graph"]["nodes"]} - filtered_edges = [o for o in obj["graph"]["edges"] if - o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] + filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128] return True, obj @@ -398,20 +446,28 @@ def delete_knowledge_graph(dataset_id: str, tenant_id: str): return False, "No authorization." _, kb = KnowledgebaseService.get_by_id(dataset_id) from rag.nlp import search - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, + from rag.graphrag.phase_markers import clear_phase_markers + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation", "community_report"]}, search.index_name(kb.tenant_id), dataset_id) + # Wiping the graph invalidates any phase-completion markers used to + # short-circuit resolution / community detection on resume. + clear_phase_markers(dataset_id) return True, True -def run_graphrag(dataset_id: str, tenant_id: str): +def run_index(dataset_id: str, tenant_id: str, index_type: str): """ - Run GraphRAG for a dataset. + Run an indexing task (graph/raptor/mindmap) for a dataset. :param dataset_id: dataset ID :param tenant_id: tenant ID + :param index_type: one of "graph", "raptor", "mindmap" :return: (success, result) or (success, error_message) """ + if index_type not in _VALID_INDEX_TYPES: + return False, f"Invalid index type '{index_type}'. Must be one of {sorted(_VALID_INDEX_TYPES)}" + if not dataset_id: return False, 'Lack of "Dataset ID"' if not KnowledgebaseService.accessible(dataset_id, tenant_id): @@ -421,14 +477,18 @@ def run_graphrag(dataset_id: str, tenant_id: str): if not ok: return False, "Invalid Dataset ID" - task_id = kb.graphrag_task_id - if task_id: - ok, task = TaskService.get_by_id(task_id) + task_type = _INDEX_TYPE_TO_TASK_TYPE[index_type] + task_id_field = _INDEX_TYPE_TO_TASK_ID_FIELD[index_type] + display_name = _INDEX_TYPE_TO_DISPLAY_NAME[index_type] + + existing_task_id = getattr(kb, task_id_field, None) + if existing_task_id: + ok, task = TaskService.get_by_id(existing_task_id) if not ok: - logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}") + logging.warning(f"A valid {display_name} task id is expected for Dataset {dataset_id}") if task and task.progress not in [-1, 1]: - return False, f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running." + return False, f"Task {existing_task_id} in progress with status {task.progress}. A {display_name} Task is already running." documents, _ = DocumentService.get_by_kb_id( kb_id=dataset_id, @@ -447,24 +507,29 @@ def run_graphrag(dataset_id: str, tenant_id: str): sample_document = documents[0] document_ids = [document["id"] for document in documents] - task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty=task_type, priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): - logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}") + if not KnowledgebaseService.update_by_id(kb.id, {task_id_field: task_id}): + logging.warning(f"Cannot save {task_id_field} for Dataset {dataset_id}") - return True, {"graphrag_task_id": task_id} + return True, {"task_id": task_id} -def trace_graphrag(dataset_id: str, tenant_id: str): +def trace_index(dataset_id: str, tenant_id: str, index_type: str): """ - Trace GraphRAG task for a dataset. + Trace an indexing task (graph/raptor/mindmap) for a dataset. :param dataset_id: dataset ID :param tenant_id: tenant ID + :param index_type: one of "graph", "raptor", "mindmap" :return: (success, result) or (success, error_message) """ + if index_type not in _VALID_INDEX_TYPES: + return False, f"Invalid index type '{index_type}'. Must be one of {sorted(_VALID_INDEX_TYPES)}" + if not dataset_id: return False, 'Lack of "Dataset ID"' + if not KnowledgebaseService.accessible(dataset_id, tenant_id): return False, "No authorization." @@ -472,7 +537,8 @@ def trace_graphrag(dataset_id: str, tenant_id: str): if not ok: return False, "Invalid Dataset ID" - task_id = kb.graphrag_task_id + task_id_field = _INDEX_TYPE_TO_TASK_ID_FIELD[index_type] + task_id = getattr(kb, task_id_field, None) if not task_id: return True, {} @@ -483,9 +549,9 @@ def trace_graphrag(dataset_id: str, tenant_id: str): return True, task.to_dict() -def run_raptor(dataset_id: str, tenant_id: str): +def list_tags(dataset_id: str, tenant_id: str): """ - Run RAPTOR for a dataset. + List tags for a dataset. :param dataset_id: dataset ID :param tenant_id: tenant ID @@ -493,6 +559,118 @@ def run_raptor(dataset_id: str, tenant_id: str): """ if not dataset_id: return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + tenants = UserTenantService.get_tenants_by_user_id(tenant_id) + tags = [] + for tenant in tenants: + tags += settings.retriever.all_tags(tenant["tenant_id"], [dataset_id]) + return True, tags + + +def aggregate_tags(dataset_ids: list[str], tenant_id: str): + """ + Aggregate tags across multiple datasets. + + :param dataset_ids: list of dataset IDs + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_ids: + return False, 'Lack of "dataset_ids"' + + for dataset_id in dataset_ids: + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, f"No authorization for dataset '{dataset_id}'" + + dataset_ids_by_tenant = {} + for dataset_id in dataset_ids: + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, f"Invalid Dataset ID '{dataset_id}'" + dataset_ids_by_tenant.setdefault(kb.tenant_id, []).append(dataset_id) + + merged = {} + for kb_tenant_id, kb_ids in dataset_ids_by_tenant.items(): + for bucket in settings.retriever.all_tags(kb_tenant_id, kb_ids): + tag = bucket["value"] + merged[tag] = merged.get(tag, 0) + bucket["count"] + + return True, [{"value": tag, "count": count} for tag, count in merged.items()] + + +def get_flattened_metadata(dataset_ids: list[str], tenant_id: str): + """ + Get flattened metadata for datasets. + + :param dataset_ids: list of dataset IDs + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_ids: + return False, 'Lack of "dataset_ids"' + + for dataset_id in dataset_ids: + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, f"No authorization for dataset '{dataset_id}'" + + from api.db.services.doc_metadata_service import DocMetadataService + + return True, DocMetadataService.get_flatted_meta_by_kbs(dataset_ids) + + +def get_auto_metadata(dataset_id: str, tenant_id: str): + """ + Get auto-metadata configuration for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) + if kb is None: + return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" + parser_cfg = kb.parser_config or {} + return True, {"metadata": parser_cfg.get("metadata") or [], "built_in_metadata": parser_cfg.get("built_in_metadata") or []} + + +async def update_auto_metadata(dataset_id: str, tenant_id: str, cfg: dict): + """ + Update auto-metadata configuration for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :param cfg: auto-metadata configuration + :return: (success, result) or (success, error_message) + """ + kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) + if kb is None: + return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" + + parser_cfg = kb.parser_config or {} + parser_cfg["metadata"] = cfg.get("metadata") + parser_cfg["built_in_metadata"] = cfg.get("built_in_metadata") + + if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": parser_cfg}): + return False, "Update auto-metadata error.(Database error)" + + return True, cfg + + +def delete_tags(dataset_id: str, tenant_id: str, tags: list[str]): + """ + Delete tags from a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :param tags: list of tags to delete + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + if not KnowledgebaseService.accessible(dataset_id, tenant_id): return False, "No authorization." @@ -500,14 +678,178 @@ def run_raptor(dataset_id: str, tenant_id: str): if not ok: return False, "Invalid Dataset ID" - task_id = kb.raptor_task_id + from rag.nlp import search + + for t in tags: + settings.docStoreConn.update({"tag_kwd": t, "kb_id": [dataset_id]}, {"remove": {"tag_kwd": t}}, search.index_name(kb.tenant_id), dataset_id) + + return True, {} + + +def list_ingestion_logs( + dataset_id: str, + tenant_id: str, + page: int, + page_size: int, + orderby: str, + desc: bool, + operation_status: list = None, + create_date_from: str = None, + create_date_to: str = None, + log_type: str = "dataset", + keywords: str = None, +): + """ + List ingestion logs for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :param page: page number + :param page_size: items per page + :param orderby: order by field + :param desc: descending order + :param operation_status: filter by operation status + :param create_date_from: filter start date + :param create_date_to: filter end date + :param log_type: "dataset" or "file" + :param keywords: search keywords for file logs + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + from api.db.services.pipeline_operation_log_service import PipelineOperationLogService + + allowed_log_types = {"dataset", "file"} + if log_type not in allowed_log_types: + logging.warning( + "list_ingestion_logs invalid log_type: dataset_id=%s tenant_id=%s log_type=%s", + dataset_id, + tenant_id, + log_type, + ) + return False, 'Invalid "log_type", expected "dataset" or "file"' + + logging.info( + "list_ingestion_logs: dataset_id=%s tenant_id=%s log_type=%s page=%s page_size=%s", + dataset_id, + tenant_id, + log_type, + page, + page_size, + ) + + if log_type == "file": + logs, total = PipelineOperationLogService.get_file_logs_by_kb_id(dataset_id, page, page_size, orderby, desc, keywords, operation_status or [], None, None, create_date_from, create_date_to) + else: + logs, total = PipelineOperationLogService.get_dataset_logs_by_kb_id(dataset_id, page, page_size, orderby, desc, operation_status or [], create_date_from, create_date_to, keywords) + return True, {"total": total, "logs": logs} + + +def get_ingestion_log(dataset_id: str, tenant_id: str, log_id: str): + """ + Get a single ingestion log. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :param log_id: log ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + from api.db.services.pipeline_operation_log_service import PipelineOperationLogService + + fields = PipelineOperationLogService.get_dataset_logs_fields() + log = PipelineOperationLogService.model.select(*fields).where((PipelineOperationLogService.model.id == log_id) & (PipelineOperationLogService.model.kb_id == dataset_id)).first() + if not log: + return False, "Log not found" + + return True, log.to_dict() + + +def delete_index(dataset_id: str, tenant_id: str, index_type: str, wipe: bool = True): + """ + Delete an indexing task (graph/raptor/mindmap) for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :param index_type: one of "graph", "raptor", "mindmap" + :param wipe: when True (default) the persisted artefacts (graph rows, + raptor summaries) are removed from the doc store and any GraphRAG + phase-completion markers are cleared. Pass False to cancel the + running task while keeping prior progress so it can be resumed. + :return: (success, result) or (success, error_message) + """ + if index_type not in _VALID_INDEX_TYPES: + return False, f"Invalid index type '{index_type}'. Must be one of {sorted(_VALID_INDEX_TYPES)}" + + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + task_id_field = _INDEX_TYPE_TO_TASK_ID_FIELD[index_type] + task_finish_at_field = f"{task_id_field.replace('_task_id', '_task_finish_at')}" + task_id = getattr(kb, task_id_field, None) + + logging.info("delete_index: dataset=%s index_type=%s wipe=%s", dataset_id, index_type, wipe) + if task_id: - ok, task = TaskService.get_by_id(task_id) - if not ok: - logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}") + from rag.utils.redis_conn import REDIS_CONN + + try: + REDIS_CONN.set(f"{task_id}-cancel", "x") + except Exception as e: + logging.exception(e) + TaskService.delete_by_id(task_id) + + if wipe and index_type == "graph": + from rag.nlp import search + from rag.graphrag.phase_markers import clear_phase_markers + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation", "community_report"]}, + search.index_name(kb.tenant_id), dataset_id) + # Wiping the graph invalidates any phase-completion markers used to + # short-circuit resolution / community detection on resume. + clear_phase_markers(dataset_id) + logging.info("delete_index: cleared GraphRAG artefacts and phase markers for dataset=%s", dataset_id) + elif wipe and index_type == "raptor": + from rag.nlp import search + + settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), dataset_id) + + KnowledgebaseService.update_by_id(kb.id, {task_id_field: "", task_finish_at_field: None}) + return True, {} - if task and task.progress not in [-1, 1]: - return False, f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running." + +def run_embedding(dataset_id: str, tenant_id: str): + """ + Run embedding for all documents in a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" documents, _ = DocumentService.get_by_kb_id( kb_id=dataset_id, @@ -523,23 +865,22 @@ def run_raptor(dataset_id: str, tenant_id: str): if not documents: return False, f"No documents in Dataset {dataset_id}" - sample_document = documents[0] - document_ids = [document["id"] for document in documents] - - task_id = queue_raptor_o_graphrag_tasks(sample_doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - - if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}): - logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}") + kb_table_num_map = {} + for doc in documents: + doc["tenant_id"] = tenant_id + DocumentService.run(tenant_id, doc, kb_table_num_map) - return True, {"raptor_task_id": task_id} + return True, {"scheduled_count": len(documents)} -def trace_raptor(dataset_id: str, tenant_id: str): +def rename_tag(dataset_id: str, tenant_id: str, from_tag: str, to_tag: str): """ - Trace RAPTOR task for a dataset. + Rename a tag in a dataset. :param dataset_id: dataset ID :param tenant_id: tenant ID + :param from_tag: original tag name + :param to_tag: new tag name :return: (success, result) or (success, error_message) """ if not dataset_id: @@ -552,78 +893,522 @@ def trace_raptor(dataset_id: str, tenant_id: str): if not ok: return False, "Invalid Dataset ID" - task_id = kb.raptor_task_id - if not task_id: - return True, {} + from rag.nlp import search - ok, task = TaskService.get_by_id(task_id) - if not ok: - return False, "RAPTOR Task Not Found or Error Occurred" + settings.docStoreConn.update({"tag_kwd": from_tag, "kb_id": [dataset_id]}, {"remove": {"tag_kwd": from_tag.strip()}, "add": {"tag_kwd": to_tag}}, search.index_name(kb.tenant_id), dataset_id) - return True, task.to_dict() + return True, {"from": from_tag, "to": to_tag} -def get_auto_metadata(dataset_id: str, tenant_id: str): +async def search(dataset_id: str, tenant_id: str, req: dict): """ - Get auto-metadata configuration for a dataset. + Search (retrieval test) within a dataset. :param dataset_id: dataset ID :param tenant_id: tenant ID + :param req: search request :return: (success, result) or (success, error_message) """ - kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) - if kb is None: - return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" + from api.db.joint_services.tenant_model_service import ( + get_model_config_by_id, + get_model_config_by_type_and_name, + get_tenant_default_model_by_type, + ) + from api.db.services.doc_metadata_service import DocMetadataService + from api.db.services.llm_service import LLMBundle + from api.db.services.search_service import SearchService + from api.db.services.user_service import UserTenantService + from common.constants import LLMType + from common.metadata_utils import apply_meta_data_filter + from rag.app.tag import label_question + from rag.prompts.generator import cross_languages, keyword_extraction + + logging.debug( + "search(dataset=%s, tenant=%s, question_len=%s)", + dataset_id, + tenant_id, + len(req.get("question", "")), + ) - parser_cfg = kb.parser_config or {} - metadata = parser_cfg.get("metadata") or [] - enabled = parser_cfg.get("enable_metadata", bool(metadata)) - # Normalize to AutoMetadataConfig-like JSON - fields = [] - for f in metadata: - if not isinstance(f, dict): - continue - fields.append( - { - "name": f.get("name", ""), - "type": f.get("type", ""), - "description": f.get("description"), - "examples": f.get("examples"), - "restrict_values": f.get("restrict_values", False), - } + page = int(req.get("page", 1)) + size = int(req.get("size", 30)) + question = req.get("question", "") + doc_ids = req.get("doc_ids", []) + use_kg = req.get("use_kg", False) + top = max(1, min(int(req.get("top_k", 1024)), 2048)) + langs = req.get("cross_languages", []) + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + logging.warning("search access denied: dataset=%s tenant=%s", dataset_id, tenant_id) + return False, "Only owner of dataset authorized for this operation." + + e, kb = KnowledgebaseService.get_by_id(dataset_id) + if not e: + logging.warning("search dataset not found: dataset=%s", dataset_id) + return False, "Dataset not found!" + + if doc_ids is not None and not isinstance(doc_ids, list): + return False, "`doc_ids` should be a list" + local_doc_ids = list(doc_ids) if doc_ids else [] + + meta_data_filter = {} + chat_mdl = None + if req.get("search_id", ""): + search_detail = SearchService.get_detail(req.get("search_id", "")) + if not search_detail: + logging.warning("search config not found: search_id=%s", req.get("search_id", "")) + return False, "Invalid search_id" + search_config = search_detail.get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_id = search_config.get("chat_id", "") + if chat_id: + chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, search_config["chat_id"]) + else: + chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(tenant_id, chat_model_config) + else: + meta_data_filter = req.get("meta_data_filter") or {} + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(tenant_id, chat_model_config) + + if meta_data_filter: + local_doc_ids = await apply_meta_data_filter( + meta_data_filter, + None, + question, + chat_mdl, + local_doc_ids, + kb_ids=[dataset_id], + metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs([dataset_id]), ) - return True, {"enabled": enabled, "fields": fields} + tenant_ids = [] + tenants = UserTenantService.query(user_id=tenant_id) + for tenant in tenants: + if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=dataset_id): + tenant_ids.append(tenant.tenant_id) + break + else: + return False, "Only owner of dataset authorized for this operation." + + _question = question + if langs: + _question = await cross_languages(kb.tenant_id, None, _question, langs) + if kb.tenant_embd_id: + embd_model_config = get_model_config_by_id(kb.tenant_embd_id) + elif kb.embd_id: + embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) + else: + embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING) + embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) + + rerank_mdl = None + if req.get("tenant_rerank_id"): + rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + elif req.get("rerank_id"): + rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"]) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + + if req.get("keyword", False): + default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config) + _question += await keyword_extraction(chat_mdl, _question) + + labels = label_question(_question, [kb]) + ranks = await settings.retriever.retrieval( + _question, + embd_mdl, + tenant_ids, + [dataset_id], + page, + size, + float(req.get("similarity_threshold", 0.0)), + float(req.get("vector_similarity_weight", 0.3)), + doc_ids=local_doc_ids, + top=top, + rerank_mdl=rerank_mdl, + rank_feature=labels, + ) -async def update_auto_metadata(dataset_id: str, tenant_id: str, cfg: dict): + if use_kg: + try: + default_chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + ck = await settings.kg_retriever.retrieval(_question, tenant_ids, [dataset_id], embd_mdl, LLMBundle(kb.tenant_id, default_chat_model_config)) + if ck["content_with_weight"]: + ranks["chunks"].insert(0, ck) + except Exception: + logging.warning("search KG retrieval failed: dataset=%s tenant=%s", dataset_id, tenant_id, exc_info=True) + total = ranks.get("total", 0) + ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids) + ranks["total"] = total + + for c in ranks["chunks"]: + c.pop("vector", None) + ranks["labels"] = labels + + return True, ranks + + +def check_embedding(dataset_id: str, tenant_id: str, req: dict): """ - Update auto-metadata configuration for a dataset. + Check embedding model compatibility by sampling random chunks, + re-embedding them with the new model, and computing cosine similarity. :param dataset_id: dataset ID :param tenant_id: tenant ID - :param cfg: auto-metadata configuration + :param req: request body with embd_id :return: (success, result) or (success, error_message) """ - kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) - if kb is None: - return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" + import random - parser_cfg = kb.parser_config or {} - fields = [] - for f in cfg.get("fields", []): - fields.append( - { - "name": f.get("name", ""), - "type": f.get("type", ""), - "description": f.get("description"), - "examples": f.get("examples"), - "restrict_values": f.get("restrict_values", False), - } + import numpy as np + from common.constants import RetCode + from common.doc_store.doc_store_base import OrderByExpr + from rag.nlp import search + + from api.db.joint_services.tenant_model_service import ( + get_model_config_by_type_and_name, + ) + from api.db.services.llm_service import LLMBundle + from common.constants import LLMType + + def _guess_vec_field(src: dict): + for k in src or {}: + if k.endswith("_vec"): + return k + return None + + def _as_float_vec(v): + if v is None: + return [] + if isinstance(v, str): + return [float(x) for x in v.split("\t") if x != ""] + if isinstance(v, (list, tuple, np.ndarray)): + return [float(x) for x in v] + return [] + + def _to_1d(x): + a = np.asarray(x, dtype=np.float32) + return a.reshape(-1) + + def _cos_sim(a, b, eps=1e-12): + a = _to_1d(a) + b = _to_1d(b) + na = np.linalg.norm(a) + nb = np.linalg.norm(b) + if na < eps or nb < eps: + return 0.0 + return float(np.dot(a, b) / (na * nb)) + + def sample_random_chunks_with_vectors( + docStoreConn, + tenant_id: str, + kb_id: str, + n: int = 5, + base_fields=("docnm_kwd", "doc_id", "content_with_weight", "page_num_int", "position_int", "top_int"), + ): + index_nm = search.index_name(tenant_id) + + res0 = docStoreConn.search( + select_fields=[], highlight_fields=[], + condition={"kb_id": kb_id, "available_int": 1}, + match_expressions=[], order_by=OrderByExpr(), + offset=0, limit=1, + index_names=index_nm, knowledgebase_ids=[kb_id], ) - parser_cfg["metadata"] = fields - parser_cfg["enable_metadata"] = cfg.get("enabled", True) + total = docStoreConn.get_total(res0) + if total <= 0: + return [] + + n = min(n, total) + offsets = sorted(random.sample(range(min(total, 1000)), n)) + out = [] + + for off in offsets: + res1 = docStoreConn.search( + select_fields=list(base_fields), + highlight_fields=[], + condition={"kb_id": kb_id, "available_int": 1}, + match_expressions=[], order_by=OrderByExpr(), + offset=off, limit=1, + index_names=index_nm, knowledgebase_ids=[kb_id], + ) + ids = docStoreConn.get_doc_ids(res1) + if not ids: + continue - if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": parser_cfg}): - return False, "Update auto-metadata error.(Database error)" + cid = ids[0] + full_doc = docStoreConn.get(cid, index_nm, [kb_id]) or {} + vec_field = _guess_vec_field(full_doc) + vec = _as_float_vec(full_doc.get(vec_field)) + + out.append({ + "chunk_id": cid, + "kb_id": kb_id, + "doc_id": full_doc.get("doc_id"), + "doc_name": full_doc.get("docnm_kwd"), + "vector_field": vec_field, + "vector_dim": len(vec), + "vector": vec, + "page_num_int": full_doc.get("page_num_int"), + "position_int": full_doc.get("position_int"), + "top_int": full_doc.get("top_int"), + "content_with_weight": full_doc.get("content_with_weight") or "", + "question_kwd": full_doc.get("question_kwd") or [], + }) + return out + + def _clean(s: str): + return re.sub(r"]{0,12})?>", " ", s or "").strip() + + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + embd_id = req.get("embd_id", "") + if not embd_id: + return False, "`embd_id` is required." + + logging.info("check_embedding: dataset=%s tenant=%s embd_id=%s", dataset_id, tenant_id, embd_id) + + ok, err = verify_embedding_availability(embd_id, tenant_id) + if not ok: + return False, err + + embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, embd_id) + emb_mdl = LLMBundle(kb.tenant_id, embd_model_config) + + n = int(req.get("check_num", 5)) + samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=kb.tenant_id, kb_id=dataset_id, n=n) + logging.info("check_embedding: dataset=%s sampled=%d chunks", dataset_id, len(samples)) + + results, eff_sims = [], [] + mode = "content_only" + for ck in samples: + title = ck.get("doc_name") or "Title" + + txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or "" + txt_in = _clean(txt_in) + if not txt_in: + results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"}) + continue + + if not ck.get("vector"): + results.append({"chunk_id": ck["chunk_id"], "reason": "no_stored_vector"}) + continue + + try: + v, _ = emb_mdl.encode([title, txt_in]) + assert len(v[1]) == len(ck["vector"]), ( + f"The dimension ({len(v[1])}) of given embedding model is different from the original ({len(ck['vector'])})" + ) + sim_content = _cos_sim(v[1], ck["vector"]) + title_w = 0.1 + qv_mix = title_w * v[0] + (1 - title_w) * v[1] + sim_mix = _cos_sim(qv_mix, ck["vector"]) + sim = sim_content + mode = "content_only" + if sim_mix > sim: + sim = sim_mix + mode = "title+content" + except Exception as e: + return False, f"Embedding failure. {e}" + + eff_sims.append(sim) + results.append({ + "chunk_id": ck["chunk_id"], + "doc_id": ck["doc_id"], + "doc_name": ck["doc_name"], + "vector_field": ck["vector_field"], + "vector_dim": ck["vector_dim"], + "cos_sim": round(sim, 6), + }) + + summary = { + "kb_id": dataset_id, + "model": embd_id, + "sampled": len(samples), + "valid": len(eff_sims), + "avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6), + "min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6), + "max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6), + "match_mode": mode, + } + + data = {"summary": summary, "results": results} + if not eff_sims: + logging.warning("check_embedding: dataset=%s no comparable chunks", dataset_id) + return False, "No embedded chunks are available to compare." + if summary["avg_cos_sim"] >= 0.9: + logging.info("check_embedding: dataset=%s compatible avg_cos_sim=%s valid=%d", dataset_id, summary["avg_cos_sim"], len(eff_sims)) + return True, data + logging.warning("check_embedding: dataset=%s not_effective avg_cos_sim=%s valid=%d", dataset_id, summary["avg_cos_sim"], len(eff_sims)) + return "not_effective", {"code": RetCode.NOT_EFFECTIVE, "message": "Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", "data": data} + + +async def search_datasets(tenant_id: str, req: dict): + """ + Search (retrieval test) across multiple datasets. + + :param tenant_id: tenant ID + :param req: search request containing dataset_ids and other params + :return: (success, result) or (success, error_message) + """ + from api.db.joint_services.tenant_model_service import ( + get_model_config_by_id, + get_model_config_by_type_and_name, + get_tenant_default_model_by_type, + ) + from api.db.services.doc_metadata_service import DocMetadataService + from api.db.services.llm_service import LLMBundle + from api.db.services.search_service import SearchService + from api.db.services.user_service import UserTenantService + from common.constants import LLMType + from common.metadata_utils import apply_meta_data_filter + from rag.app.tag import label_question + from rag.prompts.generator import cross_languages, keyword_extraction + + kb_ids = req.get("dataset_ids", []) + page = int(req.get("page", 1)) + size = int(req.get("size", 30)) + question = req.get("question", "") + doc_ids = req.get("doc_ids", []) + use_kg = req.get("use_kg", False) + top = max(1, min(int(req.get("top_k", 1024)), 2048)) + langs = req.get("cross_languages", []) + + logging.debug( + "search_datasets(datasets=%s, tenant=%s, question_len=%s)", + kb_ids, + tenant_id, + len(question), + ) + + # Access check for all datasets + for kb_id in kb_ids: + if not KnowledgebaseService.accessible(kb_id, tenant_id): + logging.warning("search_datasets access denied: dataset=%s tenant=%s", kb_id, tenant_id) + return False, f"Only owner of dataset {kb_id} authorized for this operation." + + kbs = KnowledgebaseService.get_by_ids(kb_ids) + if not kbs: + return False, "Datasets not found!" + + # All datasets must use the same embedding model + embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) + if len(embd_nms) != 1: + return False, "Datasets use different embedding models." + + if doc_ids is not None and not isinstance(doc_ids, list): + return False, "`doc_ids` should be a list" + local_doc_ids = list(doc_ids) if doc_ids else [] + + meta_data_filter = {} + chat_mdl = None + if req.get("search_id", ""): + search_detail = SearchService.get_detail(req.get("search_id", "")) + if not search_detail: + logging.warning("search config not found: search_id=%s", req.get("search_id", "")) + return False, "Invalid search_id" + search_config = search_detail.get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_id = search_config.get("chat_id", "") + if chat_id: + chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, search_config["chat_id"]) + else: + chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(tenant_id, chat_model_config) + else: + meta_data_filter = req.get("meta_data_filter") or {} + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(tenant_id, chat_model_config) + + if meta_data_filter: + local_doc_ids = await apply_meta_data_filter( + meta_data_filter, + None, + question, + chat_mdl, + local_doc_ids, + kb_ids=kb_ids, + metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids), + ) + + tenant_ids = [] + tenants = UserTenantService.query(user_id=tenant_id) + for tenant in tenants: + if any(KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id) for kb_id in kb_ids): + tenant_ids.append(tenant.tenant_id) + break + else: + return False, "Only owner of datasets authorized for this operation." + + kb = kbs[0] + _question = question + if langs: + _question = await cross_languages(kb.tenant_id, None, _question, langs) + if kb.tenant_embd_id: + embd_model_config = get_model_config_by_id(kb.tenant_embd_id) + elif kb.embd_id: + embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) + else: + embd_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.EMBEDDING) + embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) + + rerank_mdl = None + if req.get("tenant_rerank_id"): + rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + elif req.get("rerank_id"): + rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"]) + rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) + + if req.get("keyword", False): + default_chat_model_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(kb.tenant_id, default_chat_model_config) + _question += await keyword_extraction(chat_mdl, _question) + + labels = label_question(_question, kbs) + ranks = await settings.retriever.retrieval( + _question, + embd_mdl, + tenant_ids, + kb_ids, + page, + size, + float(req.get("similarity_threshold", 0.0)), + float(req.get("vector_similarity_weight", 0.3)), + doc_ids=local_doc_ids, + top=top, + rerank_mdl=rerank_mdl, + rank_feature=labels, + ) + + if use_kg: + try: + default_chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, default_chat_model_config)) + if ck["content_with_weight"]: + ranks["chunks"].insert(0, ck) + except Exception: + logging.warning("search_datasets KG retrieval failed: datasets=%s tenant=%s", kb_ids, tenant_id, exc_info=True) + total = ranks.get("total", 0) + ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids) + ranks["total"] = total + + for c in ranks["chunks"]: + c.pop("vector", None) + ranks["labels"] = labels - return True, {"enabled": parser_cfg["enable_metadata"], "fields": fields} + return True, ranks diff --git a/api/apps/services/document_api_service.py b/api/apps/services/document_api_service.py index 82dfa37e353..59abbd25072 100644 --- a/api/apps/services/document_api_service.py +++ b/api/apps/services/document_api_service.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging + from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService @@ -58,7 +60,7 @@ def update_document_name_only(document_id, req_doc_name): ) return None -def update_chunk_method_only(req, doc, dataset_id, tenant_id): +def update_chunk_method(req, doc, tenant_id): """ Update chunk method only (without validation). @@ -69,28 +71,56 @@ def update_chunk_method_only(req, doc, dataset_id, tenant_id): Args: req: The request dictionary containing chunk_method and parser_config. doc: The document model from the database. - dataset_id: The ID of the dataset containing the document. tenant_id: The tenant ID for the document store. Returns: None if successful, or an error result dictionary if failed. """ if doc.parser_id.lower() != req["chunk_method"].lower(): - # if chunk method changed - e = DocumentService.update_by_id( - doc.id, - { - "parser_id": req["chunk_method"], - "progress": 0, - "progress_msg": "", - "run": TaskStatus.UNSTART.value, - }, - ) - if not e: - return get_error_data_result(message="Document not found!") + # if chunk method changed, reset document for reparse + result = reset_document_for_reparse(doc, tenant_id, parser_id=req["chunk_method"]) + if result: + return result if not req.get("parser_config"): req["parser_config"] = get_parser_config(req["chunk_method"], req.get("parser_config")) DocumentService.update_parser_config(doc.id, req["parser_config"]) + return None + + +def reset_document_for_reparse(doc, tenant_id, parser_id=None, pipeline_id=None): + """ + Reset document for reparsing. + + Updates the parser_id and/or pipeline_id for a document, resets its progress, + clears existing chunks from the document store, and removes chunk images. + + Args: + doc: The document model from the database. + tenant_id: The tenant ID for the document store. + parser_id: Optional new parser_id (chunk method). If None, keeps existing. + pipeline_id: Optional new pipeline_id. If None, keeps existing. + + Returns: + None if successful, or an error result dictionary if failed. + """ + + # Build update fields + update_fields = { + "progress": 0, + "progress_msg": "", + "run": TaskStatus.UNSTART.value, + } + if parser_id is not None: + update_fields["parser_id"] = parser_id + if pipeline_id is not None: + update_fields["pipeline_id"] = pipeline_id + + # Update document + e = DocumentService.update_by_id(doc.id, update_fields) + if not e: + return get_error_data_result(message="Document not found!") + + # Delete chunks from document store if doc.token_num > 0: e = DocumentService.increment_chunk_num( doc.id, @@ -98,12 +128,20 @@ def update_chunk_method_only(req, doc, dataset_id, tenant_id): doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1, - ) + ) if not e: return get_error_data_result(message="Document not found!") - settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) + settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) + + # Delete chunk images + try: + DocumentService.delete_chunk_images(doc, tenant_id) + except Exception as e: + logging.error(f"error when delete chunk images:{e}") + return None + def update_document_status_only(status:int, doc, kb): """ Update document status only (without validation). diff --git a/api/apps/services/file_api_service.py b/api/apps/services/file_api_service.py index d6fe9248a50..cfde3de2948 100644 --- a/api/apps/services/file_api_service.py +++ b/api/apps/services/file_api_service.py @@ -67,14 +67,14 @@ async def upload_file(tenant_id: str, pf_id: str, file_objs: list): if not e: return False, "Folder not found!" last_folder = await thread_pool_exec( - FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, len_id_list + FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, len_id_list, tenant_id, tenant_id ) else: e, file = await thread_pool_exec(FileService.get_by_id, file_id_list[len_id_list - 2]) if not e: return False, "Folder not found!" last_folder = await thread_pool_exec( - FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, len_id_list + FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, len_id_list, tenant_id, tenant_id ) filetype = filename_type(file_obj_names[file_len - 1]) @@ -121,7 +121,7 @@ async def create_folder(tenant_id: str, name: str, pf_id: str = None, file_type: if FileService.query(name=name, parent_id=pf_id): return False, "Duplicated folder name in the same folder." - if file_type == FileType.FOLDER.value: + if (file_type or "").lower() == FileType.FOLDER.value: ft = FileType.FOLDER.value else: ft = FileType.VIRTUAL.value @@ -158,6 +158,7 @@ def list_files(tenant_id: str, args: dict): root_folder = FileService.get_root_folder(tenant_id) pf_id = root_folder["id"] FileService.init_knowledgebase_docs(pf_id, tenant_id) + FileService.init_skills_folder(pf_id, tenant_id) e, file = FileService.get_by_id(pf_id) if not e: @@ -173,92 +174,305 @@ def list_files(tenant_id: str, args: dict): -def get_parent_folder(file_id: str): +def get_parent_folder(file_id: str, user_id: str = None): """ - Get parent folder of a file. + Get parent folder of a file with permission check. :param file_id: file ID + :param user_id: user ID for permission validation :return: (success, result) or (success, error_message) """ + from api.common.check_team_permission import check_file_team_permission + e, file = FileService.get_by_id(file_id) if not e: return False, "Folder not found!" + # Permission check + if user_id and not check_file_team_permission(file, user_id): + return False, "No authorization." + parent_folder = FileService.get_parent_folder(file_id) return True, {"parent_folder": parent_folder.to_json()} -def get_all_parent_folders(file_id: str): +def get_all_parent_folders(file_id: str, user_id: str = None): """ - Get all ancestor folders of a file. + Get all ancestor folders of a file with permission check. :param file_id: file ID + :param user_id: user ID for permission validation :return: (success, result) or (success, error_message) """ + from api.common.check_team_permission import check_file_team_permission + e, file = FileService.get_by_id(file_id) if not e: return False, "Folder not found!" + # Permission check + if user_id and not check_file_team_permission(file, user_id): + return False, "No authorization." + parent_folders = FileService.get_all_parent_folders(file_id) return True, {"parent_folders": [pf.to_json() for pf in parent_folders]} -async def delete_files(uid: str, file_ids: list): +async def delete_files(uid: str, file_ids: list, auth_header: str = ""): """ Delete files/folders with team permission check and recursive deletion. :param uid: user ID :param file_ids: list of file IDs to delete + :param auth_header: Authorization header for Go backend API calls :return: (success, result) or (success, error_message) """ - def _delete_single_file(file): + errors: list[str] = [] + success_count = 0 + + def _get_space_uuid_by_name(tenant_id, space_name, authorization): + """Get space UUID by space name from Go backend""" + try: + import requests + + host = getattr(settings, 'HOST_IP', '127.0.0.1') + # Go service runs on port+4 (9384 by default) + port = getattr(settings, 'HOST_PORT', 9380) + 4 + service_url = f"http://{host}:{port}" + + # List all spaces and find the one matching the name + url = f"{service_url}/api/v1/skills/spaces" + headers = {"Content-Type": "application/json"} + if authorization: + headers["Authorization"] = authorization + + response = requests.get(url, headers=headers, timeout=10) + + if response.status_code == 200: + data = response.json() + if data.get("code") == 0: + spaces = data.get("data", {}).get("spaces", []) + for space in spaces: + if space.get("name") == space_name: + return space.get("id") + except Exception as e: + logging.warning(f"Error getting space UUID: {e}") + return None + + def _delete_skill_index(tenant_id, space_name, skill_name, authorization): + """Delete skill index from Go backend. + + Returns: + bool: True if deletion succeeded (HTTP 200), False otherwise. + """ + try: + import requests + from urllib.parse import quote + + # Construct service URL from settings + host = getattr(settings, 'HOST_IP', '127.0.0.1') + # Go service runs on port+4 (9384 by default) + port = getattr(settings, 'HOST_PORT', 9380) + 4 + service_url = f"http://{host}:{port}" + + # Get space UUID from space name + space_uuid = _get_space_uuid_by_name(tenant_id, space_name, authorization) + space_id = space_uuid if space_uuid else space_name + + url = f"{service_url}/api/v1/skills/index?skill_id={quote(skill_name)}&space_id={quote(space_id)}" + headers = {"Content-Type": "application/json"} + if authorization: + headers["Authorization"] = authorization + + response = requests.delete(url, headers=headers, timeout=10) + if response.status_code == 200: + try: + data = response.json() + if data.get("code") == 0: + logging.info( + f"Successfully deleted skill index: space={space_name}, skill={skill_name}, " + f"status={response.status_code}, code=0" + ) + return True + else: + app_code = data.get("code", "unknown") + app_msg = data.get("message", "no message") + logging.error( + f"Failed to delete skill index: space={space_name}, skill={skill_name}, " + f"status={response.status_code}, app_code={app_code}, app_msg={app_msg}, " + f"response={response.text}" + ) + return False + except ValueError as json_err: + # JSON decode error - treat as failure + logging.error( + f"Failed to parse delete response JSON: space={space_name}, skill={skill_name}, " + f"error={json_err}, raw_response={response.text}" + ) + return False + else: + logging.error( + f"Failed to delete skill index: space={space_name}, skill={skill_name}, " + f"status={response.status_code}, response={response.text}" + ) + return False + except Exception as e: + logging.error( + f"Exception deleting skill index: space={space_name}, skill={skill_name}, error={e}" + ) + return False + + def _delete_single_file(file) -> int: try: if file.location: settings.STORAGE_IMPL.rm(file.parent_id, file.location) except Exception as e: logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}") + errors.append(f"Failed to remove object {file.parent_id}/{file.location}: {e}") informs = File2DocumentService.get_by_file_id(file.id) for inform in informs: doc_id = inform.document_id e, doc = DocumentService.get_by_id(doc_id) - if e and doc: - tenant_id = DocumentService.get_tenant_id(doc_id) - if tenant_id: - DocumentService.remove_document(doc, tenant_id) - File2DocumentService.delete_by_file_id(file.id) + if not e or not doc: + errors.append(f"Document not found for file {file.id}: {doc_id}") + continue + + tenant_id = DocumentService.get_tenant_id(doc_id) + if not tenant_id: + errors.append(f"Tenant not found for document {doc_id}") + continue + + if not DocumentService.remove_document(doc, tenant_id): + errors.append(f"Failed to remove document {doc_id} for file {file.id}") - FileService.delete(file) + try: + File2DocumentService.delete_by_file_id(file.id) + except Exception as e: + logging.exception(f"Fail to remove file-document relations for file {file.id}, error: {e}") + errors.append(f"Failed to remove file-document relations for file {file.id}: {e}") - def _delete_folder_recursive(folder, tenant_id): + try: + FileService.delete(file) + except Exception as e: + logging.exception(f"Fail to delete file record {file.id}, error: {e}") + errors.append(f"Failed to delete file record {file.id}: {e}") + else: + return 1 + + return 0 + + def _find_ancestor_skill_space(folder_id, tenant_id): + """Walk up the folder hierarchy to find an ancestor with source_type == 'skill_space'. + + Returns: + tuple: (success, folder) where folder has source_type == 'skill_space', or (False, None) + """ + visited = set() + current_id = folder_id + while current_id and current_id not in visited: + visited.add(current_id) + success, folder = FileService.get_by_id(current_id) + if not success or not folder: + return False, None + if folder.source_type == "skill_space": + return True, folder + # Move to parent + current_id = folder.parent_id + return False, None + + def _delete_folder_recursive(folder, tenant_id) -> int: + deleted = 0 + current_space_name = None + is_space_folder = folder.source_type == "skill_space" + is_skill_folder = False + + if not is_space_folder: + parent_success, parent_folder = FileService.get_by_id(folder.parent_id) + if parent_success and parent_folder and parent_folder.source_type == "skill_space": + is_skill_folder = True + current_space_name = parent_folder.name + logging.info(f"Identified skill folder '{folder.name}' (parent space: {current_space_name})") + else: + ancestor_success, ancestor_folder = _find_ancestor_skill_space(folder.parent_id, tenant_id) + if ancestor_success and ancestor_folder: + is_skill_folder = True + current_space_name = ancestor_folder.name + logging.info(f"Identified skill folder '{folder.name}' (ancestor space: {current_space_name})") + + if is_space_folder: + current_space_name = folder.name + logging.info(f"Processing space folder '{folder.name}' - will delete all skill indexes within") + + if is_skill_folder and current_space_name and not is_space_folder: + logging.info(f"Deleting skill index for skill '{folder.name}' in space '{current_space_name}'") + index_deleted = _delete_skill_index(tenant_id, current_space_name, folder.name, auth_header) + if not index_deleted: + logging.error( + f"Aborting folder deletion due to index deletion failure: " + f"folder={folder.name}, space={current_space_name}" + ) + errors.append( + f"Failed to delete skill index for folder '{folder.name}' in space '{current_space_name}'. " + f"Folder deletion aborted to prevent orphaned indexes." + ) + return deleted sub_files = FileService.list_all_files_by_parent_id(folder.id) + logging.info(f"Folder '{folder.name}': found {len(sub_files)} children to delete") + for sub_file in sub_files: if sub_file.type == FileType.FOLDER.value: - _delete_folder_recursive(sub_file, tenant_id) + deleted += _delete_folder_recursive(sub_file, tenant_id) + else: + deleted += _delete_single_file(sub_file) + try: + FileService.delete(folder) + except Exception as e: + logging.exception(f"Fail to delete folder record {folder.id}, error: {e}") + errors.append(f"Failed to delete folder record {folder.id}: {e}") + else: + deleted += 1 + + try: + if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): + logging.info(f"Removing storage bucket for folder '{folder.name}' (id={folder.id})") + settings.STORAGE_IMPL.remove_bucket(folder.id) else: - _delete_single_file(sub_file) - FileService.delete(folder) + logging.debug(f"Storage implementation does not support remove_bucket, skipping for folder '{folder.name}'") + except Exception as e: + logging.warning(f"Failed to remove storage bucket for folder '{folder.name}' (id={folder.id}): {e}") + + return deleted def _rm_sync(): + nonlocal success_count for file_id in file_ids: e, file = FileService.get_by_id(file_id) if not e or not file: - return False, "File or Folder not found!" + errors.append(f"File or Folder not found: {file_id}") + continue if not file.tenant_id: - return False, "Tenant not found!" + errors.append(f"Tenant not found for file {file_id}") + continue if not check_file_team_permission(file, uid): - return False, "No authorization." + errors.append(f"No authorization for file {file_id}") + continue if file.source_type == FileSource.KNOWLEDGEBASE: continue + if file.source_type == "skill_space": + continue + if file.type == FileType.FOLDER.value: - _delete_folder_recursive(file, uid) + success_count += _delete_folder_recursive(file, uid) continue - _delete_single_file(file) + success_count += _delete_single_file(file) - return True, True + if errors: + return False, {"success_count": success_count, "errors": errors} + return True, {"success_count": success_count} return await thread_pool_exec(_rm_sync) @@ -307,6 +521,18 @@ async def move_files(uid: str, src_file_ids: list, dest_file_id: str = None, new if f.name == new_name: return False, "Duplicated file name in the same folder." + if dest_folder: + for file in files: + if file.type == FileType.FOLDER.value and file.id == dest_folder.id: + return False, "Cannot move a folder to itself." + # Check if any source folder is an ancestor of the destination folder + # to prevent infinite recursion in _move_entry_recursive + dest_ancestors = FileService.get_all_parent_folders(dest_folder.id) + dest_ancestor_ids = {f.id for f in dest_ancestors} + for file in files: + if file.type == FileType.FOLDER.value and file.id in dest_ancestor_ids: + return False, "Cannot move a folder into its own subfolder." + def _move_entry_recursive(source_file_entry, dest_folder_entry, override_name=None): effective_name = override_name or source_file_entry.name diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index 1b640cff66b..9040f0ce445 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -29,6 +29,49 @@ from common.time_utils import current_timestamp, timestamp_to_date +def _split_filter_values(values): + if not values: + return [] + if isinstance(values, str): + values = [values] + res = [] + for value in values: + if not value: + continue + if isinstance(value, str): + res.extend([v.strip() for v in value.split(",") if v.strip()]) + else: + res.append(value) + return res + + +def _joined_tenant_ids(user_id: str) -> set[str]: + user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(user_id) + return {user_id, *[tenant["tenant_id"] for tenant in user_tenants]} + + +def _memory_accessible(memory) -> bool: + if memory.tenant_id == current_user.id: + return True + if memory.permissions != TenantPermission.TEAM.value: + return False + return memory.tenant_id in _joined_tenant_ids(current_user.id) + + +def _require_memory_access(memory_id: str): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory or not _memory_accessible(memory): + raise NotFoundException(f"Memory '{memory_id}' not found.") + return memory + + +def _filter_accessible_memories(memory_ids: list[str]): + memory_ids = _split_filter_values(memory_ids) + if not memory_ids: + return [] + return [memory for memory in MemoryService.get_by_ids(memory_ids) if _memory_accessible(memory)] + + async def create_memory(memory_info: dict): """ :param memory_info: { @@ -137,9 +180,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict): for field in ["avatar", "description", "system_prompt", "user_prompt"]: if field in new_memory_setting: update_dict[field] = new_memory_setting[field] - current_memory = MemoryService.get_by_memory_id(memory_id) - if not current_memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + current_memory = _require_memory_access(memory_id) memory_dict = current_memory.to_dict() memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) @@ -168,9 +209,7 @@ async def update_memory(memory_id: str, new_memory_setting: dict): async def delete_memory(memory_id): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) MemoryService.delete_memory(memory_id) if MessageService.has_index(memory.tenant_id, memory_id): MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id) @@ -188,19 +227,16 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size :param page: int :param page_size: int """ - filter_dict: dict = {"storage_type": filter_params.get("storage_type")} - tenant_ids = filter_params.get("tenant_id") - if not filter_params.get("tenant_id"): - # restrict to current user's tenants - user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id) - filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants] + filter_dict: dict = {"storage_type": filter_params.get("storage_type"), "accessible_user_id": current_user.id} + allowed_tenant_ids = _joined_tenant_ids(current_user.id) + tenant_ids = _split_filter_values(filter_params.get("tenant_id") or filter_params.get("owner_ids")) + if tenant_ids: + filter_dict["tenant_id"] = [tenant_id for tenant_id in tenant_ids if tenant_id in allowed_tenant_ids] + if not filter_dict["tenant_id"]: + return {"memory_list": [], "total_count": 0} else: - if len(tenant_ids) == 1 and ',' in tenant_ids[0]: - tenant_ids = tenant_ids[0].split(',') - filter_dict["tenant_id"] = tenant_ids - memory_types = filter_params.get("memory_type") - if memory_types and len(memory_types) == 1 and ',' in memory_types[0]: - memory_types = memory_types[0].split(',') + filter_dict["tenant_id"] = list(allowed_tenant_ids) + memory_types = _split_filter_values(filter_params.get("memory_type")) filter_dict["memory_type"] = memory_types memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) @@ -212,15 +248,13 @@ async def list_memory(filter_params: dict, keywords: str, page: int=1, page_size async def get_memory_config(memory_id): memory = MemoryService.get_with_owner_name_by_id(memory_id) - if not memory: + if not memory or not _memory_accessible(memory): raise NotFoundException(f"Memory '{memory_id}' not found.") return format_ret_data_from_memory(memory) async def get_memory_messages(memory_id, agent_ids: list[str], keywords: str, page: int=1, page_size: int = 50): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) messages = MessageService.list_message( memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) agent_name_mapping = {} @@ -253,13 +287,14 @@ async def add_message(memory_ids: list[str], message_dict: dict): "message_type": str } """ - return await queue_save_to_memory_task(memory_ids, message_dict) + accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)] + if not accessible_memory_ids: + return False, "Memory not found." + return await queue_save_to_memory_task(accessible_memory_ids, message_dict) async def forget_message(memory_id: str, message_id: int): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) forget_time = timestamp_to_date(current_timestamp()) update_succeed = MessageService.update_message( @@ -272,9 +307,7 @@ async def forget_message(memory_id: str, message_id: int): async def update_message_status(memory_id: str, message_id: int, status: bool): - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) update_succeed = MessageService.update_message( {"memory_id": memory_id, "message_id": int(message_id)}, @@ -300,6 +333,11 @@ async def search_message(filter_dict: dict, params: dict): "top_n": int } """ + memory_ids = _split_filter_values(filter_dict.get("memory_id")) + accessible_memory_ids = [memory.id for memory in _filter_accessible_memories(memory_ids)] + if not accessible_memory_ids: + return [] + filter_dict = {**filter_dict, "memory_id": accessible_memory_ids} return query_message(filter_dict, params) @@ -313,11 +351,14 @@ async def get_messages(memory_ids: list[str], agent_id: str = "", session_id: st :param limit: maximum number of messages to return :return: list of recent messages """ - memory_list = MemoryService.get_by_ids(memory_ids) + memory_list = _filter_accessible_memories(memory_ids) + if not memory_list: + return [] uids = [memory.tenant_id for memory in memory_list] + accessible_memory_ids = [memory.id for memory in memory_list] res = MessageService.get_recent_messages( uids, - memory_ids, + accessible_memory_ids, agent_id, session_id, limit @@ -334,11 +375,9 @@ async def get_message_content(memory_id: str, message_id: int): :return: message content :raises NotFoundException: if memory or message not found """ - memory = MemoryService.get_by_memory_id(memory_id) - if not memory: - raise NotFoundException(f"Memory '{memory_id}' not found.") + memory = _require_memory_access(memory_id) res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id) if res: return res - raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.") \ No newline at end of file + raise NotFoundException(f"Message '{message_id}' in memory '{memory_id}' not found.") diff --git a/api/apps/system_app.py b/api/apps/system_app.py deleted file mode 100644 index 833a7819dd5..00000000000 --- a/api/apps/system_app.py +++ /dev/null @@ -1,197 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -# -import logging -from datetime import datetime -import json - -from api.apps import login_required - -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.utils.api_utils import ( - get_json_result, -) - -from timeit import default_timer as timer - -from rag.utils.redis_conn import REDIS_CONN -from api.utils.health_utils import get_oceanbase_status -from common import settings - -@manager.route("/status", methods=["GET"]) # noqa: F821 -@login_required -def status(): - """ - Get the system status. - --- - tags: - - System - security: - - ApiKeyAuth: [] - responses: - 200: - description: System is operational. - schema: - type: object - properties: - es: - type: object - description: Elasticsearch status. - storage: - type: object - description: Storage status. - database: - type: object - description: Database status. - 503: - description: Service unavailable. - schema: - type: object - properties: - error: - type: string - description: Error message. - """ - res = {} - st = timer() - try: - res["doc_engine"] = settings.docStoreConn.health() - res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) - except Exception as e: - res["doc_engine"] = { - "type": "unknown", - "status": "red", - "elapsed": "{:.1f}".format((timer() - st) * 1000.0), - "error": str(e), - } - - st = timer() - try: - settings.STORAGE_IMPL.health() - res["storage"] = { - "storage": settings.STORAGE_IMPL_TYPE.lower(), - "status": "green", - "elapsed": "{:.1f}".format((timer() - st) * 1000.0), - } - except Exception as e: - res["storage"] = { - "storage": settings.STORAGE_IMPL_TYPE.lower(), - "status": "red", - "elapsed": "{:.1f}".format((timer() - st) * 1000.0), - "error": str(e), - } - - st = timer() - try: - KnowledgebaseService.get_by_id("x") - res["database"] = { - "database": settings.DATABASE_TYPE.lower(), - "status": "green", - "elapsed": "{:.1f}".format((timer() - st) * 1000.0), - } - except Exception as e: - res["database"] = { - "database": settings.DATABASE_TYPE.lower(), - "status": "red", - "elapsed": "{:.1f}".format((timer() - st) * 1000.0), - "error": str(e), - } - - st = timer() - try: - if not REDIS_CONN.health(): - raise Exception("Lost connection!") - res["redis"] = { - "status": "green", - "elapsed": "{:.1f}".format((timer() - st) * 1000.0), - } - except Exception as e: - res["redis"] = { - "status": "red", - "elapsed": "{:.1f}".format((timer() - st) * 1000.0), - "error": str(e), - } - - task_executor_heartbeats = {} - try: - task_executors = REDIS_CONN.smembers("TASKEXE") - now = datetime.now().timestamp() - for task_executor_id in task_executors: - heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60 * 30, now) - heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats] - task_executor_heartbeats[task_executor_id] = heartbeats - except Exception: - logging.exception("get task executor heartbeats failed!") - res["task_executor_heartbeats"] = task_executor_heartbeats - - return get_json_result(data=res) - -@manager.route("/oceanbase/status", methods=["GET"]) # noqa: F821 -@login_required -def oceanbase_status(): - """ - Get OceanBase health status and performance metrics. - --- - tags: - - System - security: - - ApiKeyAuth: [] - responses: - 200: - description: OceanBase status retrieved successfully. - schema: - type: object - properties: - status: - type: string - description: Status (alive/timeout). - message: - type: object - description: Detailed status information including health and performance metrics. - """ - try: - status_info = get_oceanbase_status() - return get_json_result(data=status_info) - except Exception as e: - return get_json_result( - data={ - "status": "error", - "message": f"Failed to get OceanBase status: {str(e)}" - }, - code=500 - ) - - -@manager.route("/config", methods=["GET"]) # noqa: F821 -def get_config(): - """ - Get system configuration. - --- - tags: - - System - responses: - 200: - description: Return system configuration - schema: - type: object - properties: - registerEnable: - type: integer 0 means disabled, 1 means enabled - description: Whether user registration is enabled - """ - return get_json_result(data={ - "registerEnabled": settings.REGISTER_ENABLED, - "disablePasswordLogin": settings.DISABLE_PASSWORD_LOGIN, - }) diff --git a/api/db/__init__.py b/api/db/__init__.py index 0ebd9f56f3f..6d7ed9fcb97 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -74,3 +74,4 @@ class PipelineTaskType(StrEnum): KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" +SKILLS_FOLDER_NAME="skills" diff --git a/api/db/db_models.py b/api/db/db_models.py index 433ed78afe2..5fe64586c04 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -55,7 +55,7 @@ from common.time_utils import current_timestamp, timestamp_to_date, date_string_to_timestamp from common.decorator import singleton -from common.constants import ParserType +from common.constants import ParserType, MAXIMUM_TASK_PAGE_NUMBER from common import settings @@ -726,7 +726,7 @@ def __str__(self): return self.email def get_id(self): - jwt = Serializer(secret_key=settings.SECRET_KEY) + jwt = Serializer(secret_key=settings.get_secret_key()) return jwt.dumps(str(self.access_token)) class Meta: @@ -945,7 +945,7 @@ class Task(DataBaseModel): id = CharField(max_length=32, primary_key=True) doc_id = CharField(max_length=32, null=False, index=True) from_page = IntegerField(default=0) - to_page = IntegerField(default=100000000) + to_page = IntegerField(default=MAXIMUM_TASK_PAGE_NUMBER) task_type = CharField(max_length=32, null=False, default="") priority = IntegerField(default=0) diff --git a/api/db/joint_services/tenant_model_service.py b/api/db/joint_services/tenant_model_service.py index f53f83ab957..645d7563812 100644 --- a/api/db/joint_services/tenant_model_service.py +++ b/api/db/joint_services/tenant_model_service.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging import os import enum from common import settings @@ -20,14 +21,22 @@ from api.db.services.llm_service import LLMService from api.db.services.tenant_llm_service import TenantLLMService, TenantService +logger = logging.getLogger(__name__) + def get_model_config_by_id(tenant_model_id: int) -> dict: found, model_config = TenantLLMService.get_by_id(tenant_model_id) if not found: raise LookupError(f"Tenant Model with id {tenant_model_id} not found") config_dict = model_config.to_dict() + api_key, is_tools, api_key_payload = TenantLLMService._decode_api_key_config(config_dict.get("api_key", "")) + config_dict["api_key"] = api_key + if api_key_payload is not None: + config_dict["api_key_payload"] = api_key_payload + if is_tools is not None: + config_dict["is_tools"] = is_tools llm = LLMService.query(llm_name=config_dict["llm_name"]) - if llm: + if "is_tools" not in config_dict and llm: config_dict["is_tools"] = llm[0].is_tools return config_dict @@ -57,6 +66,31 @@ def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_nam "api_base": embedding_cfg["base_url"], "model_type": LLMType.EMBEDDING.value, } + elif model_type_val == LLMType.CHAT.value: + # Retry as CHAT with pure_model_name first; then fall back to a multimodal model registered under IMAGE2TEXT. + model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.CHAT.value) + if not model_config: + model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.IMAGE2TEXT.value) + if not model_config: + raise LookupError(f"Tenant Model with name {model_name} and type {model_type_val} not found") + config_dict = model_config.to_dict() + elif model_type_val == LLMType.IMAGE2TEXT.value: + model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.IMAGE2TEXT.value) + if not model_config: + # Fall back to a chat model only if it has declared IMAGE2TEXT capability (tag check via llm table) + chat_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, LLMType.CHAT.value) + logger.debug("IMAGE2TEXT config not found for %s; chat_config found: %s", pure_model_name, chat_config is not None) + if chat_config: + llm_entry = LLMService.query(fid=chat_config.llm_factory, llm_name=chat_config.llm_name) + tags = [t.strip() for t in (llm_entry[0].tags or "").split(",")] if llm_entry else [] + logger.debug("LLM tags for %s/%s: %s", chat_config.llm_factory, chat_config.llm_name, tags) + if "IMAGE2TEXT" in tags: + logger.debug("Promoting chat config to IMAGE2TEXT for %s", pure_model_name) + model_config = chat_config + if not model_config: + raise LookupError(f"Tenant Model with name {model_name} and type {model_type_val} not found") + config_dict = model_config.to_dict() + config_dict["model_type"] = LLMType.IMAGE2TEXT.value else: model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, model_type_val) if not model_config: @@ -65,14 +99,26 @@ def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_nam else: # model_name without @factory config_dict = model_config.to_dict() + api_key, is_tools, api_key_payload = TenantLLMService._decode_api_key_config(config_dict.get("api_key", "")) + config_dict["api_key"] = api_key + if api_key_payload is not None: + config_dict["api_key_payload"] = api_key_payload + if is_tools is not None: + config_dict["is_tools"] = is_tools config_model_type = config_dict.get("model_type") config_model_type = config_model_type.value if hasattr(config_model_type, "value") else config_model_type - if config_model_type != model_type_val: + if config_model_type != model_type_val and not ( + model_type_val == LLMType.CHAT.value + and config_model_type == LLMType.IMAGE2TEXT.value + ) and not ( + model_type_val == LLMType.IMAGE2TEXT.value + and config_model_type == LLMType.CHAT.value + ): raise LookupError( f"Tenant Model with name {model_name} has type {config_model_type}, expected {model_type_val}" ) llm = LLMService.query(llm_name=config_dict["llm_name"]) - if llm: + if "is_tools" not in config_dict and llm: config_dict["is_tools"] = llm[0].is_tools return config_dict diff --git a/api/db/services/api_service.py b/api/db/services/api_service.py index be41dc1b642..8f60a1c5ab5 100644 --- a/api/db/services/api_service.py +++ b/api/db/services/api_service.py @@ -44,6 +44,14 @@ def delete_by_tenant_id(cls, tenant_id): class API4ConversationService(CommonService): model = API4Conversation + @staticmethod + def _normalize_query_date(value, is_end=False): + if "T" in value: + value = datetime.fromisoformat(value.replace("Z", "+00:00")).astimezone().replace(tzinfo=None).strftime("%Y-%m-%d %H:%M:%S") + elif len(value) == 10: + value = f"{value} 23:59:59" if is_end else f"{value} 00:00:00" + return value + @classmethod @DB.connection_context() def get_list(cls, dialog_id, tenant_id, @@ -62,10 +70,11 @@ def get_list(cls, dialog_id, tenant_id, sessions = sessions.where(cls.model.user_id == user_id) if keywords: sessions = sessions.where(peewee.fn.LOWER(cls.model.message).contains(keywords.lower())) + date_field = cls.model.update_date if orderby.startswith("update_") else cls.model.create_date if from_date: - sessions = sessions.where(cls.model.create_date >= from_date) + sessions = sessions.where(date_field >= cls._normalize_query_date(from_date)) if to_date: - sessions = sessions.where(cls.model.create_date <= to_date) + sessions = sessions.where(date_field <= cls._normalize_query_date(to_date, is_end=True)) if exp_user_id: sessions = sessions.where(cls.model.exp_user_id == exp_user_id) if desc: diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 98925fa246a..4a5734e155d 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -139,10 +139,17 @@ def get_basic_info_by_canvas_ids(cls, canvas_id): @classmethod @DB.connection_context() - def get_by_tenant_ids(cls, joined_tenant_ids, user_id, - page_number, items_per_page, - orderby, desc, keywords, canvas_category=None - ): + def get_by_tenant_ids( + cls, + joined_tenant_ids, + user_id, + page_number, + items_per_page, + orderby, + desc, + keywords, + canvas_category=None, + ): fields = [ cls.model.id, cls.model.avatar, @@ -201,7 +208,11 @@ def accessible(cls, canvas_id, tenant_id): return False tids = [t.tenant_id for t in UserTenantService.query(user_id=tenant_id)] - if c["user_id"] != canvas_id and c["user_id"] not in tids: + if c["user_id"] == tenant_id: + return True + if c["user_id"] not in tids: + return False + if c["permission"] != TenantPermission.TEAM.value: return False return True @@ -210,8 +221,6 @@ def get_agent_dsl_with_release(cls, agent_id, release_mode=False, tenant_id=None e, cvs = cls.get_by_id(agent_id) if not e: raise LookupError("Agent not found.") - if tenant_id and cvs.user_id != tenant_id: - raise PermissionError("You do not own the agent.") if release_mode: released_version = UserCanvasVersionService.get_latest_released(agent_id) diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index 85d495d9d63..9f7b0e6ded1 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -29,6 +29,7 @@ from api.utils.common import hash128 from common.misc_utils import get_uuid from common.constants import TaskStatus +from common.settings import TIMEZONE from common.time_utils import current_timestamp, timestamp_to_date class ConnectorService(CommonService): @@ -99,7 +100,7 @@ def cleanup_stale_documents_for_task( return 0, [] source_type = f"{conn.source}/{conn.id}" - retain_doc_ids = {hash128(file.id) for file in file_list} + retain_doc_ids = {hash128(f"{connector_id}:{file.id}") for file in file_list} existing_docs = DocumentService.list_doc_headers_by_kb_and_source_type( kb_id, source_type, @@ -179,14 +180,14 @@ def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15) else: database_type = os.getenv("DB_TYPE", "mysql") if "postgres" in database_type.lower(): - interval_expr = SQL("make_interval(mins => t2.refresh_freq)") + expr = SQL(f"NOW() AT TIME ZONE '{TIMEZONE}' - make_interval(mins => t2.refresh_freq)") else: - interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE") + expr = SQL("NOW() - INTERVAL `t2`.`refresh_freq` MINUTE") query = query.where( Connector.input_type == InputType.POLL, Connector.status == TaskStatus.SCHEDULE, cls.model.status == TaskStatus.SCHEDULE, - cls.model.update_date < (fn.NOW() - interval_expr) + cls.model.update_date < expr ) query = query.distinct().order_by(cls.model.update_time.desc()) diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 5a205b14219..2603676e98e 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -14,6 +14,7 @@ # limitations under the License. # import time +import logging from uuid import uuid4 from common.constants import StatusEnum from api.db.db_models import Conversation, DB @@ -26,6 +27,9 @@ from rag.prompts.generator import chunks_format +logger = logging.getLogger(__name__) + + class ConversationService(CommonService): model = Conversation @@ -201,9 +205,23 @@ async def async_completion(tenant_id, chat_id, question, name="New session", ses break yield answer -async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs): - e, dia = DialogService.get_by_id(dialog_id) - assert e, "Dialog not found" +async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, tenant_id=None, **kwargs): + if tenant_id: + exists, dia = DialogService.get_by_id(dialog_id) + if (not exists + or getattr(dia, "tenant_id", None) != tenant_id + or str(getattr(dia, "status", "")) != StatusEnum.VALID.value): + logger.warning( + "Dialog lookup failed for tenant-scoped iframe completion: " + "tenant_id=%s dialog_id=%s required_status=%s", + tenant_id, + dialog_id, + StatusEnum.VALID.value, + ) + raise AssertionError("Dialog not found") + else: + e, dia = DialogService.get_by_id(dialog_id) + assert e, "Dialog not found" if not session_id: session_id = get_uuid() conv = { @@ -228,6 +246,7 @@ async def async_iframe_completion(dialog_id, question, session_id=None, stream=T session_id = session_id e, conv = API4ConversationService.get_by_id(session_id) assert e, "Session not found!" + assert conv.dialog_id == dialog_id, "Session does not belong to this dialog" if not conv.message: conv.message = [] diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index cadf76c2aa8..6f981efb5e6 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -18,7 +18,10 @@ import logging import re import time +import uuid from copy import deepcopy + +logger = logging.getLogger(__name__) from datetime import datetime from functools import partial from timeit import default_timer as timer @@ -33,6 +36,10 @@ from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.llm_service import LLMBundle from common.metadata_utils import apply_meta_data_filter +from api.utils.reference_metadata_utils import ( + enrich_chunks_with_document_metadata, + resolve_reference_metadata_preferences, +) from api.db.services.tenant_llm_service import TenantLLMService from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type from common.time_utils import current_timestamp, datetime_format @@ -41,13 +48,22 @@ from rag.advanced_rag import DeepResearcher from rag.app.tag import label_question from rag.nlp.search import index_name -from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \ - PROMPT_JINJA_ENV, ASK_SUMMARY +from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, PROMPT_JINJA_ENV, ASK_SUMMARY from common.token_utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces from common import settings +def _resolve_reference_metadata(request_payload=None, config=None): + return resolve_reference_metadata_preferences(request_payload or {}, config) + +def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None): + enrich_chunks_with_document_metadata(chunks, metadata_fields) + +def _chunk_kb_id_for_doc(row_dict, kb_ids, doc_id): + if len(kb_ids or []) == 1: + return kb_ids[0] + return row_dict.get("kb_id") or row_dict.get("kb_id_kwd") def _normalize_internet_flag(value): if isinstance(value, bool): @@ -70,6 +86,15 @@ def _should_use_web_search(prompt_config, internet=None): return normalized is True +def _resolve_reference_metadata(config, request_payload=None): + return resolve_reference_metadata_preferences(request_payload or {}, config) + + +def _enrich_chunks_with_document_metadata(chunks, metadata_fields=None): + enrich_chunks_with_document_metadata(chunks, metadata_fields) + + + class DialogService(CommonService): model = Dialog @@ -168,8 +193,7 @@ def get_by_tenant_ids( cls.model.select(*fields) .join(User, on=(cls.model.tenant_id == User.id)) .where( - (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) - & (cls.model.status == StatusEnum.VALID.value), + (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value), ) ) if id: @@ -210,22 +234,14 @@ def get_all_dialogs_by_tenant_id(cls, tenant_id): @classmethod @DB.connection_context() def get_null_tenant_llm_id_row(cls): - fields = [ - cls.model.id, - cls.model.tenant_id, - cls.model.llm_id - ] + fields = [cls.model.id, cls.model.tenant_id, cls.model.llm_id] objs = cls.model.select(*fields).where(cls.model.tenant_llm_id.is_null()) return list(objs) @classmethod @DB.connection_context() def get_null_tenant_rerank_id_row(cls): - fields = [ - cls.model.id, - cls.model.tenant_id, - cls.model.rerank_id - ] + fields = [cls.model.id, cls.model.tenant_id, cls.model.rerank_id] objs = cls.model.select(*fields).where(cls.model.tenant_rerank_id.is_null()) return list(objs) @@ -241,7 +257,7 @@ async def async_chat_solo(dialog, messages, stream=True): else: text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True) attachments = "\n\n".join(text_attachments) - + if dialog.llm_id: model_config = get_model_config_by_type_and_name(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) elif dialog.tenant_llm_id: @@ -460,11 +476,11 @@ def find_and_replace(pattern, group_index=1, repl=lambda digits: f"ID:{digits}") parts = [] last_idx = 0 for match in matches: - parts.append(answer[last_idx:match.start()]) + parts.append(answer[last_idx : match.start()]) try: i = int(match.group(group_index)) except Exception: - parts.append(answer[match.start():match.end()]) + parts.append(answer[match.start() : match.end()]) last_idx = match.end() continue @@ -473,7 +489,7 @@ def find_and_replace(pattern, group_index=1, repl=lambda digits: f"ID:{digits}") digits_original = answer[digit_start:digit_end] parts.append(f"[{repl(digits_original)}]") else: - parts.append(answer[match.start():match.end()]) + parts.append(answer[match.start() : match.end()]) last_idx = match.end() parts.append(answer[last_idx:]) @@ -534,7 +550,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): attachments = None if "doc_ids" in kwargs: attachments = [doc_id for doc_id in kwargs["doc_ids"].split(",") if doc_id] - attachments_= "" + attachments_ = "" image_attachments = [] image_files = [] if "doc_ids" in messages[-1]: @@ -547,6 +563,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): attachments_ = "\n\n".join(text_attachments) prompt_config = dialog.prompt_config + include_reference_metadata, metadata_fields = _resolve_reference_metadata(prompt_config, request_payload=kwargs) field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) logging.debug(f"field_map retrieved: {field_map}") # try to use sql if field mapping is good to go @@ -555,6 +572,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs): ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) # For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")): + if include_reference_metadata and ans.get("reference", {}).get("chunks"): + if len(dialog.kb_ids) != 1 and any(not c.get("kb_id") for c in ans["reference"]["chunks"]): + logging.warning( + "Skipping some _enrich_chunks_with_document_metadata results because " + "dialog.kb_ids has %d entries and use_sql returned chunks without kb_id.", + len(dialog.kb_ids), + ) + _enrich_chunks_with_document_metadata(ans["reference"]["chunks"], metadata_fields) yield ans return else: @@ -584,13 +609,14 @@ async def async_chat(dialog, messages, stream=True, **kwargs): questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] if dialog.meta_data_filter: - metas = DocMetadataService.get_flatted_meta_by_kbs(dialog.kb_ids) attachments = await apply_meta_data_filter( dialog.meta_data_filter, - metas, + None, questions[-1], chat_mdl, attachments, + kb_ids=dialog.kb_ids, + metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(dialog.kb_ids), ) if prompt_config.get("keyword", False): @@ -623,7 +649,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs): internet_enabled=use_web_search, ) queue = asyncio.Queue() - async def callback(msg:str): + + async def callback(msg: str): nonlocal queue await queue.put(msg + "
") @@ -632,9 +659,9 @@ async def callback(msg:str): while True: msg = await queue.get() if msg.find("") == 0: - yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True} + yield {"answer": "", "reference": {}, "audio_binary": None, "final": False} elif msg.find("") == 0: - yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True} + yield {"answer": "", "reference": {}, "audio_binary": None, "final": False} break else: yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False} @@ -670,25 +697,31 @@ async def callback(msg:str): kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) if prompt_config.get("use_kg"): default_chat_model = get_tenant_default_model_by_type(dialog.tenant_id, LLMType.CHAT) - ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, - LLMBundle(dialog.tenant_id, default_chat_model)) + ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, default_chat_model)) if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) + if include_reference_metadata: + logging.debug( + "reference_metadata enrichment enabled for async_chat: chunk_count=%d metadata_fields=%s", + len(kbinfos.get("chunks", [])), + metadata_fields, + ) + _enrich_chunks_with_document_metadata(kbinfos.get("chunks", []), metadata_fields) + knowledges = kb_prompt(kbinfos, max_tokens) logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges))) retrieval_ts = timer() if not knowledges and prompt_config.get("empty_response"): empty_res = prompt_config["empty_response"] - yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), - "audio_binary": tts(tts_mdl, empty_res), "final": True} + yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res), "final": True} return kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) gen_conf = dialog.llm_setting - msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}] + msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs) + attachments_}] prompt4citation = "" if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): prompt4citation = citation_prompt() @@ -783,8 +816,7 @@ def decorate_answer(answer): if langfuse_tracer: langfuse_generation = langfuse_tracer.start_generation( - trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], - input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg} + trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg} ) if stream: @@ -802,7 +834,7 @@ def decorate_answer(answer): yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False} full_answer = last_state.full_text if last_state else "" if full_answer: - final = decorate_answer(thought + full_answer) + final = decorate_answer(_extract_visible_answer(thought + full_answer)) final["final"] = True final["audio_binary"] = None yield final @@ -821,6 +853,25 @@ def decorate_answer(answer): async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): + """Answer a natural-language question by generating and executing SQL against the document index. + + Detects the active document engine (Infinity, OceanBase, or Elasticsearch), asks the + chat model to produce the appropriate SQL, injects a validated kb_id filter, executes + the query, and returns formatted results with optional source citations. + + Args: + question: Natural-language question from the user. + field_map: Mapping of field names to types describing the indexed document schema. + tenant_id: Tenant identifier used to derive the target index/table name. + chat_mdl: LLM bundle used to generate SQL from the question. + quota: Whether to enforce token-quota checks (default True). + kb_ids: Optional list of knowledge-base UUIDs to restrict the query scope. + + Returns: + A dict with keys ``answer`` (formatted response string), ``reference`` + (dict of supporting document chunks and doc_aggs), and ``prompt`` + (the system prompt used), or ``None`` if SQL generation or execution fails. + """ logging.debug(f"use_sql: Question: {question}") # Determine which document engine we're using @@ -831,12 +882,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N else: doc_engine = "es" + def _assert_valid_uuid(value: str, label: str = "id") -> None: + try: + uuid.UUID(str(value)) + except (ValueError, AttributeError, TypeError): + logger.warning("SQL injection guard rejected invalid %s value (length=%d)", label, len(str(value))) + raise ValueError(f"Invalid {label} format: {value!r}") + # Construct the full table name # For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause) # For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table) base_table = index_name(tenant_id) if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1: - # Infinity: append kb_id to table name + # Infinity: append kb_id to table name — validate before interpolating + _assert_valid_uuid(kb_ids[0], "kb_id") table_name = f"{base_table}_{kb_ids[0]}" logging.debug(f"use_sql: Using Infinity table name: {table_name}") else: @@ -847,13 +906,20 @@ async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=N expected_doc_name_column = "docnm" if doc_engine == "infinity" else "docnm_kwd" def has_source_columns(columns): + """Return True if the result set contains the columns needed to build source citations.""" normalized_names = {str(col.get("name", "")).lower() for col in columns} return "doc_id" in normalized_names and bool({"docnm_kwd", "docnm"} & normalized_names) def is_aggregate_sql(sql_text): + """Return True if *sql_text* contains an aggregate function (COUNT, SUM, AVG, MAX, MIN, DISTINCT).""" return bool(re.search(r"(count|sum|avg|max|min|distinct)\s*\(", (sql_text or "").lower())) def normalize_sql(sql): + """Strip LLM artefacts from *sql* and return a clean, executable SQL string. + + Removes ```` reasoning blocks, Chinese reasoning markers, markdown + code fences, and trailing semicolons that some engines reject. + """ logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}") # Remove think blocks if present (format: ...) sql = re.sub(r"\n.*?\n\s*", "", sql, flags=re.DOTALL) @@ -862,18 +928,28 @@ def normalize_sql(sql): sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE) sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE) # Remove trailing semicolon that ES SQL parser doesn't like - return sql.rstrip().rstrip(';').strip() + return sql.rstrip().rstrip(";").strip() def add_kb_filter(sql): + """Inject a validated kb_id WHERE filter into *sql* for ES/OceanBase engines. + + Infinity encodes the knowledge-base scope in the table name, so this + function is a no-op for that engine. All kb_id values are validated as + canonical UUIDs before interpolation to prevent SQL injection. + """ # Add kb_id filter for ES/OS only (Infinity already has it in table name) if doc_engine == "infinity" or not kb_ids: return sql + # Validate all kb_ids are UUIDs before interpolating into SQL + for kid in kb_ids: + _assert_valid_uuid(kid, "kb_id") + # Build kb_filter: single KB or multiple KBs with OR if len(kb_ids) == 1: kb_filter = f"kb_id = '{kb_ids[0]}'" else: - kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" + kb_filter = "(" + " OR ".join([f"kb_id = '{kid}'" for kid in kb_ids]) + ")" if "where " not in sql.lower(): o = sql.lower().split("order by") @@ -886,6 +962,7 @@ def add_kb_filter(sql): return sql def is_row_count_question(q: str) -> bool: + """Return True if *q* is asking for a total row count of a dataset or table.""" q = (q or "").lower() if not re.search(r"\bhow many rows\b|\bnumber of rows\b|\brow count\b", q): return False @@ -895,11 +972,7 @@ def is_row_count_question(q: str) -> bool: if doc_engine == "infinity": # Build Infinity prompts with JSON extraction context json_field_names = list(field_map.keys()) - row_count_override = ( - f"SELECT COUNT(*) AS rows FROM {table_name}" - if is_row_count_question(question) - else None - ) + row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. JSON Extraction: json_extract_string(chunk_data, '$.FieldName') @@ -923,19 +996,12 @@ def is_row_count_question(q: str) -> bool: {} Question: {} Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format( - table_name, - ", ".join(json_field_names), - "\n".join([f" - {field}" for field in json_field_names]), - question + table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question ) elif doc_engine == "oceanbase": # Build OceanBase prompts with JSON extraction context json_field_names = list(field_map.keys()) - row_count_override = ( - f"SELECT COUNT(*) AS rows FROM {table_name}" - if is_row_count_question(question) - else None - ) + row_count_override = f"SELECT COUNT(*) AS rows FROM {table_name}" if is_row_count_question(question) else None sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. JSON Extraction: json_extract_string(chunk_data, '$.FieldName') @@ -959,10 +1025,7 @@ def is_row_count_question(q: str) -> bool: {} Question: {} Write SQL using json_extract_string() with exact field names. Include doc_id, docnm_kwd for data queries. Only SQL.""".format( - table_name, - ", ".join(json_field_names), - "\n".join([f" - {field}" for field in json_field_names]), - question + table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question ) else: # Build ES/OS prompts with direct field access @@ -980,11 +1043,7 @@ def is_row_count_question(q: str) -> bool: Available fields: {} Question: {} -Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format( - table_name, - "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), - question - ) +Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question) tried_times = 0 @@ -1022,13 +1081,7 @@ async def repair_table_for_missing_source_columns(previous_sql): The previous SQL result is missing required source columns for citations. Rewrite SQL to keep the same query intent and include doc_id and {} in the SELECT list. For extracted JSON fields, use json_extract_string(chunk_data, '$.field_name'). -Return ONLY SQL.""".format( - table_name, - "\n".join([f" - {field}" for field in json_field_names]), - question, - previous_sql, - expected_doc_name_column - ) +Return ONLY SQL.""".format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, previous_sql, expected_doc_name_column) else: repair_prompt = """Table name: {} Available fields: @@ -1040,12 +1093,7 @@ async def repair_table_for_missing_source_columns(previous_sql): The previous SQL result is missing required source columns for citations. Rewrite SQL to keep the same query intent and include doc_id and docnm_kwd in the SELECT list. -Return ONLY SQL.""".format( - table_name, - "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), - question, - previous_sql - ) +Return ONLY SQL.""".format(table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question, previous_sql) return await get_table(custom_user_prompt=repair_prompt) try: @@ -1105,11 +1153,7 @@ async def repair_table_for_missing_source_columns(previous_sql): logging.warning(f"use_sql: Non-aggregate SQL missing required source columns; retrying once. SQL: {sql}") try: repaired_tbl, repaired_sql = await repair_table_for_missing_source_columns(sql) - if ( - repaired_tbl - and len(repaired_tbl.get("rows", [])) > 0 - and has_source_columns(repaired_tbl.get("columns", [])) - ): + if repaired_tbl and len(repaired_tbl.get("rows", [])) > 0 and has_source_columns(repaired_tbl.get("columns", [])): tbl, sql = repaired_tbl, repaired_sql logging.info(f"use_sql: Source-column SQL repair succeeded. SQL: {sql}") else: @@ -1121,11 +1165,12 @@ async def repair_table_for_missing_source_columns(previous_sql): docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"]) doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]]) + kb_id_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["kb_id", "kb_id_kwd"]]) logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}") - logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}") + logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}, kb_id_idx={kb_id_idx}") - column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] + column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx | kb_id_idx)] logging.debug(f"use_sql: column_idx={column_idx}") logging.debug(f"use_sql: field_map={field_map}") @@ -1137,9 +1182,9 @@ def map_column_name(col_name): # First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.) # Pattern: anything AS alias_name - as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE) + as_match = re.search(r"\s+AS\s+([^\s,)]+)", col_name, re.IGNORECASE) if as_match: - alias = as_match.group(1).strip('"\'') + alias = as_match.group(1).strip("\"'") # Use the alias for display name lookup if alias in field_map: @@ -1176,11 +1221,7 @@ def map_column_name(col_name): return result # compose Markdown table - columns = ( - "|" + "|".join( - [map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + ( - "|Source|" if docid_idx and doc_name_idx else "|") - ) + columns = "|" + "|".join([map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + ("|Source|" if docid_idx and doc_name_idx else "|") line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") @@ -1221,8 +1262,11 @@ def map_column_name(col_name): where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE) if where_match: where_clause = where_match.group(1).strip() - # Build a query to get doc_id and docnm_kwd with the same WHERE clause - chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}" + # Build a query to get source fields with the same WHERE clause. + # Single-KB queries can derive kb_id from the dialog, while multi-KB + # ES/OS queries need the row value for metadata enrichment. + chunks_kb_column = ", kb_id" if not (kb_ids and len(kb_ids) == 1) else "" + chunks_sql = f"select doc_id, {expected_doc_name_column}{chunks_kb_column} from {table_name} where {where_clause}" # Add LIMIT to avoid fetching too many chunks if "limit" not in chunks_sql.lower(): chunks_sql += " limit 20" @@ -1233,8 +1277,18 @@ def map_column_name(col_name): # Build chunks reference - use case-insensitive matching chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None) chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None) + chunks_kb_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["kb_id", "kb_id_kwd"]), None) if chunks_did_idx is not None and chunks_dn_idx is not None: - chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]] + chunks = [] + for r in chunks_tbl["rows"]: + chunk = {"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} + row_dict = {chunks_tbl["columns"][i]["name"]: r[i] for i in range(len(chunks_tbl["columns"])) if i < len(r)} + kb_id = _chunk_kb_id_for_doc(row_dict, kb_ids, chunk["doc_id"]) + if kb_id: + chunk["kb_id"] = kb_id + elif chunks_kb_idx is not None: + chunk["kb_id"] = r[chunks_kb_idx] + chunks.append(chunk) # Build doc_aggs doc_aggs = {} for r in chunks_tbl["rows"]: @@ -1264,7 +1318,22 @@ def map_column_name(col_name): result = { "answer": "\n".join([columns, line, rows]), "reference": { - "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], + "chunks": [ + { + key: value + for key, value in { + "doc_id": r[docid_idx], + "docnm_kwd": r[doc_name_idx], + "kb_id": _chunk_kb_id_for_doc( + {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)}, + kb_ids, + r[docid_idx], + ), + }.items() + if value + } + for r in tbl["rows"] + ], "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()], }, "prompt": sys_prompt, @@ -1272,6 +1341,7 @@ def map_column_name(col_name): logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents") return result + def clean_tts_text(text: str) -> str: if not text: return "" @@ -1281,15 +1351,7 @@ def clean_tts_text(text: str) -> str: text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text) emoji_pattern = re.compile( - "[\U0001F600-\U0001F64F" - "\U0001F300-\U0001F5FF" - "\U0001F680-\U0001F6FF" - "\U0001F1E0-\U0001F1FF" - "\U00002700-\U000027BF" - "\U0001F900-\U0001F9FF" - "\U0001FA70-\U0001FAFF" - "\U0001FAD0-\U0001FAFF]+", - flags=re.UNICODE + "[\U0001f600-\U0001f64f\U0001f300-\U0001f5ff\U0001f680-\U0001f6ff\U0001f1e0-\U0001f1ff\U00002700-\U000027bf\U0001f900-\U0001f9ff\U0001fa70-\U0001faff\U0001fad0-\U0001faff]+", flags=re.UNICODE ) text = emoji_pattern.sub("", text) @@ -1301,6 +1363,7 @@ def clean_tts_text(text: str) -> str: return text + def tts(tts_mdl, text): if not tts_mdl or not text: return None @@ -1328,18 +1391,31 @@ def __init__(self) -> None: self.buffer = "" +def _extract_visible_answer(text: str) -> str: + text = text or "" + if "" not in text: + return re.sub(r"", "", text) + + thought, answer = text.rsplit("", 1) + thought = re.sub(r"", "", thought).strip() + answer = re.sub(r"", "", answer) + if not thought: + return answer + return f"{thought}{answer}" + + def _next_think_delta(state: _ThinkStreamState) -> str: full_text = state.full_text if full_text == state.last_full: return "" state.last_full = full_text - delta_ans = full_text[state.last_idx:] + delta_ans = full_text[state.last_idx :] if delta_ans.find("") == 0: state.last_idx += len("") return "" if delta_ans.find("") > 0: - delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("")] + delta_text = full_text[state.last_idx : state.last_idx + delta_ans.find("")] state.last_idx += delta_ans.find("") return delta_text if delta_ans.endswith(""): @@ -1360,7 +1436,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16): if not chunk: continue if chunk.startswith(state.last_model_full): - new_part = chunk[len(state.last_model_full):] + new_part = chunk[len(state.last_model_full) :] state.last_model_full = chunk else: new_part = chunk @@ -1394,6 +1470,7 @@ async def _stream_with_think_delta(stream_iter, min_tokens: int = 16): if state.endswith_think: yield ("marker", "", state) + async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): doc_ids = search_config.get("doc_ids", []) rerank_mdl = None @@ -1401,6 +1478,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf chat_llm_name = search_config.get("chat_id", chat_llm_name) rerank_id = search_config.get("rerank_id", "") meta_data_filter = search_config.get("meta_data_filter") + include_reference_metadata, metadata_fields = _resolve_reference_metadata(search_config) kbs = KnowledgebaseService.get_by_ids(kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) @@ -1419,8 +1497,15 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf tenant_ids = list(set([kb.tenant_id for kb in kbs])) if meta_data_filter: - metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) - doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) + doc_ids = await apply_meta_data_filter( + meta_data_filter, + None, + question, + chat_mdl, + doc_ids, + kb_ids=kb_ids, + metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids), + ) kbinfos = await retriever.retrieval( question=question, @@ -1435,8 +1520,15 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf doc_ids=doc_ids, aggs=True, rerank_mdl=rerank_mdl, - rank_feature=label_question(question, kbs) + rank_feature=label_question(question, kbs), ) + if include_reference_metadata: + logging.debug( + "reference_metadata enrichment enabled for async_ask: chunk_count=%d metadata_fields=%s", + len(kbinfos.get("chunks", [])), + metadata_fields, + ) + _enrich_chunks_with_document_metadata(kbinfos.get("chunks", []), metadata_fields) knowledges = kb_prompt(kbinfos, max_tokens) sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges)) @@ -1445,8 +1537,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf def decorate_answer(answer): nonlocal knowledges, kbinfos, sys_prompt - answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], - embd_mdl, tkweight=0.7, vtweight=0.3) + answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: @@ -1472,7 +1563,7 @@ def decorate_answer(answer): continue yield {"answer": value, "reference": {}, "final": False} full_answer = last_state.full_text if last_state else "" - final = decorate_answer(full_answer) + final = decorate_answer(_extract_visible_answer(full_answer)) final["final"] = True yield final @@ -1505,8 +1596,15 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): rerank_mdl = LLMBundle(tenant_id, rerank_model_config) if meta_data_filter: - metas = DocMetadataService.get_flatted_meta_by_kbs(kb_ids) - doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) + doc_ids = await apply_meta_data_filter( + meta_data_filter, + None, + question, + chat_mdl, + doc_ids, + kb_ids=kb_ids, + metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids), + ) ranks = await settings.retriever.retrieval( question=question, diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py index 7a9e435e072..1cf887c2d3f 100644 --- a/api/db/services/doc_metadata_service.py +++ b/api/db/services/doc_metadata_service.py @@ -454,19 +454,27 @@ def update_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool: # Index exists - check if document exists try: doc_exists = settings.docStoreConn.get( - index_name=index_name, - id=doc_id, - kb_id=kb_id + doc_id, + index_name, + [kb_id] ) if doc_exists: - # Document exists - use partial update + # Document exists - replace meta_fields entirely + # Use upsert to fully replace the meta_fields field + # (ES update with doc parameter does deep merge on object fields, + # which would retain old keys that should be removed) settings.docStoreConn.es.update( index=index_name, id=doc_id, refresh=True, - doc={"meta_fields": processed_meta} + body={ + "script": { + "source": "ctx._source.meta_fields = params.meta_fields", + "params": {"meta_fields": processed_meta} + } + } ) - logging.debug(f"Successfully updated metadata for document {doc_id} using ES partial update") + logging.debug(f"Successfully updated metadata for document {doc_id} using ES script update") return True except Exception as e: logging.debug(f"Document {doc_id} not found in index, will insert: {e}") @@ -764,6 +772,140 @@ def get_flatted_meta_by_kbs(cls, kb_ids: List[str]) -> Dict: logging.error(f"Error getting flattened metadata for KBs {kb_ids}: {e}") return {} + @classmethod + def filter_doc_ids_by_meta_pushdown( + cls, + kb_ids: List[str], + filters: List[Dict], + logic: str = "and", + limit: int = 10000, + ) -> Optional[List[str]]: + """Run a metadata filter directly against ES, returning matching doc IDs. + + Returns ``None`` to signal "push-down not viable, use the in-memory + ``meta_filter`` fallback". Reasons for ``None``: + + - Active doc store is not Elasticsearch (Infinity / OceanBase have + different filter semantics for the JSON ``meta_fields`` column). + - One of the user filters cannot be expressed in ES DSL. + - The ES request itself failed (network, mapping, missing index). + + On success returns the deduplicated, ordered list of document IDs the + ES query matched. Callers can union or intersect this with their own + base ``doc_ids`` rather than fetching the entire metadata table. + """ + from common.metadata_es_filter import ( + UnsupportedMetaFilter, + build_meta_filter_query, + extract_doc_ids, + is_pushdown_supported, + ) + + if not kb_ids: + return [] + + if settings.DOC_ENGINE_INFINITY: + # Infinity stores ``meta_fields`` as a JSON column without dotted + # field access; the in-memory path is still the reliable answer. + return None + + es_client = getattr(settings.docStoreConn, "es", None) + if es_client is None: + return None + + if not is_pushdown_supported(filters): + return None + + try: + kb = Knowledgebase.get_by_id(kb_ids[0]) + except Exception as e: + logging.warning(f"[meta_pushdown] cannot resolve tenant for kb {kb_ids[0]}: {e}") + return None + if not kb: + return None + + tenant_id = kb.tenant_id + index_name = cls._get_doc_meta_index_name(tenant_id) + + try: + if not settings.docStoreConn.index_exist(index_name, ""): + # No metadata index → no metadata-filtered docs. Returning an + # empty list (rather than ``None``) so callers don't bounce + # back to the in-memory path and re-query MySQL for nothing. + return [] + except Exception as e: + logging.warning(f"[meta_pushdown] index_exist check failed for {index_name}: {e}") + return None + + try: + query_body = build_meta_filter_query(filters, logic, kb_ids) + except UnsupportedMetaFilter as e: + logging.debug(f"[meta_pushdown] falling back to in-memory: {e.reason}") + return None + + # Only the doc id is needed downstream; trimming ``_source`` keeps the + # response small when the metadata blob is large. + request_body = { + **query_body, + "size": limit, + "_source": ["id"], + } + + try: + response = es_client.search(index=index_name, body=request_body) + except Exception as e: + logging.warning(f"[meta_pushdown] ES query failed for {index_name}: {e}") + return None + + doc_ids = extract_doc_ids(response if isinstance(response, dict) else dict(response)) + # Preserve order while removing duplicates so caller-side de-dupe stays + # cheap. + seen: set[str] = set() + unique: List[str] = [] + for did in doc_ids: + if did in seen: + continue + seen.add(did) + unique.append(did) + + if len(unique) >= limit: + logging.warning( + f"[meta_pushdown] hit limit {limit} for KBs {kb_ids}; some matches may be missing" + ) + + logging.debug(f"[meta_pushdown] {len(unique)} matches for KBs {kb_ids}") + return unique + + @classmethod + def get_metadata_keys_by_kbs(cls, kb_ids: List[str]) -> List[str]: + """ + Get unique metadata field names across multiple knowledge bases. + + Args: + kb_ids: List of knowledge base IDs + + Returns: + Sorted list of unique metadata field names + """ + if not kb_ids: + return [] + + logging.debug(f"get_metadata_keys_by_kbs start: n_kbs={len(kb_ids)}") + keys: set[str] = set() + try: + for kb_id in kb_ids: + results = cls._search_metadata(kb_id, condition={"kb_id": kb_id}) + for _doc_id, doc in cls._iter_search_results(results): + doc_meta = cls._extract_metadata(doc) + if not isinstance(doc_meta, dict): + continue + keys.update(str(k) for k in doc_meta.keys()) + logging.debug(f"get_metadata_keys_by_kbs end: n_keys={len(keys)}, kb_ids={kb_ids}") + return sorted(keys) + except Exception as e: + logging.error(f"Error getting metadata keys for KBs {kb_ids}: {e}") + return [] + @classmethod def get_metadata_for_documents(cls, doc_ids: Optional[List[str]], kb_id: str) -> Dict[str, Dict]: """ diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 0c6e8b89195..7992cdb6105 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -13,15 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import asyncio -import json import logging import random -import re -from concurrent.futures import ThreadPoolExecutor -from copy import deepcopy from datetime import datetime -from io import BytesIO import xxhash from peewee import fn, Case, JOIN @@ -33,13 +27,15 @@ from api.db.services.common_service import CommonService, retry_deadlock_operation from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.doc_metadata_service import DocMetadataService + +from common import settings +from common.constants import ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME, MAXIMUM_TASK_PAGE_NUMBER +from common.doc_store.doc_store_base import OrderByExpr from common.misc_utils import get_uuid from common.time_utils import current_timestamp, get_format_time -from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME -from rag.nlp import rag_tokenizer, search + +from rag.nlp import search from rag.utils.redis_conn import REDIS_CONN -from common.doc_store.doc_store_base import OrderByExpr -from common import settings class DocumentService(CommonService): @@ -127,7 +123,7 @@ def check_doc_health(cls, tenant_id: str, filename): @classmethod @DB.connection_context() - def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_id=None, name=None, doc_ids_filter=None, return_empty_metadata=False): + def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, name=None, doc_ids=None, return_empty_metadata=False): fields = cls.get_cls_model_fields() if keywords: docs = ( @@ -147,10 +143,8 @@ def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keyword .join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER) .where(cls.model.kb_id == kb_id) ) - if doc_id: - docs = docs.where(cls.model.id == doc_id) - if doc_ids_filter: - docs = docs.where(cls.model.id.in_(doc_ids_filter)) + if doc_ids: + docs = docs.where(cls.model.id.in_(doc_ids)) if run_status: docs = docs.where(cls.model.run.in_(run_status)) if types: @@ -429,6 +423,9 @@ def remove_document(cls, doc, tenant_id): if not cls.delete_document_and_update_kb_counts(doc.id): return True + chunk_index_name = search.index_name(tenant_id) + chunk_index_exists = settings.docStoreConn.index_exist(chunk_index_name, doc.kb_id) + # Cancel all running tasks first Using preset function in task_service.py --- set cancel flag in Redis try: cancel_all_task_of(doc.id) @@ -444,7 +441,8 @@ def remove_document(cls, doc, tenant_id): # Delete chunk images (non-critical, log and continue) try: - cls.delete_chunk_images(doc, tenant_id) + if chunk_index_exists: + cls.delete_chunk_images(doc, tenant_id) except Exception as e: logging.warning(f"Failed to delete chunk images for document {doc.id}: {e}") @@ -458,7 +456,7 @@ def remove_document(cls, doc, tenant_id): # Delete chunks from doc store - this is critical, log errors try: - settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.delete({"doc_id": doc.id}, chunk_index_name, doc.kb_id) except Exception as e: logging.error(f"Failed to delete chunks from doc store for document {doc.id}: {e}") @@ -470,23 +468,24 @@ def remove_document(cls, doc, tenant_id): # Cleanup knowledge graph references (non-critical, log and continue) try: - graph_source = settings.docStoreConn.get_fields( - settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), - ["source_id"], - ) - if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: - settings.docStoreConn.update( - {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, - {"remove": {"source_id": doc.id}}, - search.index_name(tenant_id), - doc.kb_id, - ) - settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.delete( - {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, - search.index_name(tenant_id), - doc.kb_id, + if chunk_index_exists: + graph_source = settings.docStoreConn.get_fields( + settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, chunk_index_name, [doc.kb_id]), + ["source_id"], ) + if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: + settings.docStoreConn.update( + {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, + {"remove": {"source_id": doc.id}}, + chunk_index_name, + doc.kb_id, + ) + settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, chunk_index_name, doc.kb_id) + settings.docStoreConn.delete( + {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, + chunk_index_name, + doc.kb_id, + ) except Exception as e: logging.warning(f"Failed to cleanup knowledge graph for document {doc.id}: {e}") @@ -679,17 +678,10 @@ def get_tenant_id_by_name(cls, name): @classmethod @DB.connection_context() def accessible(cls, doc_id, user_id): - docs = ( - cls.model.select(cls.model.id) - .join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)) - .join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)) - .where(cls.model.id == doc_id, UserTenant.user_id == user_id) - .paginate(0, 1) - ) - docs = docs.dicts() - if not docs: + e, doc = cls.get_by_id(doc_id) + if not e: return False - return True + return KnowledgebaseService.accessible(doc.kb_id, user_id) @classmethod @DB.connection_context() @@ -1002,8 +994,8 @@ def new_task(): return { "id": get_uuid(), "doc_id": fake_doc_id, - "from_page": 100000000, - "to_page": 100000000, + "from_page": MAXIMUM_TASK_PAGE_NUMBER, + "to_page": MAXIMUM_TASK_PAGE_NUMBER, "task_type": ty, "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty, "begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), @@ -1027,138 +1019,3 @@ def get_queue_length(priority): if not group_info: return 0 return int(group_info.get("lag", 0) or 0) - - -def doc_upload_and_parse(conversation_id, file_objs, user_id): - from api.db.services.api_service import API4ConversationService - from api.db.services.conversation_service import ConversationService - from api.db.services.dialog_service import DialogService - from api.db.services.file_service import FileService - from api.db.services.llm_service import LLMBundle - from api.db.services.user_service import TenantService - from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type - from rag.app import audio, email, naive, picture, presentation - - e, conv = ConversationService.get_by_id(conversation_id) - if not e: - e, conv = API4ConversationService.get_by_id(conversation_id) - assert e, "Conversation not found!" - - e, dia = DialogService.get_by_id(conv.dialog_id) - if not dia.kb_ids: - raise LookupError("No dataset associated with this conversation. Please add a dataset before uploading documents") - kb_id = dia.kb_ids[0] - e, kb = KnowledgebaseService.get_by_id(kb_id) - if not e: - raise LookupError("Can't find this dataset!") - if kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) - else: - embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) - embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, lang=kb.language) - - err, files = FileService.upload_document(kb, file_objs, user_id) - assert not err, "\n".join(err) - - def dummy(prog=None, msg=""): - pass - - FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email} - parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0} - exe = ThreadPoolExecutor(max_workers=12) - threads = [] - doc_nm = {} - for d, blob in files: - doc_nm[d["id"]] = d["name"] - for d, blob in files: - kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language} - threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs)) - - for (docinfo, _), th in zip(files, threads): - docs = [] - doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]} - for ck in th.result(): - d = deepcopy(doc) - d.update(ck) - d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest() - d["create_time"] = str(datetime.now()).replace("T", " ")[:19] - d["create_timestamp_flt"] = datetime.now().timestamp() - if not d.get("image"): - docs.append(d) - continue - - output_buffer = BytesIO() - if isinstance(d["image"], bytes): - output_buffer = BytesIO(d["image"]) - else: - d["image"].save(output_buffer, format="JPEG") - - settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue()) - d["img_id"] = "{}-{}".format(kb.id, d["id"]) - d.pop("image", None) - docs.append(d) - - parser_ids = {d["id"]: d["parser_id"] for d, _ in files} - docids = [d["id"] for d, _ in files] - chunk_counts = {id: 0 for id in docids} - token_counts = {id: 0 for id in docids} - es_bulk_size = 64 - - def embedding(doc_id, cnts, batch_size=16): - nonlocal embd_mdl, chunk_counts, token_counts - vectors = [] - for i in range(0, len(cnts), batch_size): - vts, c = embd_mdl.encode(cnts[i : i + batch_size]) - vectors.extend(vts.tolist()) - chunk_counts[doc_id] += len(cnts[i : i + batch_size]) - token_counts[doc_id] += c - return vectors - - idxnm = search.index_name(kb.tenant_id) - try_create_idx = True - - _, tenant = TenantService.get_by_id(kb.tenant_id) - tenant_llm_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) - llm_bdl = LLMBundle(kb.tenant_id, tenant_llm_config) - for doc_id in docids: - cks = [c for c in docs if c["doc_id"] == doc_id] - - if parser_ids[doc_id] != ParserType.PICTURE.value: - from rag.graphrag.general.mind_map_extractor import MindMapExtractor - - mindmap = MindMapExtractor(llm_bdl) - try: - mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])) - mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) - if len(mind_map) < 32: - raise Exception("Few content: " + mind_map) - cks.append( - { - "id": get_uuid(), - "doc_id": doc_id, - "kb_id": [kb.id], - "docnm_kwd": doc_nm[doc_id], - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])), - "content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"), - "content_with_weight": mind_map, - "knowledge_graph_kwd": "mind_map", - } - ) - except Exception: - logging.exception("Mind map generation error") - - vectors = embedding(doc_id, [c["content_with_weight"] for c in cks]) - assert len(cks) == len(vectors) - for i, d in enumerate(cks): - v = vectors[i] - d["q_%d_vec" % len(v)] = v - for b in range(0, len(cks), es_bulk_size): - if try_create_idx: - if not settings.docStoreConn.index_exist(idxnm, kb_id): - settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id) - try_create_idx = False - settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id) - - DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) - - return [d["id"] for d, _ in files] diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 11940b88c21..db8ae4b72f5 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -23,17 +23,20 @@ from pathlib import Path from typing import Union +logger = logging.getLogger(__name__) + import xxhash from peewee import fn -from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileType +from api.db import KNOWLEDGEBASE_FOLDER_NAME, SKILLS_FOLDER_NAME, FileType from api.db.db_models import DB, Document, File, File2Document, Knowledgebase, Task from api.db.services import duplicate_name from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from common.misc_utils import get_uuid -from common.constants import TaskStatus, FileSource, ParserType +from common.ssrf_guard import assert_url_is_safe +from common.constants import TaskStatus, FileSource, ParserType, MAXIMUM_PAGE_NUMBER from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.task_service import TaskService from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path @@ -188,23 +191,24 @@ def get_all_file_ids_by_tenant_id(cls, tenant_id): @classmethod @DB.connection_context() - def create_folder(cls, file, parent_id, name, count): - from api.apps import current_user + def create_folder(cls, file, parent_id, name, count, tenant_id, created_by): # Recursively create folder structure # Args: # file: Current file object # parent_id: Parent folder ID # name: List of folder names to create # count: Current depth in creation + # tenant_id: Tenant ID + # created_by: Created by user ID # Returns: # Created file object if count > len(name) - 2: return file else: file = cls.insert( - {"id": get_uuid(), "parent_id": parent_id, "tenant_id": current_user.id, "created_by": current_user.id, "name": name[count], "location": "", "size": 0, "type": FileType.FOLDER.value} + {"id": get_uuid(), "parent_id": parent_id, "tenant_id": tenant_id, "created_by": created_by, "name": name[count], "location": "", "size": 0, "type": FileType.FOLDER.value} ) - return cls.create_folder(file, file.id, name, count + 1) + return cls.create_folder(file, file.id, name, count + 1, tenant_id, created_by) @classmethod @DB.connection_context() @@ -290,6 +294,28 @@ def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value cls.save(**file) return file + @classmethod + @DB.connection_context() + def init_skills_folder(cls, root_id, tenant_id): + # Initialize skills folder if not exists + # Args: + # root_id: Root folder ID + # tenant_id: Tenant ID + for _ in cls.model.select().where((cls.model.name == SKILLS_FOLDER_NAME) & (cls.model.parent_id == root_id)): + return + file_id = get_uuid() + file = { + "id": file_id, + "parent_id": root_id, + "tenant_id": tenant_id, + "created_by": tenant_id, + "name": SKILLS_FOLDER_NAME, + "type": FileType.FOLDER.value, + "size": 0, + "location": "", + } + cls.save(**file) + @classmethod @DB.connection_context() def init_knowledgebase_docs(cls, root_id, tenant_id): @@ -550,7 +576,7 @@ def dummy(prog=None, msg=""): FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email} parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": layout_recognize or "Plain Text"} - kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": current_user.id if current_user else tenant_id} + kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": MAXIMUM_PAGE_NUMBER, "tenant_id": current_user.id if current_user else tenant_id} file_type = filename_type(filename) if img_base64 and file_type == FileType.VISUAL.value: return GptV4.image2base64(blob) @@ -624,6 +650,26 @@ def delete_docs(cls, doc_ids, tenant_id): return errors + _ALLOWED_SCHEMES = {"http", "https"} + + @staticmethod + def _validate_url_for_crawl(url: str) -> tuple[str, str]: + """Raise ValueError if the URL is not safe to crawl (SSRF guard). + + Delegates to :func:`common.ssrf_guard.assert_url_is_safe`, which + validates the scheme, hostname, and every DNS-resolved address, and + returns ``(hostname, resolved_ip)`` for DNS pinning. + + Only the scheme and host (and port when present) are forwarded to the + guard so that credentials or query parameters in *url* are never + written to the log. + """ + from urllib.parse import urlparse + parsed = urlparse(url) + port_suffix = f":{parsed.port}" if parsed.port else "" + redacted = f"{parsed.scheme}://{parsed.hostname}{port_suffix}" + return assert_url_is_safe(redacted, allowed_schemes=FileService._ALLOWED_SCHEMES) + @staticmethod def upload_info(user_id, file, url: str|None=None): def structured(filename, filetype, blob, content_type): @@ -646,6 +692,53 @@ def structured(filename, filetype, blob, content_type): } if url: + import requests as _requests + from urllib.parse import urljoin as _urljoin + + _MAX_CRAWL_REDIRECTS = 10 + + # Pre-resolve the full redirect chain so that AsyncWebCrawler never + # follows a server-sent redirect to an unvalidated (potentially + # internal) host. Each hop is SSRF-checked before being followed; + # the validated (hostname, ip) pairs are pinned via Chromium's + # --host-resolver-rules so the browser cannot re-resolve any of them + # through a fresh DNS query. + current_url = url + current_hostname, current_ip = FileService._validate_url_for_crawl(current_url) + # Accumulate MAP rules for every hostname we encounter in the chain. + host_pins: dict[str, str] = {current_hostname: current_ip} + + for _ in range(_MAX_CRAWL_REDIRECTS): + try: + _resp = _requests.get( + current_url, + timeout=10, + allow_redirects=False, + ) + except _requests.RequestException as _exc: + raise ValueError(f"Failed to fetch {current_url!r}: {_exc}") from _exc + + if _resp.status_code not in (301, 302, 303, 307, 308): + break + + _location = _resp.headers.get("Location") + if not _location: + break + + _next_url = _urljoin(current_url, _location) + _next_hostname, _next_ip = FileService._validate_url_for_crawl(_next_url) + host_pins[_next_hostname] = _next_ip + current_url = _next_url + else: + raise ValueError( + f"Exceeded {_MAX_CRAWL_REDIRECTS} redirects fetching {url!r}" + ) + + # Build a single MAP rule string covering every validated hostname + # in the redirect chain. Chromium uses the pinned IP for each, + # skipping DNS entirely and eliminating the rebinding window. + _map_rules = ",".join(f"MAP {h} {ip}" for h, ip in host_pins.items()) + from crawl4ai import ( AsyncWebCrawler, BrowserConfig, @@ -659,6 +752,7 @@ async def adownload(): browser_config = BrowserConfig( headless=True, verbose=False, + extra_args=[f"--host-resolver-rules={_map_rules}"], ) async with AsyncWebCrawler(config=browser_config) as crawler: crawler_config = CrawlerRunConfig( @@ -668,8 +762,10 @@ async def adownload(): pdf=True, screenshot=False ) + # Use the final resolved URL so the browser starts at the + # redirect destination rather than re-following the chain. result: CrawlResult = await crawler.arun( - url=url, + url=current_url, config=crawler_config ) return result @@ -679,7 +775,7 @@ async def adownload(): filename += ".pdf" return structured(filename, "pdf", page.pdf, page.response_headers["content-type"]) - return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id) + return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"]) DocumentService.check_doc_health(user_id, file.filename) return structured(file.filename, filename_type(file.filename), file.read(), file.content_type) diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index c66d66a6821..a164287fa4e 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -18,7 +18,7 @@ from peewee import fn, JOIN from api.db import TenantPermission -from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas +from api.db.db_models import DB, Document, Knowledgebase, User, UserCanvas from api.db.services.common_service import CommonService from common.time_utils import current_timestamp, datetime_format from api.db.services import duplicate_name @@ -485,13 +485,21 @@ def accessible(cls, kb_id, user_id): # user_id: User ID # Returns: # Boolean indicating accessibility - docs = cls.model.select( - cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) - docs = docs.dicts() - if not docs: + e, kb = cls.get_by_id(kb_id) + if not e: return False - return True + + if kb.status != StatusEnum.VALID.value: + return False + + if kb.tenant_id == user_id: + return True + + if kb.permission != TenantPermission.TEAM.value: + return False + + joined_tenants = TenantService.get_joined_tenants_by_user_id(user_id) + return any(tenant["tenant_id"] == kb.tenant_id for tenant in joined_tenants) @classmethod @DB.connection_context() @@ -502,10 +510,10 @@ def get_kb_by_id(cls, kb_id, user_id): # user_id: User ID # Returns: # List containing dataset information - kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) - kbs = kbs.dicts() - return list(kbs) + e, kb = cls.get_by_id(kb_id) + if not e or not cls.accessible(kb_id, user_id): + return [] + return [kb.to_dict()] @classmethod @DB.connection_context() @@ -516,10 +524,11 @@ def get_kb_by_name(cls, kb_name, user_id): # user_id: User ID # Returns: # List containing dataset information - kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1) - kbs = kbs.dicts() - return list(kbs) + kbs = cls.query(name=kb_name, status=StatusEnum.VALID.value) + for kb in kbs: + if cls.accessible(kb.id, user_id): + return [kb.to_dict()] + return [] @classmethod @DB.connection_context() diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 6058c6b69f7..60090bb0409 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -94,7 +94,7 @@ def bind_tools(self, toolcall_session, tools): def encode(self, texts: list): if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.model_config["llm_name"], input={"texts": texts}) + generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode", model=self.model_config["llm_name"], input={"texts": texts}) safe_texts = [] for text in texts: @@ -119,7 +119,7 @@ def encode(self, texts: list): def encode_queries(self, query: str): if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode_queries", model=self.model_config["llm_name"], input={"query": query}) + generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode_queries", model=self.model_config["llm_name"], input={"query": query}) emd, used_tokens = self.mdl.encode_queries(query) if self.model_config["llm_factory"] == "Builtin": @@ -135,7 +135,7 @@ def encode_queries(self, query: str): def similarity(self, query: str, texts: list): if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts}) + generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="similarity", model=self.model_config["llm_name"], input={"query": query, "texts": texts}) sim, used_tokens = self.mdl.similarity(query, texts) if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): @@ -149,7 +149,7 @@ def similarity(self, query: str, texts: list): def describe(self, image, max_tokens=300): if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe", metadata={"model": self.model_config["llm_name"]}) + generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe", metadata={"model": self.model_config["llm_name"]}) txt, used_tokens = self.mdl.describe(image) if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): @@ -163,7 +163,7 @@ def describe(self, image, max_tokens=300): def describe_with_prompt(self, image, prompt): if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt}) + generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="describe_with_prompt", metadata={"model": self.model_config["llm_name"], "prompt": prompt}) txt, used_tokens = self.mdl.describe_with_prompt(image, prompt) if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): @@ -177,7 +177,7 @@ def describe_with_prompt(self, image, prompt): def transcription(self, audio): if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="transcription", metadata={"model": self.model_config["llm_name"]}) + generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="transcription", metadata={"model": self.model_config["llm_name"]}) txt, used_tokens = self.mdl.transcription(audio) if not TenantLLMService.increase_usage_by_id(self.model_config["id"], used_tokens): @@ -194,7 +194,7 @@ def stream_transcription(self, audio): supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription")) if supports_stream: if self.langfuse: - generation = self.langfuse.start_generation( + generation = self.langfuse.start_observation(as_type="generation", trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.model_config["llm_name"]}, @@ -228,7 +228,7 @@ def stream_transcription(self, audio): return if self.langfuse: - generation = self.langfuse.start_generation( + generation = self.langfuse.start_observation(as_type="generation", trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.model_config["llm_name"]}, @@ -253,7 +253,7 @@ def stream_transcription(self, audio): def tts(self, text: str) -> Generator[bytes, None, None]: if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text}) + generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="tts", input={"text": text}) for chunk in self.mdl.tts(text): if isinstance(chunk, int): @@ -376,7 +376,7 @@ async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kw generation = None if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history}) + generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat", model=self.model_config["llm_name"], input={"system": system, "history": history}) chat_partial = partial(base_fn, system, history, gen_conf) use_kwargs = self._clean_param(chat_partial, **kwargs) @@ -417,7 +417,7 @@ async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = generation = None if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}) + generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}) if stream_fn: chat_partial = partial(stream_fn, system, history, gen_conf) @@ -460,7 +460,7 @@ async def async_chat_streamly_delta(self, system: str, history: list, gen_conf: generation = None if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}) + generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="chat_streamly", model=self.model_config["llm_name"], input={"system": system, "history": history}) if stream_fn: chat_partial = partial(stream_fn, system, history, gen_conf) diff --git a/api/db/services/memory_service.py b/api/db/services/memory_service.py index d2433d01d0e..530fc5ad9ea 100644 --- a/api/db/services/memory_service.py +++ b/api/db/services/memory_service.py @@ -92,6 +92,11 @@ def get_by_filter(cls, filter_dict: dict, keywords: str, page: int = 1, page_siz memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)) if filter_dict.get("tenant_id"): memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"])) + if filter_dict.get("accessible_user_id"): + memories = memories.where( + (cls.model.tenant_id == filter_dict["accessible_user_id"]) | + (cls.model.permissions == "team") + ) if filter_dict.get("memory_type"): memory_type_int = calculate_memory_type(filter_dict["memory_type"]) memories = memories.where(cls.model.memory_type.bin_and(memory_type_int) > 0) diff --git a/api/db/services/pipeline_operation_log_service.py b/api/db/services/pipeline_operation_log_service.py index 344e2381b7e..ad90acb1f34 100644 --- a/api/db/services/pipeline_operation_log_service.py +++ b/api/db/services/pipeline_operation_log_service.py @@ -250,20 +250,16 @@ def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, des @DB.connection_context() def get_documents_info(cls, id): fields = [Document.id, Document.name, Document.progress, Document.kb_id] - return ( - cls.model.select(*fields) - .join(Document, on=(cls.model.document_id == Document.id)) - .where( - cls.model.id == id - ) - .dicts() - ) + return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where(cls.model.id == id).dicts() @classmethod @DB.connection_context() - def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None): + def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None, keywords=None): fields = cls.get_dataset_logs_fields() - logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID)) + if keywords: + logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID), (fn.LOWER(cls.model.document_name).contains(keywords.lower()))) + else: + logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID)) if operation_status: logs = logs.where(cls.model.operation_status.in_(operation_status)) diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 80817323076..640c8fbd25e 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -29,7 +29,7 @@ from api.db.services.document_service import DocumentService from common.misc_utils import get_uuid from common.time_utils import current_timestamp -from common.constants import StatusEnum, TaskStatus +from common.constants import StatusEnum, TaskStatus, MAXIMUM_PAGE_NUMBER, MAXIMUM_TASK_PAGE_NUMBER from deepdoc.parser.excel_parser import RAGFlowExcelParser from rag.utils.redis_conn import REDIS_CONN from common import settings @@ -37,6 +37,7 @@ CANVAS_DEBUG_DOC_ID = "dataflow_x" GRAPH_RAPTOR_FAKE_DOC_ID = "graph_raptor_x" +TASK_MAX_LOG_LENGTH = int(os.environ.get("TASK_MAX_LOG_LENGTH", 3000)) # TEXT MAX is 64 KiB bytes! def trim_header_by_lines(text: str, max_length) -> str: # Trim header text to maximum length while preserving line breaks @@ -320,7 +321,7 @@ def update_progress(cls, id, info): if os.environ.get("MACOS"): if info["progress_msg"]: - progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000) + progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], TASK_MAX_LOG_LENGTH) cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() if "progress" in info: prog = info["progress"] @@ -332,7 +333,7 @@ def update_progress(cls, id, info): else: with DB.lock("update_progress", -1): if info["progress_msg"]: - progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000) + progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], TASK_MAX_LOG_LENGTH) cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() if "progress" in info: prog = info["progress"] @@ -379,7 +380,7 @@ def new_task(): "doc_id": doc["id"], "progress": 0.0, "from_page": 0, - "to_page": 100000000, + "to_page": MAXIMUM_TASK_PAGE_NUMBER, "begin_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } @@ -395,8 +396,8 @@ def new_task(): if doc["parser_id"] == "paper": page_size = doc["parser_config"].get("task_page_size") or 22 if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC" or doc["parser_config"].get("toc_extraction", False): - page_size = 10 ** 9 - page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)] + page_size = MAXIMUM_TASK_PAGE_NUMBER + page_ranges = doc["parser_config"].get("pages") or [(1, MAXIMUM_PAGE_NUMBER)] for s, e in page_ranges: s -= 1 s = max(0, s) @@ -495,7 +496,7 @@ def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: return 0 task["chunk_ids"] = prev_task["chunk_ids"] task["progress"] = 1.0 - if "from_page" in task and "to_page" in task and int(task['to_page']) - int(task['from_page']) >= 10 ** 6: + if "from_page" in task and "to_page" in task and (int(task['to_page']) - int(task['from_page']) >= 10 ** 6 or (int(task['from_page']) == MAXIMUM_TASK_PAGE_NUMBER and int(task['to_page']) == MAXIMUM_TASK_PAGE_NUMBER)): task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): " else: task["progress_msg"] = "" @@ -530,7 +531,7 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE id=task_id, doc_id=doc_id, from_page=0, - to_page=100000000, + to_page=MAXIMUM_TASK_PAGE_NUMBER, task_type="dataflow" if not rerun else "dataflow_rerun", priority=priority, begin_at= datetime.now().strftime("%Y-%m-%d %H:%M:%S"), diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py index a27f1352d44..ee2eab6648a 100644 --- a/api/db/services/tenant_llm_service.py +++ b/api/db/services/tenant_llm_service.py @@ -19,7 +19,7 @@ from peewee import IntegrityError from langfuse import Langfuse from common import settings -from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, PADDLEOCR_DEFAULT_CONFIG, PADDLEOCR_ENV_KEYS, LLMType +from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, OPENDATALOADER_DEFAULT_CONFIG, OPENDATALOADER_ENV_KEYS, PADDLEOCR_DEFAULT_CONFIG, PADDLEOCR_ENV_KEYS, LLMType from api.db.db_models import DB, LLMFactories, TenantLLM from api.db.services.common_service import CommonService from api.db.services.langfuse_service import TenantLangfuseService @@ -34,6 +34,42 @@ class LLMFactoriesService(CommonService): class TenantLLMService(CommonService): model = TenantLLM + @staticmethod + def _decode_api_key_config(raw_api_key: str) -> tuple[str, bool | None, str | None]: + if not raw_api_key: + return raw_api_key, None, None + + try: + parsed = json.loads(raw_api_key) + except Exception: + return raw_api_key, None, None + + if not isinstance(parsed, dict): + return raw_api_key, None, None + + is_tools = bool(parsed["is_tools"]) if "is_tools" in parsed else None + if set(parsed.keys()) <= {"api_key", "is_tools"}: + return parsed.get("api_key", ""), is_tools, None + + return parsed.get("api_key", raw_api_key), is_tools, raw_api_key + + @staticmethod + def _encode_api_key_config(raw_api_key: str, is_tools: bool | None) -> str: + if is_tools is None: + return raw_api_key + + try: + parsed = json.loads(raw_api_key or "{}") + except Exception: + parsed = None + + if isinstance(parsed, dict): + payload = dict(parsed) + payload["is_tools"] = bool(is_tools) + return json.dumps(payload) + + return json.dumps({"api_key": raw_api_key or "", "is_tools": bool(is_tools)}) + @classmethod @DB.connection_context() def get_api_key(cls, tenant_id, model_name, model_type=None): @@ -123,6 +159,12 @@ def get_model_config(cls, tenant_id, llm_type, llm_name=None): model_config = cls.get_api_key(tenant_id, mdlnm, llm_type) if model_config: model_config = model_config.to_dict() + api_key, is_tools, api_key_payload = cls._decode_api_key_config(model_config.get("api_key", "")) + model_config["api_key"] = api_key + if api_key_payload is not None: + model_config["api_key_payload"] = api_key_payload + if is_tools is not None: + model_config["is_tools"] = is_tools elif llm_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv("TEI_MODEL", ""): embedding_cfg = settings.EMBEDDING_CFG model_config = {"llm_factory": "Builtin", "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]} @@ -132,7 +174,7 @@ def get_model_config(cls, tenant_id, llm_type, llm_name=None): llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid) if not llm and fid: # for some cases seems fid mismatch llm = LLMService.query(llm_name=mdlnm) - if llm: + if "is_tools" not in model_config and llm: model_config["is_tools"] = llm[0].is_tools return model_config @@ -142,35 +184,36 @@ def model_instance(cls, model_config: dict, lang="Chinese", **kwargs): if not model_config: raise LookupError("Model config is required") kwargs.update({"provider": model_config["llm_factory"]}) + api_key = model_config.get("api_key_payload", model_config["api_key"]) if model_config["model_type"] == LLMType.EMBEDDING.value: if model_config["llm_factory"] not in EmbeddingModel: return None - return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) + return EmbeddingModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"]) elif model_config["model_type"] == LLMType.RERANK: if model_config["llm_factory"] not in RerankModel: return None - return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) + return RerankModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"]) elif model_config["model_type"] == LLMType.IMAGE2TEXT.value: if model_config["llm_factory"] not in CvModel: return None - return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) + return CvModel[model_config["llm_factory"]](api_key, model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) elif model_config["model_type"] == LLMType.CHAT.value: if model_config["llm_factory"] not in ChatModel: return None - return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs) + return ChatModel[model_config["llm_factory"]](api_key, model_config["llm_name"], base_url=model_config["api_base"], **kwargs) elif model_config["model_type"] == LLMType.SPEECH2TEXT: if model_config["llm_factory"] not in Seq2txtModel: return None - return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) + return Seq2txtModel[model_config["llm_factory"]](key=api_key, model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) elif model_config["model_type"] == LLMType.TTS: if model_config["llm_factory"] not in TTSModel: return None return TTSModel[model_config["llm_factory"]]( - model_config["api_key"], + api_key, model_config["llm_name"], base_url=model_config["api_base"], ) @@ -179,7 +222,7 @@ def model_instance(cls, model_config: dict, lang="Chinese", **kwargs): if model_config["llm_factory"] not in OcrModel: return None return OcrModel[model_config["llm_factory"]]( - key=model_config["api_key"], + key=api_key, model_name=model_config["llm_name"], base_url=model_config.get("api_base", ""), **kwargs, @@ -364,6 +407,67 @@ def _parse_api_key(raw: str) -> dict: idx += 1 continue + @classmethod + def _collect_opendataloader_env_config(cls) -> dict | None: + cfg = dict(OPENDATALOADER_DEFAULT_CONFIG) + found = False + for key in OPENDATALOADER_ENV_KEYS: + val = os.environ.get(key) + if val: + found = True + cfg[key] = val + return cfg if found else None + + @classmethod + @DB.connection_context() + def ensure_opendataloader_from_env(cls, tenant_id: str) -> str | None: + """ + Ensure an OpenDataLoader OCR model exists for the tenant if env variables are present. + Return the existing or newly created llm_name, or None if env not set. + """ + cfg = cls._collect_opendataloader_env_config() + if not cfg: + return None + + saved_models = cls.query(tenant_id=tenant_id, llm_factory="OpenDataLoader", model_type=LLMType.OCR.value) + + def _parse_api_key(raw: str) -> dict: + try: + return json.loads(raw or "{}") + except Exception: + return {} + + for item in saved_models: + api_cfg = _parse_api_key(item.api_key) + normalized = {k: api_cfg.get(k, OPENDATALOADER_DEFAULT_CONFIG.get(k)) for k in OPENDATALOADER_ENV_KEYS} + if normalized == cfg: + return item.llm_name + + used_names = {item.llm_name for item in saved_models} + idx = 1 + base_name = "opendataloader-from-env" + while True: + candidate = f"{base_name}-{idx}" + if candidate in used_names: + idx += 1 + continue + try: + cls.save( + tenant_id=tenant_id, + llm_factory="OpenDataLoader", + llm_name=candidate, + model_type=LLMType.OCR.value, + api_key=json.dumps(cfg), + api_base="", + max_tokens=0, + ) + return candidate + except IntegrityError: + logging.warning("OpenDataLoader env model %s already exists for tenant %s, retry with next name", candidate, tenant_id) + used_names.add(candidate) + idx += 1 + continue + @classmethod @DB.connection_context() def delete_by_tenant_id(cls, tenant_id): @@ -397,7 +501,7 @@ def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs) self.llm_name = model_config["llm_name"] self.model_config = model_config self.mdl = TenantLLMService.model_instance(model_config, lang=lang, **kwargs) - assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, model_config["llm_type"], model_config["llm_name"]) + assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, model_config["model_type"], model_config["llm_name"]) self.max_length = model_config.get("max_tokens", 8192) self.is_tools = model_config.get("is_tools", False) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index fe6f6d0d445..a041ee0819f 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -325,7 +325,7 @@ async def wrapper(*args, **kwargs): from common import settings from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer try: - jwt = Serializer(secret_key=settings.SECRET_KEY) + jwt = Serializer(secret_key=settings.get_secret_key()) raw_token = str(jwt.loads(token)) user = UserService.query(access_token=raw_token, status=StatusEnum.VALID.value) if user: diff --git a/api/utils/health_utils.py b/api/utils/health_utils.py index 288eb79ff67..34f098b8c92 100644 --- a/api/utils/health_utils.py +++ b/api/utils/health_utils.py @@ -293,7 +293,7 @@ def check_ragflow_server_alive(): url = f'http://{settings.HOST_IP}:{settings.HOST_PORT}/api/v1/system/ping' if '0.0.0.0' in url: url = url.replace('0.0.0.0', '127.0.0.1') - response = requests.get(url) + response = requests.get(url, timeout=10) if response.status_code == 200: return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."} else: diff --git a/api/utils/reference_metadata_utils.py b/api/utils/reference_metadata_utils.py new file mode 100644 index 00000000000..58d5beffb0a --- /dev/null +++ b/api/utils/reference_metadata_utils.py @@ -0,0 +1,125 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging + +logger = logging.getLogger(__name__) + + +def resolve_reference_metadata_preferences( + request_payload: dict | None = None, + config_payload: dict | None = None, +) -> tuple[bool, set[str] | None]: + """ + Resolve metadata include/fields from request and optional config. + Request values take precedence over config values. + Supports legacy request keys: include_metadata / metadata_fields. + """ + request_payload = request_payload or {} + config_payload = config_payload or {} + + config_ref = config_payload.get("reference_metadata", {}) + request_ref = request_payload.get("reference_metadata", {}) + + resolved: dict = {} + if isinstance(config_ref, dict): + resolved.update(config_ref) + if isinstance(request_ref, dict): + resolved.update(request_ref) + + if "include_metadata" in request_payload: + resolved["include"] = bool(request_payload.get("include_metadata")) + if "metadata_fields" in request_payload: + resolved["fields"] = request_payload.get("metadata_fields") + + include_metadata = bool(resolved.get("include", False)) + fields = resolved.get("fields") + if fields is None: + return include_metadata, None + if not isinstance(fields, list): + logger.warning( + "reference_metadata.fields is not a list; include_metadata=%s fields=%r type=%s resolved=%r. " + "enrich_chunks_with_document_metadata will skip enrichment.", + include_metadata, + fields, + type(fields).__name__, + resolved, + ) + return include_metadata, set() + return include_metadata, {f for f in fields if isinstance(f, str)} + + +def enrich_chunks_with_document_metadata( + chunks: list[dict], + metadata_fields: set[str] | None = None, + *, + kb_field: str = "kb_id", + doc_field: str = "doc_id", + output_field: str = "document_metadata", +) -> None: + """ + Mutates chunk payloads in-place by attaching `document_metadata`. + Field names can be customized for different chunk schemas. + """ + if metadata_fields is not None and not metadata_fields: + return + + doc_ids_by_kb: dict[str, set[str]] = {} + for chunk in chunks: + kb_ids = chunk.get(kb_field) + doc_id = chunk.get(doc_field) + if not kb_ids or not doc_id: + continue + if isinstance(kb_ids, (list, tuple)): + for kid in kb_ids: + if kid: + doc_ids_by_kb.setdefault(kid, set()).add(doc_id) + else: + doc_ids_by_kb.setdefault(kb_ids, set()).add(doc_id) + + if not doc_ids_by_kb: + return + + # Resolve service lazily so callers/tests that swap service modules at runtime + # (e.g. via monkeypatch) don't get stuck with a stale class reference. + from api.db.services.doc_metadata_service import DocMetadataService + metadata_getter = getattr(DocMetadataService, "get_metadata_for_documents", None) + if not callable(metadata_getter): + logging.warning( + "DocMetadataService.get_metadata_for_documents is unavailable; " + "skipping metadata enrichment." + ) + return + + meta_by_doc: dict[str, dict] = {} + for kb_id, doc_ids in doc_ids_by_kb.items(): + meta_map = metadata_getter(list(doc_ids), kb_id) + if meta_map: + meta_by_doc.update(meta_map) + logging.debug("Fetched metadata for %d docs in kb_id=%s", len(meta_map), kb_id) + + for chunk in chunks: + doc_id = chunk.get(doc_field) + if not doc_id: + continue + meta = meta_by_doc.get(doc_id) + if not meta: + continue + if metadata_fields is not None: + meta = {k: v for k, v in meta.items() if k in metadata_fields} + if meta: + chunk[output_field] = meta + logging.debug("Enriched chunk for doc_id=%s with %d metadata fields: %s", doc_id, len(meta), list(meta.keys())) diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index acce4926277..94e0fa2ab83 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging import math import pathlib import re @@ -22,16 +23,7 @@ from uuid import UUID from quart import Request -from pydantic import ( - BaseModel, - ConfigDict, - Field, - StringConstraints, - ValidationError, - field_validator, - model_validator, - ValidationInfo -) +from pydantic import BaseModel, ConfigDict, Field, StringConstraints, ValidationError, field_validator, model_validator, ValidationInfo from pydantic_core import PydanticCustomError from werkzeug.exceptions import BadRequest, UnsupportedMediaType @@ -170,12 +162,13 @@ def validate_and_parse_request_args(request: Request, validator: type[BaseModel] args = request.args.to_dict(flat=True) # Handle ext parameter: parse JSON string to dict if it's a string - if 'ext' in args and isinstance(args['ext'], str): + if "ext" in args and isinstance(args["ext"], str): import json + try: - args['ext'] = json.loads(args['ext']) + args["ext"] = json.loads(args["ext"]) except json.JSONDecodeError: - pass # Keep the string and let validation handle the error + logging.debug("Failed to decode query arg 'ext' as JSON; passing raw value to validator") try: if extras is not None: @@ -350,6 +343,7 @@ class RaptorConfig(Base): threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)] max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)] random_seed: Annotated[int, Field(default=0, ge=0)] + scope: Annotated[Literal["file", "dataset"], Field(default="file")] auto_disable_for_structured_data: Annotated[bool, Field(default=True)] ext: Annotated[dict, Field(default={})] @@ -370,18 +364,17 @@ class ParentChildConfig(Base): class AutoMetadataField(Base): """Schema for a single auto-metadata field configuration.""" - name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(...)] - type: Annotated[Literal["string", "list", "time"], Field(...)] + key: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(...)] + type: Annotated[Literal["string", "list", "time", "number"], Field(...)] description: Annotated[str | None, Field(default=None, max_length=65535)] - examples: Annotated[list[str] | None, Field(default=None)] - restrict_values: Annotated[bool, Field(default=False)] + enum: Annotated[list[str] | None, Field(default=None)] class AutoMetadataConfig(Base): """Top-level auto-metadata configuration attached to a dataset.""" - enabled: Annotated[bool, Field(default=True)] - fields: Annotated[list[AutoMetadataField], Field(default_factory=list)] + metadata: Annotated[list[AutoMetadataField], Field(default_factory=list)] + built_in_metadata: Annotated[list[AutoMetadataField], Field(default_factory=list)] class ParserConfig(Base): @@ -401,6 +394,7 @@ class ParserConfig(Base): pages: Annotated[list[list[int]] | None, Field(default=None)] ext: Annotated[dict, Field(default={})] + class UpdateDocumentReq(Base): """ Request model for updating a document. @@ -408,9 +402,11 @@ class UpdateDocumentReq(Base): This model validates the request parameters for updating a document, including name, chunk method, enabled status, and other metadata. """ - model_config = ConfigDict(extra='ignore') + + model_config = ConfigDict(extra="ignore") name: Annotated[str | None, Field(default=None, max_length=65535)] chunk_method: Annotated[str | None, Field(default=None, max_length=65535)] + pipeline_id: Annotated[str | None, Field(default=None, max_length=65535)] enabled: Annotated[int | None, Field(default=None, ge=0, le=1)] chunk_count: Annotated[int | None, Field(default=None, ge=0)] token_count: Annotated[int | None, Field(default=None, ge=0)] @@ -425,7 +421,7 @@ def validate_document_chunk_method(cls, chunk_method: str | None): # Validate chunk method if present valid_chunk_method = {"naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "knowledge_graph", "email", "tag"} if chunk_method not in valid_chunk_method: - raise PydanticCustomError("format_invalid", "`chunk_method` {chunk_method} doesn't exist", {"chunk_method":chunk_method}) + raise PydanticCustomError("format_invalid", "`chunk_method` {chunk_method} doesn't exist", {"chunk_method": chunk_method}) return chunk_method @@ -435,7 +431,7 @@ def validate_document_enabled(cls, enabled: str | None): if enabled: converted = int(enabled) if converted < 0 or converted > 1: - raise PydanticCustomError("format_invalid", "`enabled` value invalid, only accept 0 or 1 but is {enabled}", {"enabled":enabled}) + raise PydanticCustomError("format_invalid", "`enabled` value invalid, only accept 0 or 1 but is {enabled}", {"enabled": enabled}) return enabled @@ -450,11 +446,12 @@ def validate_document_meta_fields(cls, meta_fields: dict | None): for k, v in meta_fields.items(): if isinstance(v, list): if not all(isinstance(i, (str, int, float)) for i in v): - raise PydanticCustomError("format_invalid", "The type is not supported in list: {v}", {"v":v}) + raise PydanticCustomError("format_invalid", "The type is not supported in list: {v}", {"v": v}) elif not isinstance(v, (str, int, float)): - raise PydanticCustomError("format_invalid", "The type is not supported: {v}", {"v":v}) + raise PydanticCustomError("format_invalid", "The type is not supported: {v}", {"v": v}) return meta_fields + class CreateDatasetReq(Base): name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] avatar: Annotated[str | None, Field(default=None, max_length=65535)] @@ -707,8 +704,7 @@ def validate_parser_dependency(self) -> "CreateDatasetReq": @classmethod def validate_chunk_method(cls, v: Any, handler, info: ValidationInfo) -> Any: """Wrap validation to unify error messages, including type errors (e.g. list).""" - allowed = {"naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", - "tag", "resume"} + allowed = {"naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag", "resume"} error_msg = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table', 'tag' or 'resume'" try: # Run inner validation (type checking) @@ -818,6 +814,70 @@ def validate_ids(cls, v_list: list[str] | None) -> list[str] | None: class DeleteDatasetReq(DeleteReq): ... +class DeleteDocumentReq(DeleteReq): + @field_validator("ids", mode="after") + @classmethod + def validate_ids(cls, v_list: list[str] | None) -> list[str] | None: + """ + Validate document IDs without enforcing UUIDv1. + + Connector-backed documents can use non-UUID identifiers, so we only + enforce uniqueness here and leave existence checks to the delete API. + """ + if v_list is None: + return None + + duplicates = [item for item, count in Counter(v_list).items() if count > 1] + if duplicates: + duplicates_str = ", ".join(duplicates) + raise PydanticCustomError( + "duplicate_uuids", + "Duplicate ids: '{duplicate_ids}'", + {"duplicate_ids": duplicates_str}, + ) + + return v_list + + +class SearchDatasetReq(BaseModel): + model_config = ConfigDict(extra="ignore") + + question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)] + doc_ids: Annotated[list[str], Field(default=[])] + page: Annotated[int, Field(default=1, ge=1)] + size: Annotated[int, Field(default=30, ge=1)] + top_k: Annotated[int, Field(default=1024, ge=1)] + similarity_threshold: Annotated[float, Field(default=0.0, ge=0.0, le=1.0)] + vector_similarity_weight: Annotated[float, Field(default=0.3, ge=0.0, le=1.0)] + use_kg: Annotated[bool, Field(default=False)] + cross_languages: Annotated[list[str], Field(default=[])] + keyword: Annotated[bool, Field(default=False)] + search_id: Annotated[str | None, Field(default=None)] + rerank_id: Annotated[str | None, Field(default=None)] + tenant_rerank_id: Annotated[int | None, Field(default=None)] + meta_data_filter: Annotated[dict | None, Field(default=None)] + + +class SearchDatasetsReq(BaseModel): + model_config = ConfigDict(extra="ignore") + + dataset_ids: Annotated[list[str], Field(..., min_length=1)] + question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)] + doc_ids: Annotated[list[str], Field(default=[])] + page: Annotated[int, Field(default=1, ge=1)] + size: Annotated[int, Field(default=30, ge=1)] + top_k: Annotated[int, Field(default=1024, ge=1)] + similarity_threshold: Annotated[float, Field(default=0.0, ge=0.0, le=1.0)] + vector_similarity_weight: Annotated[float, Field(default=0.3, ge=0.0, le=1.0)] + use_kg: Annotated[bool, Field(default=False)] + cross_languages: Annotated[list[str], Field(default=[])] + keyword: Annotated[bool, Field(default=False)] + search_id: Annotated[str | None, Field(default=None)] + rerank_id: Annotated[str | None, Field(default=None)] + tenant_rerank_id: Annotated[str | None, Field(default=None)] + meta_data_filter: Annotated[dict | None, Field(default=None)] + + class BaseListReq(BaseModel): model_config = ConfigDict(extra="forbid") @@ -841,6 +901,7 @@ class ListDatasetReq(BaseListReq): # ---- File Management Request Models ---- + class CreateFolderReq(Base): name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(...)] parent_id: Annotated[str | None, Field(default=None)] @@ -856,7 +917,7 @@ class MoveFileReq(Base): dest_file_id: Annotated[str | None, Field(default=None)] new_name: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(default=None)] - @model_validator(mode='after') + @model_validator(mode="after") def check_operation(self): if not self.dest_file_id and not self.new_name: raise ValueError("At least one of dest_file_id or new_name must be provided") @@ -876,7 +937,7 @@ class ListFileReq(BaseModel): desc: Annotated[bool, Field(default=True)] -def validate_immutable_fields(update_doc_req:UpdateDocumentReq, doc): +def validate_immutable_fields(update_doc_req: UpdateDocumentReq, doc): """ Validate that immutable fields have not been changed. @@ -906,7 +967,7 @@ def validate_immutable_fields(update_doc_req:UpdateDocumentReq, doc): return None, None -def validate_document_name(req_doc_name:str, doc, docs_from_name): +def validate_document_name(req_doc_name: str, doc, docs_from_name): """ Validate document name update. @@ -937,6 +998,7 @@ def validate_document_name(req_doc_name:str, doc, docs_from_name): return "Duplicated document name in the same dataset.", RetCode.DATA_ERROR return None, None + def validate_chunk_method(doc, chunk_method=None): """ Validate chunk method update. @@ -952,9 +1014,8 @@ def validate_chunk_method(doc, chunk_method=None): A tuple of (error_message, error_code) if validation fails, or (None, None) if validation passes. """ - if chunk_method is not None and len(chunk_method) == 0: # will not be detected in UpdateDocumentReq + if chunk_method is not None and len(chunk_method) == 0: # will not be detected in UpdateDocumentReq return "`chunk_method` (empty string) is not valid", RetCode.DATA_ERROR if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name): return "Not supported yet!", RetCode.DATA_ERROR return None, None - diff --git a/api/utils/web_utils.py b/api/utils/web_utils.py index 4cb13ff7e6f..23d2421862d 100644 --- a/api/utils/web_utils.py +++ b/api/utils/web_utils.py @@ -15,11 +15,8 @@ # import base64 -import ipaddress import json import re -import socket -from urllib.parse import urlparse import aiosmtplib from email.mime.text import MIMEText from email.header import Header @@ -37,10 +34,10 @@ OTP_LENGTH = 4 -OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes -ATTEMPT_LIMIT = 5 # maximum attempts -ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes -RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute +OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes +ATTEMPT_LIMIT = 5 # maximum attempts +ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes +RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute CONTENT_TYPE_MAP = { @@ -188,29 +185,16 @@ def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_opt return base64.b64decode(result["data"]) -def is_private_ip(ip: str) -> bool: - try: - ip_obj = ipaddress.ip_address(ip) - return ip_obj.is_private - except ValueError: - return False - - def is_valid_url(url: str) -> bool: if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url): return False - parsed_url = urlparse(url) - hostname = parsed_url.hostname + from common.ssrf_guard import assert_url_is_safe - if not hostname: - return False try: - ip = socket.gethostbyname(hostname) - if is_private_ip(ip): - return False - except socket.gaierror: + assert_url_is_safe(url) + return True + except ValueError: return False - return True def safe_json_parse(data: str | dict) -> dict: diff --git a/cmd/admin_server.go b/cmd/admin_server.go index 9e876639164..3775d038b72 100644 --- a/cmd/admin_server.go +++ b/cmd/admin_server.go @@ -18,12 +18,14 @@ package main import ( "context" + "errors" "flag" "fmt" "net/http" "os" "os/signal" "ragflow/internal/cache" + "ragflow/internal/common" "ragflow/internal/engine" "syscall" "time" @@ -33,33 +35,23 @@ import ( "ragflow/internal/admin" "ragflow/internal/dao" - "ragflow/internal/logger" "ragflow/internal/server" "ragflow/internal/utility" ) -// AdminServer admin server -type AdminServer struct { - router *admin.Router - handler *admin.Handler - service *admin.Service - engine *gin.Engine - port string -} - func main() { var configPath string flag.StringVar(&configPath, "config", "", "Path to configuration file") flag.Parse() // Initialize logger - if err := logger.Init("info"); err != nil { + if err := common.Init("info"); err != nil { panic("failed to initialize logger: " + err.Error()) } // Initialize configuration if err := server.Init(configPath); err != nil { - logger.Error("Failed to initialize configuration", err) + common.Error("Failed to initialize configuration", err) os.Exit(1) } @@ -67,15 +59,15 @@ func main() { // Reinitialize logger with configured level if different if cfg.Log.Level != "" && cfg.Log.Level != "info" { - if err := logger.Init(cfg.Log.Level); err != nil { - logger.Error("Failed to reinitialize logger with configured level", err) + if err := common.Init(cfg.Log.Level); err != nil { + common.Error("Failed to reinitialize logger with configured level", err) } } // Set logger for server package - server.SetLogger(logger.Logger) + server.SetLogger(common.Logger) - logger.Info("Server mode", zap.String("mode", cfg.Server.Mode)) + common.Info("Server mode", zap.String("mode", cfg.Server.Mode)) // Set Gin mode if cfg.Server.Mode == "release" { @@ -86,26 +78,26 @@ func main() { // Initialize database if err := dao.InitDB(); err != nil { - logger.Error("Failed to initialize database", err) + common.Error("Failed to initialize database", err) os.Exit(1) } // Initialize doc engine if err := engine.Init(&cfg.DocEngine); err != nil { - logger.Fatal("Failed to initialize doc engine", zap.Error(err)) + common.Fatal("Failed to initialize doc engine", zap.Error(err)) } defer engine.Close() // Initialize Redis cache if err := cache.Init(&cfg.Redis); err != nil { - logger.Fatal("Failed to initialize Redis", zap.Error(err)) + common.Fatal("Failed to initialize Redis", zap.Error(err)) } defer cache.Close() // Initialize server variables (runtime variables that can change during operation) // This must be done after Cache is initialized if err := server.InitVariables(cache.Get()); err != nil { - logger.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error())) + common.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error())) } adminService := admin.NewService() @@ -113,7 +105,7 @@ func main() { // Initialize default admin user if err := adminService.InitDefaultAdmin(); err != nil { - logger.Error("Failed to initialize default admin user", err) + common.Error("Failed to initialize default admin user", err) } // Initialize router @@ -129,7 +121,7 @@ func main() { ginEngine.Use(gin.Recovery()) // Log request URL for every request ginEngine.Use(func(c *gin.Context) { - logger.Info("HTTP Request", zap.String("url", c.Request.URL.String()), zap.String("method", c.Request.Method)) + common.Info("HTTP Request", zap.String("url", c.Request.URL.String()), zap.String("method", c.Request.Method)) c.Next() }) @@ -144,13 +136,13 @@ func main() { } // Print RAGFlow version - logger.Info("RAGFlow version", zap.String("version", utility.GetRAGFlowVersion())) + common.Info("RAGFlow version", zap.String("version", utility.GetRAGFlowVersion())) // Print all configuration settings server.PrintAll() // Print RAGFlow Admin logo - logger.Info("" + + common.Info("" + "\n ____ ___ ______________ ___ __ _ \n" + " / __ \\/ | / ____/ ____/ /___ _ __ / | ____/ /___ ___ (_)___ \n" + " / /_/ / /| |/ / __/ /_ / / __ \\ | /| / / / /| |/ __ / __ `__ \\/ / __ \\ \n" + @@ -159,10 +151,10 @@ func main() { // Start server in a goroutine go func() { - logger.Info(fmt.Sprintf("Admin Go Version: %s", utility.GetRAGFlowVersion())) - logger.Info(fmt.Sprintf("Starting RAGFlow admin server on port: %d", cfg.Admin.Port)) - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.Fatal("Failed to start server", zap.Error(err)) + common.Info(fmt.Sprintf("Admin Go Version: %s", utility.GetRAGFlowVersion())) + common.Info(fmt.Sprintf("Starting RAGFlow admin server on port: %d", cfg.Admin.Port)) + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + common.Fatal("Failed to start server", zap.Error(err)) } }() @@ -171,8 +163,8 @@ func main() { signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2) sig := <-quit - logger.Info("Received signal", zap.String("signal", sig.String())) - logger.Info("Shutting down server...") + common.Info("Received signal", zap.String("signal", sig.String())) + common.Info("Shutting down server...") // Create context with timeout for graceful shutdown ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -180,8 +172,8 @@ func main() { // Shutdown server if err := srv.Shutdown(ctx); err != nil { - logger.Fatal("Server forced to shutdown", zap.Error(err)) + common.Fatal("Server forced to shutdown", zap.Error(err)) } - logger.Info("Server exited") + common.Info("Server exited") } diff --git a/cmd/ragflow_cli.go b/cmd/ragflow_cli.go index bb18a5a44e2..cc2043687cc 100644 --- a/cmd/ragflow_cli.go +++ b/cmd/ragflow_cli.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "os/signal" + "ragflow/internal/common" "syscall" "ragflow/internal/cli" @@ -17,6 +18,15 @@ func main() { os.Exit(1) } + // Initialize logger with appropriate level + logLevel := "warn" // Default to warn (quiet mode) + if args.Verbose { + logLevel = "info" + } + if err = common.Init(logLevel); err != nil { + fmt.Printf("Warning: Failed to initialize logger: %v\n", err) + } + // Show help and exit if args.ShowHelp { cli.PrintUsage() diff --git a/cmd/server_main.go b/cmd/server_main.go index d1db4ad7622..e4a634e72af 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "flag" "fmt" "net/http" @@ -23,7 +24,6 @@ import ( "ragflow/internal/dao" "ragflow/internal/engine" "ragflow/internal/handler" - "ragflow/internal/logger" "ragflow/internal/router" "ragflow/internal/service" "ragflow/internal/service/nlp" @@ -55,81 +55,80 @@ func main() { // Initialize logger with default level // logger.Init("info"); // set debug log level - if err := logger.Init("info"); err != nil { + if err := common.Init("info"); err != nil { panic(fmt.Sprintf("Failed to initialize logger: %v", err)) } // Initialize configuration if err := server.Init(""); err != nil { - logger.Fatal("Failed to initialize config", zap.Error(err)) + common.Fatal("Failed to initialize config", zap.Error(err)) } // Override port with command line argument if provided + config := server.GetConfig() if portFlag > 0 { - config := server.GetConfig() config.Server.Port = portFlag - logger.Info("Port overridden by command line argument", zap.Int("port", portFlag)) + common.Info("Port overridden by command line argument", zap.Int("port", portFlag)) } - // Load model providers configuration - if err := server.LoadModelProviders(""); err != nil { - logger.Fatal("Failed to load model providers", zap.Error(err)) + if config.Server.Port == 0 { + common.Fatal("Server port is not configured. Please specify via --port flag or config file.") } - logger.Info("Model providers loaded", zap.Int("count", len(server.GetModelProviders()))) - config := server.GetConfig() - if config.Server.Port == 0 { - logger.Fatal("Server port is not configured. Please specify via --port flag or config file.") + // Load model providers configuration + if err := server.LoadModelProviders(""); err != nil { + common.Fatal("Failed to load model providers", zap.Error(err)) } + common.Info("Model providers loaded", zap.Int("count", len(server.GetModelProviders()))) // Reinitialize logger with configured level if different if config.Log.Level != "" && config.Log.Level != "info" { - if err := logger.Init(config.Log.Level); err != nil { - logger.Error("Failed to reinitialize logger with configured level", err) + if err := common.Init(config.Log.Level); err != nil { + common.Error("Failed to reinitialize logger with configured level", err) } } - server.SetLogger(logger.Logger) + server.SetLogger(common.Logger) if config.Log.Level == "" { - config.Log.Level = logger.GetLevel() + config.Log.Level = common.GetLevel() } - logger.Info("Server mode", zap.String("mode", config.Server.Mode)) + common.Info("Server mode", zap.String("mode", config.Server.Mode)) // Print all configuration settings server.PrintAll() // Initialize database if err := dao.InitDB(); err != nil { - logger.Fatal("Failed to initialize database", zap.Error(err)) + common.Fatal("Failed to initialize database", zap.Error(err)) } // Initialize LLM factory data models from configuration file if err := dao.InitLLMFactory(); err != nil { - logger.Error("Failed to initialize LLM factory", err) + common.Error("Failed to initialize LLM factory", err) } else { - logger.Info("LLM factory initialized successfully") + common.Info("LLM factory initialized successfully") } // Initialize doc engine if err := engine.Init(&config.DocEngine); err != nil { - logger.Fatal("Failed to initialize doc engine", zap.Error(err)) + common.Fatal("Failed to initialize doc engine", zap.Error(err)) } defer engine.Close() // Initialize Redis cache if err := cache.Init(&config.Redis); err != nil { - logger.Fatal("Failed to initialize Redis", zap.Error(err)) + common.Fatal("Failed to initialize Redis", zap.Error(err)) } defer cache.Close() if err := storage.InitStorageFactory(); err != nil { - logger.Fatal("Failed to initialize storage factory", zap.Error(err)) + common.Fatal("Failed to initialize storage factory", zap.Error(err)) } // Initialize server variables (runtime variables that can change during operation) // This must be done after Cache is initialized if err := server.InitVariables(cache.Get()); err != nil { - logger.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error())) + common.Warn("Failed to initialize server variables from Redis, using defaults", zap.String("error", err.Error())) } // Initialize admin status (default: unavailable=1) @@ -140,19 +139,19 @@ func main() { DictPath: "/usr/share/infinity/resource", } if err := tokenizer.Init(tokenizerCfg); err != nil { - logger.Fatal("Failed to initialize tokenizer", zap.Error(err)) + common.Fatal("Failed to initialize tokenizer", zap.Error(err)) } defer tokenizer.Close() // Initialize global QueryBuilder using tokenizer's DictPath // This ensures the Synonym uses the same wordnet directory as tokenizer if err := nlp.InitQueryBuilderFromTokenizer(tokenizerCfg.DictPath); err != nil { - logger.Fatal("Failed to initialize query builder", zap.Error(err)) + common.Fatal("Failed to initialize query builder", zap.Error(err)) } startServer(config) - logger.Info("Server exited") + common.Info("Server exited") } func startServer(config *server.Config) { @@ -181,6 +180,9 @@ func startServer(config *server.Config) { memoryService := service.NewMemoryService() modelProviderService := service.NewModelProviderService() + // Initialize doc engine for skill search + docEngine := engine.Get() + // Initialize handler layer authHandler := handler.NewAuthHandler() userHandler := handler.NewUserHandler(userService) @@ -197,10 +199,11 @@ func startServer(config *server.Config) { searchHandler := handler.NewSearchHandler(searchService, userService) fileHandler := handler.NewFileHandler(fileService, userService) memoryHandler := handler.NewMemoryHandler(memoryService) + skillSearchHandler := handler.NewSkillSearchHandler(docEngine) providerHandler := handler.NewProviderHandler(userService, modelProviderService) // Initialize router - r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, providerHandler) + r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, skillSearchHandler, providerHandler) // Create Gin engine ginEngine := gin.New() @@ -214,45 +217,49 @@ func startServer(config *server.Config) { // Setup routes r.Setup(ginEngine) - // Create HTTP server + // Create HTTP server with timeouts to prevent slow clients from blocking shutdown addr := fmt.Sprintf(":%d", config.Server.Port) srv := &http.Server{ - Addr: addr, - Handler: ginEngine, + Addr: addr, + Handler: ginEngine, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 60 * time.Second, + WriteTimeout: 120 * time.Second, + IdleTimeout: 120 * time.Second, } // Start server in a goroutine go func() { - logger.Info( + common.Info( "\n ____ ___ ______ ______ __\n" + " / __ \\ / | / ____// ____// /____ _ __\n" + " / /_/ // /| | / / __ / /_ / // __ \\| | /| / /\n" + " / _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ /\n" + " /_/ |_|/_/ |_|\\____//_/ /_/ \\____/ |__/|__/\n", ) - logger.Info(fmt.Sprintf("RAGFlow Go Version: %s", utility.GetRAGFlowVersion())) - logger.Info(fmt.Sprintf("Server starting on port: %d", config.Server.Port)) - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.Fatal("Failed to start server", zap.Error(err)) + common.Info(fmt.Sprintf("RAGFlow Go Version: %s", utility.GetRAGFlowVersion())) + common.Info(fmt.Sprintf("Server starting on port: %d", config.Server.Port)) + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + common.Fatal("Failed to start server", zap.Error(err)) } }() // Get local IP address for heartbeat reporting - localIP := utility.GetLocalIP() - if localIP == "" { - localIP = "127.0.0.1" + localIP, err := utility.GetLocalIP() + if err != nil { + common.Fatal("fail to get local ip address") } // Initialize and start heartbeat reporter to admin server heartbeatService := service.NewHeartbeatSender( - logger.Logger, + common.Logger, common.ServerTypeAPI, fmt.Sprintf("ragflow-server-%d", config.Server.Port), localIP, config.Server.Port, ) - if err := heartbeatService.InitHTTPClient(); err != nil { - logger.Warn("Failed to initialize heartbeat service", zap.Error(err)) + if err = heartbeatService.InitHTTPClient(); err != nil { + common.Warn("Failed to initialize heartbeat service", zap.Error(err)) } else { // Start heartbeat reporter with 30 seconds interval heartbeatReporter := utility.NewScheduledTask("Heartbeat reporter", 3*time.Second, func() { @@ -272,15 +279,15 @@ func startServer(config *server.Config) { signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGUSR2) sig := <-quit - logger.Info(fmt.Sprintf("Receives %s signal to shutdown server", strings.ToUpper(sig.String()))) - logger.Info("Shutting down server...") + common.Info(fmt.Sprintf("Receives %s signal to shutdown server", strings.ToUpper(sig.String()))) + common.Info("Shutting down server...") // Create context with timeout for graceful shutdown ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() // Shutdown server - if err := srv.Shutdown(ctx); err != nil { - logger.Fatal("Server forced to shutdown", zap.Error(err)) + if err = srv.Shutdown(ctx); err != nil { + common.Fatal("Server forced to shutdown", zap.Error(err)) } } diff --git a/common/constants.py b/common/constants.py index b027908637d..5ab9acaa502 100644 --- a/common/constants.py +++ b/common/constants.py @@ -244,6 +244,12 @@ class ForgettingPolicy(StrEnum): SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker" TAG_FLD = "tag_feas" +# Maximum page number used as "unlimited" sentinel value. +# Parsing layer (chunk/Pdf.__call__) uses MAXIMUM_PAGE_NUMBER. +# Task/DB layer (Task model) uses MAXIMUM_PAGE_NUMBER * 1000 to avoid collision with user-specified page ranges. +MAXIMUM_PAGE_NUMBER = 100000 +MAXIMUM_TASK_PAGE_NUMBER = MAXIMUM_PAGE_NUMBER * 1000 + MINERU_ENV_KEYS = ["MINERU_APISERVER", "MINERU_OUTPUT_DIR", "MINERU_BACKEND", "MINERU_SERVER_URL", "MINERU_DELETE_OUTPUT"] MINERU_DEFAULT_CONFIG = { @@ -260,3 +266,8 @@ class ForgettingPolicy(StrEnum): "PADDLEOCR_ACCESS_TOKEN": None, "PADDLEOCR_ALGORITHM": "PaddleOCR-VL", } + +OPENDATALOADER_ENV_KEYS = ["OPENDATALOADER_APISERVER"] +OPENDATALOADER_DEFAULT_CONFIG = { + "OPENDATALOADER_APISERVER": "", +} diff --git a/common/data_source/airtable_connector.py b/common/data_source/airtable_connector.py index 46dcf07ee47..f1ab3004036 100644 --- a/common/data_source/airtable_connector.py +++ b/common/data_source/airtable_connector.py @@ -8,8 +8,14 @@ from common.data_source.config import AIRTABLE_CONNECTOR_SIZE_THRESHOLD, INDEX_BATCH_SIZE, DocumentSource from common.data_source.exceptions import ConnectorMissingCredentialError -from common.data_source.interfaces import LoadConnector, PollConnector -from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch +from common.data_source.interfaces import LoadConnector, PollConnector, SlimConnectorWithPermSync +from common.data_source.models import ( + Document, + GenerateDocumentsOutput, + GenerateSlimDocumentOutput, + SecondsSinceUnixEpoch, + SlimDocument, +) from common.data_source.utils import extract_size_bytes, get_file_ext class AirtableClientNotSetUpError(PermissionError): @@ -19,7 +25,7 @@ def __init__(self) -> None: ) -class AirtableConnector(LoadConnector, PollConnector): +class AirtableConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """ Lightweight Airtable connector. @@ -39,6 +45,43 @@ def __init__( self._airtable_client: AirtableApi | None = None self.size_threshold = AIRTABLE_CONNECTOR_SIZE_THRESHOLD + def _iter_attachment_entries(self) -> Generator[tuple[str, str, str, str, str | None, dict[str, Any]], None, None]: + if not self._airtable_client: + raise ConnectorMissingCredentialError("Airtable credentials not loaded") + + table = self.airtable_client.table(self.base_id, self.table_name_or_id) + records = table.all() + + logging.info( + f"Starting Airtable attachment scan for table {self.table_name_or_id}, " + f"{len(records)} records found." + ) + + for record in records: + record_id = record.get("id") + fields = record.get("fields", {}) + created_time = record.get("createdTime") + + for field_value in fields.values(): + if not isinstance(field_value, list): + continue + + for attachment in field_value: + filename = attachment.get("filename") + attachment_id = attachment.get("id") + + if not record_id or not filename or not attachment_id: + continue + + yield ( + record_id, + attachment_id, + filename, + f"airtable:{record_id}:{attachment_id}", + created_time, + attachment, + ) + # ------------------------- # Credentials # ------------------------- @@ -64,69 +107,65 @@ def load_from_state(self) -> GenerateDocumentsOutput: if not self._airtable_client: raise ConnectorMissingCredentialError("Airtable credentials not loaded") - table = self.airtable_client.table(self.base_id, self.table_name_or_id) - records = table.all() - - logging.info( - f"Starting Airtable blob ingestion for table {self.table_name_or_id}, " - f"{len(records)} records found." - ) - batch: list[Document] = [] - for record in records: - record_id = record.get("id") - fields = record.get("fields", {}) - created_time = record.get("createdTime") - - for field_value in fields.values(): - # We only care about attachment fields (lists of dicts with url/filename) - if not isinstance(field_value, list): - continue + for record_id, attachment_id, filename, doc_id, created_time, attachment in self._iter_attachment_entries(): + url = attachment.get("url") + if not url or not created_time: + continue + + try: + resp = requests.get(url, timeout=30) + resp.raise_for_status() + content = resp.content + except Exception: + logging.exception( + f"Failed to download attachment {filename} " + f"(record={record_id})" + ) + continue + size_bytes = extract_size_bytes(attachment) + if ( + self.size_threshold is not None + and isinstance(size_bytes, int) + and size_bytes > self.size_threshold + ): + logging.warning( + f"{filename} exceeds size threshold of {self.size_threshold}. Skipping." + ) + continue + batch.append( + Document( + id=doc_id, + blob=content, + source=DocumentSource.AIRTABLE, + semantic_identifier=filename, + extension=get_file_ext(filename), + size_bytes=size_bytes if size_bytes else 0, + doc_updated_at=datetime.strptime(created_time, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc) + ) + ) + + if len(batch) >= self.batch_size: + yield batch + batch = [] - for attachment in field_value: - url = attachment.get("url") - filename = attachment.get("filename") - attachment_id = attachment.get("id") + if batch: + yield batch - if not url or not filename or not attachment_id: - continue + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + del callback - try: - resp = requests.get(url, timeout=30) - resp.raise_for_status() - content = resp.content - except Exception: - logging.exception( - f"Failed to download attachment {filename} " - f"(record={record_id})" - ) - continue - size_bytes = extract_size_bytes(attachment) - if ( - self.size_threshold is not None - and isinstance(size_bytes, int) - and size_bytes > self.size_threshold - ): - logging.warning( - f"{filename} exceeds size threshold of {self.size_threshold}. Skipping." - ) - continue - batch.append( - Document( - id=f"airtable:{record_id}:{attachment_id}", - blob=content, - source=DocumentSource.AIRTABLE, - semantic_identifier=filename, - extension=get_file_ext(filename), - size_bytes=size_bytes if size_bytes else 0, - doc_updated_at=datetime.strptime(created_time, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc) - ) - ) + batch: list[SlimDocument] = [] - if len(batch) >= self.batch_size: - yield batch - batch = [] + for _, _, _, doc_id, _, _ in self._iter_attachment_entries(): + batch.append(SlimDocument(id=doc_id)) + if len(batch) >= self.batch_size: + yield batch + batch = [] if batch: yield batch @@ -165,4 +204,4 @@ def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) for doc in first_batch: print(f"- {doc.semantic_identifier} ({doc.size_bytes} bytes)") except StopIteration: - print("No documents available in Dropbox.") \ No newline at end of file + print("No documents available in Dropbox.") diff --git a/common/data_source/asana_connector.py b/common/data_source/asana_connector.py index 4143c0cba0d..e3aee9c4f04 100644 --- a/common/data_source/asana_connector.py +++ b/common/data_source/asana_connector.py @@ -1,13 +1,13 @@ from collections.abc import Iterator import time -from datetime import datetime +from datetime import datetime, timezone import logging from typing import Any, Dict import asana import requests from common.data_source.config import CONTINUE_ON_CONNECTOR_FAILURE, INDEX_BATCH_SIZE, DocumentSource -from common.data_source.interfaces import LoadConnector, PollConnector -from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch +from common.data_source.interfaces import LoadConnector, PollConnector, SlimConnectorWithPermSync +from common.data_source.models import Document, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SecondsSinceUnixEpoch, SlimDocument from common.data_source.utils import extract_size_bytes, get_file_ext @@ -63,6 +63,31 @@ def get_tasks( ) -> Iterator[AsanaTask]: """Get all tasks from the projects with the given gids that were modified since the given date. If project_gids is None, get all tasks from all projects in the workspace.""" + projects_list = self._get_project_gids_to_process(project_gids) + start_seconds = int(time.mktime(datetime.now().timetuple())) + for project_gid in projects_list: + for task in self._get_tasks_for_project( + project_gid, start_date, start_seconds + ): + yield task + logging.info(f"Completed fetching {self.task_count} tasks from Asana") + if self.api_error_count > 0: + logging.warning( + f"Encountered {self.api_error_count} API errors during task fetching" + ) + + def get_task_ids( + self, project_gids: list[str] | None, start_date: str + ) -> Iterator[str]: + """Get task gids without hydrating comments, users, or task text.""" + projects_list = self._get_project_gids_to_process(project_gids) + for project_gid in projects_list: + for task_id in self._get_task_ids_for_project(project_gid, start_date): + yield task_id + + def _get_project_gids_to_process( + self, project_gids: list[str] | None + ) -> list[str]: logging.info("Starting to fetch Asana projects") projects = self.project_api.get_projects( opts={ @@ -70,7 +95,6 @@ def get_tasks( "opt_fields": "gid,name,archived,modified_at", } ) - start_seconds = int(time.mktime(datetime.now().timetuple())) projects_list = [] project_count = 0 for project_info in projects: @@ -85,20 +109,9 @@ def get_tasks( if project_count % 100 == 0: logging.info(f"Processed {project_count} projects") logging.info(f"Found {len(projects_list)} projects to process") - for project_gid in projects_list: - for task in self._get_tasks_for_project( - project_gid, start_date, start_seconds - ): - yield task - logging.info(f"Completed fetching {self.task_count} tasks from Asana") - if self.api_error_count > 0: - logging.warning( - f"Encountered {self.api_error_count} API errors during task fetching" - ) + return projects_list - def _get_tasks_for_project( - self, project_gid: str, start_date: str, start_seconds: int - ) -> Iterator[AsanaTask]: + def _get_project_to_process(self, project_gid: str) -> dict | None: project = self.project_api.get_project(project_gid, opts={}) project_name = project.get("name", project_gid) team = project.get("team") or {} @@ -122,6 +135,35 @@ def _get_tasks_for_project( f"Processing private project in configured team: {project_name} ({project_gid})" ) + return project + + def _get_task_ids_for_project( + self, project_gid: str, start_date: str + ) -> Iterator[str]: + project = self._get_project_to_process(project_gid) + if project is None: + return + + tasks_from_api = self.tasks_api.get_tasks_for_project( + project_gid, + { + "opt_fields": "gid", + "modified_since": start_date, + }, + ) + for data in tasks_from_api: + task_id = data.get("gid") + if task_id: + yield task_id + + def _get_tasks_for_project( + self, project_gid: str, start_date: str, start_seconds: int + ) -> Iterator[AsanaTask]: + project = self._get_project_to_process(project_gid) + if project is None: + return + + project_name = project.get("name", project_gid) simple_start_date = start_date.split(".")[0].split("+")[0] logging.info( f"Fetching tasks modified since {simple_start_date} for project: {project_name} ({project_gid})" @@ -242,7 +284,7 @@ def get_attachments(self, task_gid: str) -> list[dict]: full = self.attachments_api.get_attachment( attachment_gid=gid, opts={ - "opt_fields": "name,download_url,size,created_at" + "opt_fields": "gid,name,download_url,size,created_at" } ) @@ -330,7 +372,7 @@ def get_time(self) -> str: return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) -class AsanaConnector(LoadConnector, PollConnector): +class AsanaConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): def __init__( self, asana_workspace_id: str, @@ -367,11 +409,22 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None ) -> GenerateDocumentsOutput: - start_time = datetime.fromtimestamp(start).isoformat() + start_time = datetime.fromtimestamp(start, tz=timezone.utc).isoformat() + end_time = datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None logging.info(f"Starting Asana poll from {start_time}") docs_batch: list[Document] = [] tasks = self.asana_client.get_tasks(self.project_ids_to_index, start_time) for task in tasks: + if end_time: + task_last_modified = task.last_modified + if task_last_modified.tzinfo is None: + task_last_modified = task_last_modified.replace(tzinfo=timezone.utc) + else: + task_last_modified = task_last_modified.astimezone(timezone.utc) + + if task_last_modified >= end_time: + continue + docs = self._task_to_documents(task) docs_batch.extend(docs) @@ -390,6 +443,31 @@ def load_from_state(self) -> GenerateDocumentsOutput: logging.info("Starting full index of all Asana tasks") return self.poll_source(start=0, end=None) + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + del callback + + start_time = datetime.fromtimestamp(0, tz=timezone.utc).isoformat() + docs_batch: list[SlimDocument] = [] + + for task_id in self.asana_client.get_task_ids(self.project_ids_to_index, start_time): + attachments = self.asana_client.get_attachments(task_id) + + for att in attachments: + attachment_gid = att.get("gid") + if not attachment_gid: + continue + + docs_batch.append(SlimDocument(id=f"asana:{task_id}:{attachment_gid}")) + if len(docs_batch) >= self.batch_size: + yield docs_batch + docs_batch = [] + + if docs_batch: + yield docs_batch + def _task_to_documents(self, task: AsanaTask) -> list[Document]: docs: list[Document] = [] @@ -456,4 +534,4 @@ def _task_to_documents(self, task: AsanaTask) -> list[Document]: for docs in all_docs: for doc in docs: print(doc.id) - logging.info("Asana connector test completed") \ No newline at end of file + logging.info("Asana connector test completed") diff --git a/common/data_source/bitbucket/connector.py b/common/data_source/bitbucket/connector.py index f355a8945fc..0557d2a5039 100644 --- a/common/data_source/bitbucket/connector.py +++ b/common/data_source/bitbucket/connector.py @@ -269,17 +269,11 @@ def validate_checkpoint_json( def retrieve_all_slim_docs_perm_sync( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> Iterator[list[SlimDocument]]: """Return only document IDs for all existing pull requests.""" batch: list[SlimDocument] = [] - params = self._build_params( - fields=SLIM_PR_LIST_RESPONSE_FIELDS, - start=start, - end=end, - ) + params = self._build_params(fields=SLIM_PR_LIST_RESPONSE_FIELDS) with self._client() as client: for slug in self._iter_target_repositories(client): for pr in self._iter_pull_requests_for_repo( @@ -361,10 +355,7 @@ def validate_connector_settings(self) -> None: start_time = datetime.fromtimestamp(0, tz=timezone.utc) end_time = datetime.now(timezone.utc) - for doc_batch in bitbucket.retrieve_all_slim_docs_perm_sync( - start=start_time.timestamp(), - end=end_time.timestamp(), - ): + for doc_batch in bitbucket.retrieve_all_slim_docs_perm_sync(): for doc in doc_batch: print(doc) @@ -385,4 +376,4 @@ def validate_connector_settings(self) -> None: except StopIteration as e: bitbucket_checkpoint = e.value break - \ No newline at end of file + diff --git a/common/data_source/blob_connector.py b/common/data_source/blob_connector.py index 1ab39189d79..7505b878ba3 100644 --- a/common/data_source/blob_connector.py +++ b/common/data_source/blob_connector.py @@ -19,7 +19,13 @@ InsufficientPermissionsError ) from common.data_source.interfaces import LoadConnector, PollConnector -from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput +from common.data_source.models import ( + Document, + SecondsSinceUnixEpoch, + GenerateDocumentsOutput, + GenerateSlimDocumentOutput, + SlimDocument, +) class BlobStorageConnector(LoadConnector, PollConnector): @@ -122,29 +128,7 @@ def _yield_blob_objects( end: datetime, ) -> GenerateDocumentsOutput: """Generate bucket objects""" - if self.s3_client is None: - raise ConnectorMissingCredentialError("Blob storage") - - paginator = self.s3_client.get_paginator("list_objects_v2") - pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix) - - # Collect all objects first to count filename occurrences - all_objects = [] - for page in pages: - if "Contents" not in page: - continue - for obj in page["Contents"]: - if obj["Key"].endswith("/"): - continue - last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) - if start < last_modified <= end: - all_objects.append(obj) - - # Count filename occurrences to determine which need full paths - filename_counts: dict[str, int] = {} - for obj in all_objects: - file_name = os.path.basename(obj["Key"]) - filename_counts[file_name] = filename_counts.get(file_name, 0) + 1 + all_objects, filename_counts = self._collect_blob_objects(start, end) batch: list[Document] = [] for obj in all_objects: @@ -162,20 +146,15 @@ def _yield_blob_objects( f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping." ) continue - + try: - blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold) + blob = download_object( + self.s3_client, self.bucket_name, key, self.size_threshold + ) if blob is None: continue - # Use full path only if filename appears multiple times - if filename_counts.get(file_name, 0) > 1: - relative_path = key - if self.prefix and key.startswith(self.prefix): - relative_path = key[len(self.prefix):] - semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name - else: - semantic_id = file_name + semantic_id = self._get_semantic_id(key, file_name, filename_counts) batch.append( Document( @@ -185,7 +164,7 @@ def _yield_blob_objects( semantic_identifier=semantic_id, extension=get_file_ext(file_name), doc_updated_at=last_modified, - size_bytes=size_bytes if size_bytes else 0 + size_bytes=size_bytes if size_bytes else 0, ) ) if len(batch) == self.batch_size: @@ -194,7 +173,76 @@ def _yield_blob_objects( except Exception: logging.exception(f"Error decoding object {key}") - + + if batch: + yield batch + + def _collect_blob_objects( + self, + start: datetime, + end: datetime, + ) -> tuple[list[dict[str, Any]], dict[str, int]]: + """Collect object metadata for files in the requested window.""" + if self.s3_client is None: + raise ConnectorMissingCredentialError("Blob storage") + + paginator = self.s3_client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix) + + # Collect all objects first to count filename occurrences + all_objects: list[dict[str, Any]] = [] + for page in pages: + if "Contents" not in page: + continue + for obj in page["Contents"]: + if obj["Key"].endswith("/"): + continue + last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) + if start < last_modified <= end: + all_objects.append(obj) + + filename_counts: dict[str, int] = {} + for obj in all_objects: + file_name = os.path.basename(obj["Key"]) + filename_counts[file_name] = filename_counts.get(file_name, 0) + 1 + + return all_objects, filename_counts + + def _get_semantic_id( + self, + key: str, + file_name: str, + filename_counts: dict[str, int], + ) -> str: + """Use full relative path only when filenames collide.""" + if filename_counts.get(file_name, 0) > 1: + relative_path = key + if self.prefix and key.startswith(self.prefix): + relative_path = key[len(self.prefix):] + return relative_path.replace("/", " / ") if relative_path else file_name + return file_name + + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + """Return a full current snapshot of blob object IDs without downloading content.""" + del callback + + all_objects, _ = self._collect_blob_objects( + start=datetime(1970, 1, 1, tzinfo=timezone.utc), + end=datetime.now(timezone.utc), + ) + + batch: list[SlimDocument] = [] + for obj in all_objects: + batch.append( + SlimDocument(id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}") + ) + if len(batch) == self.batch_size: + yield batch + batch = [] + if batch: yield batch diff --git a/common/data_source/box_connector.py b/common/data_source/box_connector.py index 253029d3c92..cc44f356e87 100644 --- a/common/data_source/box_connector.py +++ b/common/data_source/box_connector.py @@ -1,7 +1,7 @@ """Box connector""" import logging from datetime import datetime, timezone -from typing import Any +from typing import Any, Generator from box_sdk_gen import BoxClient from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE @@ -10,21 +10,21 @@ ConnectorValidationError, ) from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch -from common.data_source.models import Document, GenerateDocumentsOutput +from common.data_source.models import Document, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument from common.data_source.utils import get_file_ext + class BoxConnector(LoadConnector, PollConnector): def __init__(self, folder_id: str, batch_size: int = INDEX_BATCH_SIZE, use_marker: bool = True) -> None: self.batch_size = batch_size self.folder_id = "0" if not folder_id else folder_id self.use_marker = use_marker - + self.box_client: BoxClient | None = None def load_credentials(self, auth: Any): self.box_client = BoxClient(auth=auth) return None - def validate_connector_settings(self): if self.box_client is None: raise ConnectorMissingCredentialError("Box") @@ -35,79 +35,41 @@ def validate_connector_settings(self): logging.exception("[Box]: Failed to validate Box credentials") raise ConnectorValidationError(f"Unexpected error during Box settings validation: {e}") - - def _yield_files_recursive( - self, - folder_id: str, - start: SecondsSinceUnixEpoch | None, - end: SecondsSinceUnixEpoch | None, - relative_folder_path: str = "", - ) -> GenerateDocumentsOutput: - + def _iter_files_recursive( + self, + folder_id: str, + relative_folder_path: str = "", + ) -> Generator[tuple[Any, str], None, None]: if self.box_client is None: raise ConnectorMissingCredentialError("Box") result = self.box_client.folders.get_folder_items( folder_id=folder_id, limit=self.batch_size, - usemarker=self.use_marker + usemarker=self.use_marker, ) while True: - batch: list[Document] = [] for entry in result.entries: - if entry.type == 'file' : - file = self.box_client.files.get_file_by_id( - entry.id - ) - modified_time: SecondsSinceUnixEpoch | None = None - raw_time = ( - getattr(file, "created_at", None) - or getattr(file, "content_created_at", None) - ) - - if raw_time: - modified_time = self._box_datetime_to_epoch_seconds(raw_time) - if start is not None and modified_time <= start: - continue - if end is not None and modified_time > end: - continue - - content_bytes = self.box_client.downloads.download_file(file.id) + if entry.type == "file": + file = self.box_client.files.get_file_by_id(entry.id) semantic_identifier = ( f"{relative_folder_path} / {file.name}" if relative_folder_path else file.name ) - - batch.append( - Document( - id=f"box:{file.id}", - blob=content_bytes.read(), - source=DocumentSource.BOX, - semantic_identifier=semantic_identifier, - extension=get_file_ext(file.name), - doc_updated_at=modified_time, - size_bytes=file.size, - metadata=file.metadata - ) - ) - elif entry.type == 'folder': + yield file, semantic_identifier + elif entry.type == "folder": child_relative_path = ( f"{relative_folder_path} / {entry.name}" if relative_folder_path else entry.name ) - yield from self._yield_files_recursive( + yield from self._iter_files_recursive( folder_id=entry.id, - start=start, - end=end, - relative_folder_path=child_relative_path + relative_folder_path=child_relative_path, ) - if batch: - yield batch - if not result.next_marker: break @@ -115,9 +77,56 @@ def _yield_files_recursive( folder_id=folder_id, limit=self.batch_size, marker=result.next_marker, - usemarker=True + usemarker=True, ) + def _yield_files_recursive( + self, + folder_id: str, + start: SecondsSinceUnixEpoch | None, + end: SecondsSinceUnixEpoch | None, + relative_folder_path: str = "", + ) -> GenerateDocumentsOutput: + if self.box_client is None: + raise ConnectorMissingCredentialError("Box") + + batch: list[Document] = [] + for file, semantic_identifier in self._iter_files_recursive( + folder_id=folder_id, + relative_folder_path=relative_folder_path, + ): + modified_time: SecondsSinceUnixEpoch | None = None + raw_time = ( + getattr(file, "created_at", None) + or getattr(file, "content_created_at", None) + ) + + if raw_time: + modified_time = self._box_datetime_to_epoch_seconds(raw_time) + if start is not None and modified_time <= start: + continue + if end is not None and modified_time > end: + continue + + content_bytes = self.box_client.downloads.download_file(file.id) + batch.append( + Document( + id=f"box:{file.id}", + blob=content_bytes.read(), + source=DocumentSource.BOX, + semantic_identifier=semantic_identifier, + extension=get_file_ext(file.name), + doc_updated_at=modified_time, + size_bytes=file.size, + metadata=file.metadata, + ) + ) + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch def _box_datetime_to_epoch_seconds(self, dt: datetime) -> SecondsSinceUnixEpoch: """Convert a Box SDK datetime to Unix epoch seconds (UTC). @@ -133,6 +142,21 @@ def _box_datetime_to_epoch_seconds(self, dt: datetime) -> SecondsSinceUnixEpoch: return SecondsSinceUnixEpoch(int(dt.timestamp())) + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + del callback + + batch: list[SlimDocument] = [] + for file, _semantic_identifier in self._iter_files_recursive(folder_id=self.folder_id): + batch.append(SlimDocument(id=f"box:{file.id}")) + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch def poll_source(self, start, end): return self._yield_files_recursive(folder_id=self.folder_id, start=start, end=end) diff --git a/common/data_source/confluence_connector.py b/common/data_source/confluence_connector.py index abe55b5b275..ef0d6a77600 100644 --- a/common/data_source/confluence_connector.py +++ b/common/data_source/confluence_connector.py @@ -1904,8 +1904,6 @@ def retrieve_all_slim_docs( def retrieve_all_slim_docs_perm_sync( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: """ @@ -1913,16 +1911,12 @@ def retrieve_all_slim_docs_perm_sync( Does not fetch actual text. Used primarily for incremental permission sync. """ return self._retrieve_all_slim_docs( - start=start, - end=end, callback=callback, include_permissions=True, ) def _retrieve_all_slim_docs( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, include_permissions: bool = True, ) -> GenerateSlimDocumentOutput: diff --git a/common/data_source/dingtalk_ai_table_connector.py b/common/data_source/dingtalk_ai_table_connector.py index 66588d4d307..40dc44b61f5 100644 --- a/common/data_source/dingtalk_ai_table_connector.py +++ b/common/data_source/dingtalk_ai_table_connector.py @@ -22,8 +22,8 @@ from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource from common.data_source.exceptions import ConnectorMissingCredentialError, ConnectorValidationError -from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch -from common.data_source.models import Document, GenerateDocumentsOutput +from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync +from common.data_source.models import Document, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument logger = logging.getLogger(__name__) @@ -38,7 +38,7 @@ def __init__(self) -> None: super().__init__("DingTalk Notable client is not set up. Did you forget to call load_credentials()?") -class DingTalkAITableConnector(LoadConnector, PollConnector): +class DingTalkAITableConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """ DingTalk AI Table (Notable) connector for accessing table records. @@ -75,6 +75,9 @@ def __init__( self._client: NotableClient | None = None self._access_token: str | None = None + def _document_id(self, sheet_id: str, record_id: str) -> str: + return f"{_DINGTALK_AI_TABLE_DOC_ID_PREFIX}{self.table_id}:{sheet_id}:{record_id}" + def _create_client(self) -> NotableClient: """Create DingTalk Notable API client.""" config = open_api_models.Config() @@ -280,6 +283,8 @@ def _convert_record_to_document( record_id = record.get("id", "unknown") fields = record.get("fields", {}) + doc_id = self._document_id(sheet_id, str(record_id)) + # Convert fields to JSON string for blob content content = json.dumps(fields, ensure_ascii=False, indent=2) blob = content.encode("utf-8") @@ -304,7 +309,7 @@ def _convert_record_to_document( # Create document doc = Document( - id=f"{_DINGTALK_AI_TABLE_DOC_ID_PREFIX}{self.table_id}:{sheet_id}:{record_id}", + id=doc_id, source=DocumentSource.DINGTALK_AI_TABLE, semantic_identifier=semantic_identifier, extension=".json", @@ -316,6 +321,44 @@ def _convert_record_to_document( return doc + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + """ + Enumerate current record IDs for all sheets without building document blobs. + + IDs match :meth:`_convert_record_to_document` / full ingest. + """ + del callback + logger.info( + "[DingTalk Notable]: slim snapshot table_id=%s operator_id=%s", + self.table_id, + self.operator_id, + ) + sheets = self._get_all_sheets() + batch: list[SlimDocument] = [] + for sheet in sheets: + sheet_id = sheet["id"] + next_token: str | None = None + while True: + records, next_token = self._list_records( + sheet_id=sheet_id, + next_token=next_token, + ) + for record in records: + rid = record.get("id") + if not rid: + continue + batch.append(SlimDocument(id=self._document_id(sheet_id, str(rid)))) + if len(batch) >= self.batch_size: + yield batch + batch = [] + if not next_token: + break + if batch: + yield batch + def _yield_documents_from_table( self, start: SecondsSinceUnixEpoch | None = None, diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py index e65a6324185..83b2b562f0e 100644 --- a/common/data_source/discord_connector.py +++ b/common/data_source/discord_connector.py @@ -13,8 +13,14 @@ from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource from common.data_source.exceptions import ConnectorMissingCredentialError -from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch -from common.data_source.models import Document, GenerateDocumentsOutput, TextSection +from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync +from common.data_source.models import ( + Document, + GenerateDocumentsOutput, + GenerateSlimDocumentOutput, + SlimDocument, + TextSection, +) _DISCORD_DOC_ID_PREFIX = "DISCORD_" _SNIPPET_LENGTH = 30 @@ -94,8 +100,12 @@ async def _fetch_filtered_channels( async def _fetch_documents_from_channel( channel: TextChannel, start_time: datetime | None, - end_time: datetime | None, -) -> AsyncIterable[Document]: +) -> AsyncIterable[DiscordMessage]: + """Yield raw Discord messages for one channel and its threads. + + This stays at the message layer so callers can decide whether they need + full Document construction or only lightweight ID accounting. + """ # Discord's epoch starts at 2015-01-01 discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc) if start_time and start_time < discord_epoch: @@ -109,39 +119,23 @@ async def _fetch_documents_from_channel( async for channel_message in channel.history( limit=None, after=start_time, - before=end_time, ): # Skip messages that are not the default type if channel_message.type != MessageType.default: continue - sections: list[TextSection] = [ - TextSection( - text=channel_message.content, - link=channel_message.jump_url, - ) - ] - - yield _convert_message_to_document(channel_message, sections) + yield channel_message for active_thread in channel.threads: async for thread_message in active_thread.history( limit=None, after=start_time, - before=end_time, ): # Skip messages that are not the default type if thread_message.type != MessageType.default: continue - sections = [ - TextSection( - text=thread_message.content, - link=thread_message.jump_url, - ) - ] - - yield _convert_message_to_document(thread_message, sections) + yield thread_message async for archived_thread in channel.archived_threads( limit=None, @@ -149,20 +143,12 @@ async def _fetch_documents_from_channel( async for thread_message in archived_thread.history( limit=None, after=start_time, - before=end_time, ): # Skip messages that are not the default type if thread_message.type != MessageType.default: continue - sections = [ - TextSection( - text=thread_message.content, - link=thread_message.jump_url, - ) - ] - - yield _convert_message_to_document(thread_message, sections) + yield thread_message def _manage_async_retrieval( @@ -171,20 +157,23 @@ def _manage_async_retrieval( channel_names: list[str], server_ids: list[int], start: datetime | None = None, - end: datetime | None = None, -) -> Iterable[Document]: +) -> Iterable[DiscordMessage]: + """Bridge the async Discord client into a synchronous iterator. + + `start` is only used as a lower bound for the underlying fetch. Callers + that need a narrower time window should apply their own filtering while + iterating so the same full scan can also support deleted-file sync. + """ # parse requested_start_date_string to datetime pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None - # Set start_time to the most recent of start and pull_date, or whichever is provided + # Keep the configured start date as the full-scan lower bound. start_time = max(filter(None, [start, pull_date])) if start or pull_date else None - - end_time: datetime | None = end proxy_url: str | None = os.environ.get("https_proxy") or os.environ.get("http_proxy") if proxy_url: logging.info(f"Using proxy for Discord: {proxy_url}") - async def _async_fetch() -> AsyncIterable[Document]: + async def _async_fetch() -> AsyncIterable[DiscordMessage]: intents = Intents.default() intents.message_content = True async with Client(intents=intents, proxy=proxy_url) as cli: @@ -198,15 +187,13 @@ async def _async_fetch() -> AsyncIterable[Document]: ) for channel in filtered_channels: - async for doc in _fetch_documents_from_channel( + async for message in _fetch_documents_from_channel( channel=channel, start_time=start_time, - end_time=end_time, ): - print(doc) - yield doc + yield message - def run_and_yield() -> Iterable[Document]: + def run_and_yield() -> Iterable[DiscordMessage]: loop = asyncio.new_event_loop() try: # Get the async generator @@ -228,7 +215,7 @@ def run_and_yield() -> Iterable[Document]: return run_and_yield() -class DiscordConnector(LoadConnector, PollConnector): +class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """Discord connector for accessing Discord messages and channels""" def __init__( @@ -251,12 +238,28 @@ def discord_bot_token(self) -> str: raise ConnectorMissingCredentialError("Discord") return self._discord_bot_token - def _manage_doc_batching( + def _iter_merged_documents( self, start: datetime | None = None, end: datetime | None = None, ) -> GenerateDocumentsOutput: - doc_batch = [] + """Build merged Discord documents for the requested polling window.""" + doc_batch: list[Document] = [] + + def _message_created_at(message: DiscordMessage) -> datetime: + created_at = message.created_at + if created_at.tzinfo is None: + return created_at.replace(tzinfo=timezone.utc) + return created_at.astimezone(timezone.utc) + + def _is_in_window(message: DiscordMessage) -> bool: + created_at = _message_created_at(message) + if start is not None and created_at < start: + return False + if end is not None and created_at >= end: + return False + return True + def merge_batch(): nonlocal doc_batch id = doc_batch[0].id @@ -280,14 +283,23 @@ def merge_batch(): size_bytes=size_bytes, ) - for doc in _manage_async_retrieval( + for message in _manage_async_retrieval( token=self.discord_bot_token, requested_start_date_string=self.requested_start_date_string, channel_names=self.channel_names, server_ids=self.server_ids, start=start, - end=end, ): + if not _is_in_window(message): + continue + + sections = [ + TextSection( + text=message.content, + link=message.jump_url, + ) + ] + doc = _convert_message_to_document(message, sections) doc_batch.append(doc) if len(doc_batch) >= self.batch_size: yield [merge_batch()] @@ -296,6 +308,13 @@ def merge_batch(): if doc_batch: yield [merge_batch()] + def _manage_doc_batching( + self, + start: datetime | None = None, + end: datetime | None = None, + ) -> GenerateDocumentsOutput: + yield from self._iter_merged_documents(start=start, end=end) + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self._discord_bot_token = credentials["discord_bot_token"] return None @@ -316,6 +335,41 @@ def load_from_state(self) -> Any: """Load messages from Discord state""" return self._manage_doc_batching(None, None) + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + del callback + slim_doc_batch: list[SlimDocument] = [] + full_scan_batch_size = 0 + full_scan_batch_first_id: str | None = None + + for message in _manage_async_retrieval( + token=self.discord_bot_token, + requested_start_date_string=self.requested_start_date_string, + channel_names=self.channel_names, + server_ids=self.server_ids, + start=None, + ): + if full_scan_batch_first_id is None: + full_scan_batch_first_id = f"{_DISCORD_DOC_ID_PREFIX}{message.id}" + full_scan_batch_size += 1 + + if full_scan_batch_size >= self.batch_size: + slim_doc_batch.append(SlimDocument(id=full_scan_batch_first_id)) + full_scan_batch_size = 0 + full_scan_batch_first_id = None + + if len(slim_doc_batch) >= self.batch_size: + yield slim_doc_batch + slim_doc_batch = [] + + if full_scan_batch_first_id is not None: + slim_doc_batch.append(SlimDocument(id=full_scan_batch_first_id)) + + if slim_doc_batch: + yield slim_doc_batch + if __name__ == "__main__": import os diff --git a/common/data_source/dropbox_connector.py b/common/data_source/dropbox_connector.py index 0e7131d8f3b..43ab08f4b06 100644 --- a/common/data_source/dropbox_connector.py +++ b/common/data_source/dropbox_connector.py @@ -14,14 +14,14 @@ ConnectorValidationError, InsufficientPermissionsError, ) -from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch -from common.data_source.models import Document, GenerateDocumentsOutput +from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync +from common.data_source.models import Document, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument from common.data_source.utils import get_file_ext logger = logging.getLogger(__name__) -class DropboxConnector(LoadConnector, PollConnector): +class DropboxConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """Dropbox connector for accessing Dropbox files and folders""" def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: @@ -87,57 +87,48 @@ def _yield_files_recursive( if self.dropbox_client is None: raise ConnectorMissingCredentialError("Dropbox") - # Collect all files first to count filename occurrences - all_files = [] - self._collect_files_recursive(path, start, end, all_files) - + all_files: list[FileMetadata] = [] + self._collect_file_entries_recursive(path, start, end, all_files) + # Count filename occurrences filename_counts: dict[str, int] = {} - for entry, _ in all_files: + for entry in all_files: filename_counts[entry.name] = filename_counts.get(entry.name, 0) + 1 - + # Process files in batches batch: list[Document] = [] - for entry, downloaded_file in all_files: - modified_time = entry.client_modified - if modified_time.tzinfo is None: - modified_time = modified_time.replace(tzinfo=timezone.utc) - else: - modified_time = modified_time.astimezone(timezone.utc) - - # Use full path only if filename appears multiple times - if filename_counts.get(entry.name, 0) > 1: - # Remove leading slash and replace slashes with ' / ' - relative_path = entry.path_display.lstrip('/') - semantic_id = relative_path.replace('/', ' / ') if relative_path else entry.name - else: - semantic_id = entry.name - + for entry in all_files: + try: + downloaded_file = self._download_file(entry.path_display) + except Exception: + logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}") + continue + batch.append( Document( id=f"dropbox:{entry.id}", blob=downloaded_file, source=DocumentSource.DROPBOX, - semantic_identifier=semantic_id, + semantic_identifier=self._get_semantic_identifier(entry, filename_counts), extension=get_file_ext(entry.name), - doc_updated_at=modified_time, + doc_updated_at=self._normalize_modified_time(entry.client_modified), size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file), ) ) - + if len(batch) == self.batch_size: yield batch batch = [] - + if batch: yield batch - def _collect_files_recursive( + def _collect_file_entries_recursive( self, path: str, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None, - all_files: list, + all_files: list[FileMetadata], ) -> None: """Recursively collect all files matching time criteria.""" if self.dropbox_client is None: @@ -152,33 +143,56 @@ def _collect_files_recursive( while True: for entry in result.entries: if isinstance(entry, FileMetadata): - modified_time = entry.client_modified - if modified_time.tzinfo is None: - modified_time = modified_time.replace(tzinfo=timezone.utc) - else: - modified_time = modified_time.astimezone(timezone.utc) - - time_as_seconds = modified_time.timestamp() + time_as_seconds = self._normalize_modified_time(entry.client_modified).timestamp() if start is not None and time_as_seconds <= start: continue if end is not None and time_as_seconds > end: continue - try: - downloaded_file = self._download_file(entry.path_display) - all_files.append((entry, downloaded_file)) - except Exception: - logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}") - continue + all_files.append(entry) elif isinstance(entry, FolderMetadata): - self._collect_files_recursive(entry.path_lower, start, end, all_files) + self._collect_file_entries_recursive(entry.path_lower, start, end, all_files) if not result.has_more: break result = self.dropbox_client.files_list_folder_continue(result.cursor) + def _normalize_modified_time(self, modified_time): + if modified_time.tzinfo is None: + return modified_time.replace(tzinfo=timezone.utc) + return modified_time.astimezone(timezone.utc) + + def _get_semantic_identifier(self, entry: FileMetadata, filename_counts: dict[str, int]) -> str: + if filename_counts.get(entry.name, 0) <= 1: + return entry.name + + relative_path = entry.path_display.lstrip("/") + return relative_path.replace("/", " / ") if relative_path else entry.name + + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + del callback + + if self.dropbox_client is None: + raise ConnectorMissingCredentialError("Dropbox") + + all_files: list[FileMetadata] = [] + self._collect_file_entries_recursive("", None, None, all_files) + + batch: list[SlimDocument] = [] + for entry in all_files: + batch.append(SlimDocument(id=f"dropbox:{entry.id}")) + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput: """Poll Dropbox for recent file changes""" if self.dropbox_client is None: diff --git a/common/data_source/github/connector.py b/common/data_source/github/connector.py index 258e2cf8b46..2d65c995e6b 100644 --- a/common/data_source/github/connector.py +++ b/common/data_source/github/connector.py @@ -964,11 +964,9 @@ def retrieve_slim_document( def retrieve_all_slim_docs_perm_sync( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, callback: Any = None, ) -> GenerateSlimDocumentOutput: - yield from self.retrieve_slim_document(start=start, end=end, callback=callback) + yield from self.retrieve_slim_document(callback=callback) def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint: return GithubConnectorCheckpoint( diff --git a/common/data_source/gitlab_connector.py b/common/data_source/gitlab_connector.py index 0d2c0dab775..dae24992b49 100644 --- a/common/data_source/gitlab_connector.py +++ b/common/data_source/gitlab_connector.py @@ -20,8 +20,11 @@ from common.data_source.interfaces import LoadConnector from common.data_source.interfaces import PollConnector from common.data_source.interfaces import SecondsSinceUnixEpoch +from common.data_source.interfaces import SlimConnectorWithPermSync from common.data_source.models import BasicExpertInfo from common.data_source.models import Document +from common.data_source.models import GenerateSlimDocumentOutput +from common.data_source.models import SlimDocument from common.data_source.utils import get_file_ext T = TypeVar("T") @@ -158,7 +161,7 @@ def _should_exclude(path: str) -> bool: return any(fnmatch.fnmatch(path, pattern) for pattern in exclude_patterns) -class GitlabConnector(LoadConnector, PollConnector): +class GitlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): def __init__( self, project_owner: str, @@ -313,6 +316,67 @@ def poll_source( end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) return self._fetch_from_gitlab(start_datetime, end_datetime) + def retrieve_all_slim_docs_perm_sync(self, callback: Any = None) -> GenerateSlimDocumentOutput: + if self.gitlab_client is None: + raise ConnectorMissingCredentialError("Gitlab") + + project: Project = self.gitlab_client.projects.get( + f"{self.project_owner}/{self.project_name}" + ) + + slim_batch: list[SlimDocument] = [] + + def append_doc(doc_id: str): + slim_batch.append(SlimDocument(id=doc_id)) + if len(slim_batch) >= self.batch_size: + batch = slim_batch[:] + slim_batch.clear() + return batch + return None + + if self.include_code_files: + default_branch = project.default_branch + queue = deque([""]) + while queue: + current_path = queue.popleft() + files = project.repository_tree(path=current_path, all=True) + for file in files: + if _should_exclude(file["path"]): + continue + if file["type"] == "tree": + queue.append(file["path"]) + continue + if file["type"] != "blob": + continue + + file_url = f"{self.gitlab_client.url}/{self.project_owner}/{self.project_name}/-/blob/{default_branch}/{file['path']}" + batch = append_doc(file_url) + if batch: + yield batch + + if self.include_mrs: + merge_requests = project.mergerequests.list( + state=self.state_filter, + iterator=True, + ) + for mr in merge_requests: + batch = append_doc(mr.web_url) + if batch: + yield batch + + if self.include_issues: + issues = project.issues.list( + state=self.state_filter, + iterator=True, + ) + for issue in issues: + batch = append_doc(issue.web_url) + if batch: + yield batch + + if slim_batch: + yield slim_batch + if __name__ == "__main__": import os @@ -337,4 +401,4 @@ def poll_source( document_batches = connector.load_from_state() for f in document_batches: print("Batch:", f) - print("Finished loading from state.") \ No newline at end of file + print("Finished loading from state.") diff --git a/common/data_source/gmail_connector.py b/common/data_source/gmail_connector.py index 1421f9f4bf1..ea4dd993ae0 100644 --- a/common/data_source/gmail_connector.py +++ b/common/data_source/gmail_connector.py @@ -270,12 +270,10 @@ def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) def retrieve_all_slim_docs_perm_sync( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, callback=None, ) -> GenerateSlimDocumentOutput: """Retrieve slim documents for permission synchronization.""" - query = build_time_range_query(start, end) + query = build_time_range_query() doc_batch = [] for user_email in self._get_all_user_emails(): @@ -343,4 +341,4 @@ def retrieve_all_slim_docs_perm_sync( print(f) print("\n\n") except Exception as e: - logging.exception(f"Error loading credentials: {e}") \ No newline at end of file + logging.exception(f"Error loading credentials: {e}") diff --git a/common/data_source/google_drive/connector.py b/common/data_source/google_drive/connector.py index b44c28d74db..479c60e0b63 100644 --- a/common/data_source/google_drive/connector.py +++ b/common/data_source/google_drive/connector.py @@ -159,6 +159,7 @@ def __init__( self._creds: OAuthCredentials | ServiceAccountCredentials | None = None self._creds_dict: dict[str, Any] | None = None + self._all_drive_ids_cache: set[str] | None = None # ids of folders and shared drives that have been traversed self._retrieved_folder_and_drive_ids: set[str] = set() @@ -211,6 +212,7 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None self.include_files_shared_with_me = True self._creds_dict = new_creds_dict + self._all_drive_ids_cache = None return new_creds_dict @@ -249,7 +251,11 @@ def _get_all_user_emails(self) -> list[str]: return user_emails def get_all_drive_ids(self) -> set[str]: - return self._get_all_drives_for_user(self.primary_admin_email) + if self._all_drive_ids_cache is None: + self._all_drive_ids_cache = self._get_all_drives_for_user( + self.primary_admin_email + ) + return set(self._all_drive_ids_cache) def _get_all_drives_for_user(self, user_email: str) -> set[str]: drive_service = get_drive_service(self.creds, user_email) @@ -265,7 +271,14 @@ def _get_all_drives_for_user(self, user_email: str) -> set[str]: all_drive_ids.add(drive["id"]) if not all_drive_ids: - self.logger.warning("No drives found even though indexing shared drives was requested.") + if self._requested_shared_drive_ids: + self.logger.warning( + "No shared drives found for user %s while resolving requested shared drives.", + user_email, + ) + elif self.include_shared_drives: + log_fn = self.logger.warning if is_service_account else self.logger.info + log_fn("No shared drives found for user %s.", user_email) return all_drive_ids @@ -1087,8 +1100,6 @@ def _extract_slim_docs_from_google_drive( def retrieve_all_slim_docs_perm_sync( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: try: @@ -1096,8 +1107,6 @@ def retrieve_all_slim_docs_perm_sync( while checkpoint.completion_stage != DriveRetrievalStage.DONE: yield from self._extract_slim_docs_from_google_drive( checkpoint=checkpoint, - start=start, - end=end, ) self.logger.info("Drive perm sync: Slim doc retrieval complete") diff --git a/common/data_source/google_util/resource.py b/common/data_source/google_util/resource.py index eb060e46883..ba4199cb078 100644 --- a/common/data_source/google_util/resource.py +++ b/common/data_source/google_util/resource.py @@ -85,9 +85,19 @@ def _get_google_service( if isinstance(creds, ServiceAccountCredentials): # NOTE: https://developers.google.com/identity/protocols/oauth2/service-account#error-codes creds = creds.with_subject(user_email) - service = build(service_name, service_version, credentials=creds) + service = build( + service_name, + service_version, + credentials=creds, + cache_discovery=False, + ) elif isinstance(creds, OAuthCredentials): - service = build(service_name, service_version, credentials=creds) + service = build( + service_name, + service_version, + credentials=creds, + cache_discovery=False, + ) return service diff --git a/common/data_source/imap_connector.py b/common/data_source/imap_connector.py index f682676e8ed..a8c1988f6ce 100644 --- a/common/data_source/imap_connector.py +++ b/common/data_source/imap_connector.py @@ -1,5 +1,6 @@ import copy import email +import hashlib from email.header import decode_header import imaplib import logging @@ -12,14 +13,26 @@ from enum import Enum from typing import Any from typing import cast -import uuid import bs4 from pydantic import BaseModel from common.data_source.config import IMAP_CONNECTOR_SIZE_THRESHOLD, DocumentSource -from common.data_source.interfaces import CheckpointOutput, CheckpointedConnectorWithPermSync, CredentialsConnector, CredentialsProviderInterface -from common.data_source.models import BasicExpertInfo, ConnectorCheckpoint, Document, ExternalAccess, SecondsSinceUnixEpoch +from common.data_source.interfaces import ( + CheckpointOutput, + CheckpointedConnectorWithPermSync, + CredentialsConnector, + CredentialsProviderInterface, +) +from common.data_source.models import ( + BasicExpertInfo, + ConnectorCheckpoint, + Document, + ExternalAccess, + GenerateSlimDocumentOutput, + SecondsSinceUnixEpoch, + SlimDocument, +) _DEFAULT_IMAP_PORT_NUMBER = int(os.environ.get("IMAP_PORT", 993)) _IMAP_OKAY_STATUS = "OK" @@ -86,9 +99,6 @@ def _parse_date(date_str: str | None) -> datetime | None: except (TypeError, ValueError): return None - message_id = _decode(header=Header.MESSAGE_ID_HEADER) - if not message_id: - message_id = f"" # It's possible for the subject line to not exist or be an empty string. subject = _decode(header=Header.SUBJECT_HEADER) or "Unknown Subject" from_ = _decode(header=Header.FROM_HEADER) @@ -97,11 +107,27 @@ def _parse_date(date_str: str | None) -> datetime | None: to = _decode(header=Header.DELIVERED_TO_HEADER) cc = _decode(header=Header.CC_HEADER) date_str = _decode(header=Header.DATE_HEADER) - date = _parse_date(date_str=date_str) + parsed_date = _parse_date(date_str=date_str) + date = parsed_date if not date: date = datetime.now(tz=timezone.utc) + message_id = _decode(header=Header.MESSAGE_ID_HEADER) + if not message_id: + message_id = _build_stable_generated_message_id( + email_msg=email_msg, + subject=subject, + sender=from_ or "", + recipients=to or "", + cc=cc or "", + date_key=( + _as_utc(parsed_date).isoformat() + if parsed_date + else (date_str or "") + ), + ) + # If any of the above are `None`, model validation will fail. # Therefore, no guards (i.e.: `if
is None: raise RuntimeError(..)`) were written. return cls.model_validate( @@ -269,12 +295,7 @@ def _load_from_checkpoint( continue email_headers = EmailHeaders.from_email_msg(email_msg=email_msg) - msg_dt = email_headers.date - if msg_dt.tzinfo is None: - msg_dt = msg_dt.replace(tzinfo=timezone.utc) - else: - msg_dt = msg_dt.astimezone(timezone.utc) - + msg_dt = _as_utc(email_headers.date) start_dt = datetime.fromtimestamp(start, tz=timezone.utc) end_dt = datetime.fromtimestamp(end, tz=timezone.utc) @@ -339,6 +360,64 @@ def load_from_checkpoint_with_perm_sync( start=start, end=end, checkpoint=checkpoint, include_perm_sync=True ) + def retrieve_all_slim_docs_perm_sync( + self, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + del callback + mail_client = self._get_mail_client() + start_ts = start if start is not None else 0 + end_ts = ( + end if end is not None else datetime.now(tz=timezone.utc).timestamp() + ) + start_dt = datetime.fromtimestamp(start_ts, tz=timezone.utc) + end_dt = datetime.fromtimestamp(end_ts, tz=timezone.utc) + + if self._mailboxes: + mailboxes = _sanitize_mailbox_names(self._mailboxes) + else: + mailboxes = _sanitize_mailbox_names( + _fetch_all_mailboxes_for_email_account(mail_client=mail_client) + ) + + slim_doc_batch: list[SlimDocument] = [] + for mailbox in mailboxes: + email_ids = _fetch_email_ids_in_mailbox( + mail_client=mail_client, + mailbox=mailbox, + start=start_ts, + end=end_ts, + ) + _select_mailbox(mail_client=mail_client, mailbox=mailbox) + + for email_id in email_ids: + email_msg = _fetch_email(mail_client=mail_client, email_id=email_id) + if not email_msg: + logging.warning(f"Failed to fetch message {email_id=}; skipping") + continue + + email_headers = EmailHeaders.from_email_msg(email_msg=email_msg) + msg_dt = _as_utc(email_headers.date) + if not (start_dt < msg_dt <= end_dt): + continue + + slim_doc_batch.append(SlimDocument(id=email_headers.id)) + for att in extract_attachments(email_msg): + slim_doc_batch.append( + SlimDocument( + id=_attachment_document_id(email_headers.id, att) + ) + ) + + if len(slim_doc_batch) >= _PAGE_SIZE: + yield slim_doc_batch + slim_doc_batch = [] + + if slim_doc_batch: + yield slim_doc_batch + def _fetch_all_mailboxes_for_email_account(mail_client: imaplib.IMAP4_SSL) -> list[str]: status, mailboxes_data = mail_client.list('""', "*") @@ -435,6 +514,39 @@ def _fetch_email(mail_client: imaplib.IMAP4_SSL, email_id: str) -> Message | Non return email.message_from_bytes(raw_email) +def _as_utc(dt: datetime) -> datetime: + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc) + + +def _build_stable_generated_message_id( + email_msg: Message, + subject: str, + sender: str, + recipients: str, + cc: str, + date_key: str, +) -> str: + body = _extract_email_body_text(email_msg) + raw_digest = hashlib.sha256(email_msg.as_bytes()).hexdigest() + body_digest = hashlib.sha256(body.encode("utf-8")).hexdigest() + digest = hashlib.sha256( + "\n".join( + [ + subject, + date_key, + sender, + recipients, + cc, + body_digest, + raw_digest, + ] + ).encode("utf-8") + ).hexdigest() + return f"generated:{digest}" + + def _convert_email_headers_and_body_into_document( email_msg: Message, email_headers: EmailHeaders, @@ -544,6 +656,13 @@ def decode_mime_filename(raw: str | None) -> str | None: return "".join(decoded) + +def _attachment_document_id(parent_doc_id: str, att: dict) -> str: + raw_filename = att["filename"] + filename = decode_mime_filename(raw_filename) or "attachment.bin" + return f"{parent_doc_id}#att:{filename}" + + def attachment_to_document( parent_doc: Document, att: dict, @@ -554,7 +673,7 @@ def attachment_to_document( ext = "." + filename.split(".")[-1] if "." in filename else "" return Document( - id=f"{parent_doc.id}#att:{filename}", + id=_attachment_document_id(parent_doc.id, att), source=DocumentSource.IMAP, semantic_identifier=filename, extension=ext, @@ -574,6 +693,15 @@ def _parse_email_body( email_msg: Message, email_headers: EmailHeaders, ) -> str: + body = _extract_email_body_text(email_msg) + if not body: + logging.warning( + f"Email with {email_headers.id=} has an empty body; returning an empty string" + ) + return body + + +def _extract_email_body_text(email_msg: Message) -> str: body = None for part in email_msg.walk(): if part.is_multipart(): @@ -598,9 +726,6 @@ def _parse_email_body( continue if not body: - logging.warning( - f"Email with {email_headers.id=} has an empty body; returning an empty string" - ) return "" soup = bs4.BeautifulSoup(markup=body, features="html.parser") @@ -636,6 +761,7 @@ def _parse_singular_addr(raw_header: str) -> tuple[str, str]: if __name__ == "__main__": import time + import uuid from types import TracebackType from common.data_source.utils import load_all_docs_from_checkpoint_connector diff --git a/common/data_source/interfaces.py b/common/data_source/interfaces.py index b68a40c1e1a..324293baaba 100644 --- a/common/data_source/interfaces.py +++ b/common/data_source/interfaces.py @@ -60,8 +60,6 @@ class SlimConnectorWithPermSync(ABC): @abstractmethod def retrieve_all_slim_docs_perm_sync( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, callback: Any = None, ) -> Generator[list[SlimDocument], None, None]: """Retrieve all simplified documents (with permission sync)""" diff --git a/common/data_source/jira/connector.py b/common/data_source/jira/connector.py index db3c3f8942d..aa4082f4149 100644 --- a/common/data_source/jira/connector.py +++ b/common/data_source/jira/connector.py @@ -149,7 +149,10 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None else: logger.warning("[Jira] Scoped token requested but Jira base URL does not appear to be an Atlassian Cloud domain; scoped token ignored.") - user_email = credentials.get("jira_user_email") or credentials.get("username") + user_email = ( + credentials.get("jira_user_email") + or credentials.get("jira_username") + ) api_token = credentials.get("jira_api_token") or credentials.get("token") or credentials.get("api_token") password = credentials.get("jira_password") or credentials.get("password") rest_api_version = credentials.get("rest_api_version") @@ -377,16 +380,14 @@ def validate_checkpoint_json(self, checkpoint_json: str) -> JiraCheckpoint: def retrieve_all_slim_docs_perm_sync( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, - callback: Any = None, # noqa: ARG002 - maintained for interface compatibility + callback: Any = None, # noqa: ARG002 - callback interface hook ) -> Generator[list[SlimDocument], None, None]: """Return lightweight references to Jira issues (used for permission syncing).""" if not self.jira_client: raise ConnectorMissingCredentialError("Jira") - start_ts = start if start is not None else 0 - end_ts = end if end is not None else datetime.now(timezone.utc).timestamp() + start_ts = 0 + end_ts = datetime.now(timezone.utc).timestamp() jql = self._build_jql(start_ts, end_ts) checkpoint = self.build_dummy_checkpoint() @@ -962,7 +963,16 @@ def main(config: dict[str, Any] | None = None) -> None: if not base_url: raise RuntimeError("Jira base URL must be provided via config or CLI arguments.") - if not (credentials.get("jira_api_token") or (credentials.get("jira_user_email") and credentials.get("jira_password"))): + if not ( + credentials.get("jira_api_token") + or ( + ( + credentials.get("jira_user_email") + or credentials.get("jira_username") + ) + and credentials.get("jira_password") + ) + ): raise RuntimeError("Provide either an API token or both email/password for Jira authentication.") connector_options = { diff --git a/common/data_source/moodle_connector.py b/common/data_source/moodle_connector.py index 39efcf07be0..850ce5815d1 100644 --- a/common/data_source/moodle_connector.py +++ b/common/data_source/moodle_connector.py @@ -21,14 +21,19 @@ LoadConnector, PollConnector, SecondsSinceUnixEpoch, + SlimConnectorWithPermSync, +) +from common.data_source.models import ( + Document, + GenerateSlimDocumentOutput, + SlimDocument, ) -from common.data_source.models import Document from common.data_source.utils import batch_generator, rl_requests logger = logging.getLogger(__name__) -class MoodleConnector(LoadConnector, PollConnector): +class MoodleConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """Moodle LMS connector for accessing course content""" def __init__(self, moodle_url: str, batch_size: int = INDEX_BATCH_SIZE) -> None: @@ -137,6 +142,78 @@ def poll_source( self._get_updated_content(courses, start, end) ) + @staticmethod + def _slim_doc_id_for_module(module) -> Optional[str]: + """Return the indexed document id for a Moodle module, or None. + + The id format must match the ones produced by the _process_* + helpers below. Module types that we never ingest (label, url) and + modules with no id return None. + """ + mtype = getattr(module, "modname", None) + mid = getattr(module, "id", None) + if not mtype or mid is None: + return None + if mtype in ("label", "url"): + return None + if mtype == "resource": + return f"moodle_resource_{mid}" + if mtype == "forum": + return f"moodle_forum_{mid}" + if mtype == "page": + return f"moodle_page_{mid}" + if mtype == "book": + return f"moodle_book_{mid}" + if mtype in ("assign", "quiz"): + return f"moodle_{mtype}_{mid}" + return None + + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + """List the ids of every Moodle module that could be indexed. + + This is a lightweight pass over courses and modules with no file + downloads. The caller compares the returned ids against the index + and removes any indexed document whose id is not in this list. + """ + del callback + if not self.moodle_client: + raise ConnectorMissingCredentialError("Moodle client not initialized") + + logger.info("Starting Moodle slim snapshot for stale-document cleanup") + courses = self._get_enrolled_courses() + if not courses: + logger.warning("No courses found for slim snapshot") + return + + batch: list[SlimDocument] = [] + total = 0 + for course in courses: + try: + contents = self._get_course_contents(course.id) + for section in contents: + for module in section.modules: + slim_id = self._slim_doc_id_for_module(module) + if slim_id is None: + continue + batch.append(SlimDocument(id=slim_id)) + total += 1 + if len(batch) >= self.batch_size: + yield batch + batch = [] + except Exception as e: + self._log_error( + f"slim snapshot for course {getattr(course, 'fullname', '?')}", + e, + ) + + if batch: + yield batch + + logger.info(f"Moodle slim snapshot completed: {total} documents listed") + @retry(tries=3, delay=1, backoff=2) def _get_enrolled_courses(self) -> list: if not self.moodle_client: diff --git a/common/data_source/notion_connector.py b/common/data_source/notion_connector.py index 30536dfb944..ea3d6d07646 100644 --- a/common/data_source/notion_connector.py +++ b/common/data_source/notion_connector.py @@ -28,9 +28,11 @@ from common.data_source.models import ( Document, GenerateDocumentsOutput, + GenerateSlimDocumentOutput, NotionBlock, NotionPage, NotionSearchResponse, + SlimDocument, TextSection, ) from common.data_source.utils import ( @@ -433,6 +435,45 @@ def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] return result_blocks, child_pages, attachments + def _read_slim_blocks(self, base_block_id: str) -> tuple[list[str], list[str]]: + child_pages: list[str] = [] + attachment_ids: list[str] = [] + cursor = None + + while True: + data = self._fetch_child_blocks(base_block_id, cursor) + + if data is None: + return child_pages, attachment_ids + + for result in data["results"]: + result_block_id = result["id"] + result_type = result["type"] + + if result_type in {"file", "image", "pdf", "video", "audio"}: + attachment_ids.append(result_block_id) + + if result["has_children"]: + if result_type == "child_page": + child_pages.append(result_block_id) + else: + nested_child_pages, nested_attachment_ids = self._read_slim_blocks( + result_block_id + ) + child_pages.extend(nested_child_pages) + attachment_ids.extend(nested_attachment_ids) + + if result_type == "child_database" and self.recursive_index_enabled: + _, inner_child_pages = self._read_pages_from_database(result_block_id) + child_pages.extend(inner_child_pages) + + if data["next_cursor"] is None: + break + + cursor = data["next_cursor"] + + return child_pages, attachment_ids + def _read_page_title(self, page: NotionPage) -> Optional[str]: """Extracts the title from a Notion page.""" if hasattr(page, "database_name") and page.database_name: @@ -552,6 +593,79 @@ def _recursive_load(self, start: SecondsSinceUnixEpoch | None = None, end: Secon pages = [self._fetch_page(page_id=self.root_page_id)] yield from batch_generator(self._read_pages(pages, start, end), self.batch_size) + def _read_pages_for_slim_docs( + self, + pages: list[NotionPage], + slim_indexed_pages: set[str], + ) -> Generator[SlimDocument, None, None]: + all_child_page_ids: list[str] = [] + + for page in pages: + if isinstance(page, dict): + page = NotionPage(**page) + if page.id in slim_indexed_pages: + continue + + child_page_ids, attachment_ids = self._read_slim_blocks(page.id) + all_child_page_ids.extend(child_page_ids) + slim_indexed_pages.add(page.id) + + yield SlimDocument(id=page.id) + for attachment_id in attachment_ids: + yield SlimDocument(id=attachment_id) + + if self.recursive_index_enabled and all_child_page_ids: + for child_page_batch_ids in batch_generator(all_child_page_ids, INDEX_BATCH_SIZE): + child_page_batch = [ + self._fetch_page(page_id) + for page_id in child_page_batch_ids + if page_id not in slim_indexed_pages + ] + yield from self._read_pages_for_slim_docs( + child_page_batch, + slim_indexed_pages, + ) + + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + slim_indexed_pages: set[str] = set() + + if self.recursive_index_enabled and self.root_page_id: + root_pages = [self._fetch_page(page_id=self.root_page_id)] + yield from batch_generator( + self._read_pages_for_slim_docs(root_pages, slim_indexed_pages), + self.batch_size, + ) + return + + query_dict = { + "filter": {"property": "object", "value": "page"}, + "page_size": 100, + } + + slim_batch: list[SlimDocument] = [] + while True: + db_res = self._search_notion(query_dict) + pages = [NotionPage(**page) for page in db_res.results] + + for doc in self._read_pages_for_slim_docs(pages, slim_indexed_pages): + slim_batch.append(doc) + if len(slim_batch) >= self.batch_size: + yield slim_batch + slim_batch = [] + if callback: + callback.progress("notion_slim_document", 1) + + if db_res.has_more: + query_dict["start_cursor"] = db_res.next_cursor + else: + break + + if slim_batch: + yield slim_batch + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Applies integration token to headers.""" self.headers["Authorization"] = f"Bearer {credentials['notion_integration_token']}" @@ -653,4 +767,4 @@ def validate_connector_settings(self) -> None: document_batches = connector.load_from_state() for doc_batch in document_batches: for doc in doc_batch: - print(doc) \ No newline at end of file + print(doc) diff --git a/common/data_source/rdbms_connector.py b/common/data_source/rdbms_connector.py index 05628501c65..9811d2064dc 100644 --- a/common/data_source/rdbms_connector.py +++ b/common/data_source/rdbms_connector.py @@ -1,5 +1,6 @@ """RDBMS (MySQL/PostgreSQL) data source connector for importing data from relational databases.""" +import copy import hashlib import json import logging @@ -12,8 +13,13 @@ ConnectorMissingCredentialError, ConnectorValidationError, ) -from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch -from common.data_source.models import Document +from common.data_source.interfaces import ( + LoadConnector, + PollConnector, + SecondsSinceUnixEpoch, + SlimConnectorWithPermSync, +) +from common.data_source.models import Document, SlimDocument class DatabaseType(str, Enum): @@ -22,15 +28,18 @@ class DatabaseType(str, Enum): POSTGRESQL = "postgresql" -class RDBMSConnector(LoadConnector, PollConnector): +class RDBMSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """ - RDBMS connector for importing data from MySQL and PostgreSQL databases. - - This connector allows users to: - 1. Connect to a MySQL or PostgreSQL database - 2. Execute a SQL query to extract data - 3. Map columns to content (for vectorization) and metadata - 4. Sync data in batch or incremental mode using a timestamp column + Import rows from MySQL or PostgreSQL into documents. + + The flow is: + 1. Connect to the configured database. + 2. Read rows from a custom SQL query, or from every table when no query is provided. + 3. Build document content from the selected content columns. + 4. Copy the selected metadata columns into document metadata. + 5. Use the configured ID column as the stable document ID, or hash the content when no ID column is set. + 6. For incremental sync, treat the timestamp column as an ordered cursor and only compare values by size. + 7. For deleted-file sync, read a slim snapshot of current row IDs and let the sync worker remove stale documents. """ def __init__( self, @@ -73,6 +82,9 @@ def __init__( self._connection = None self._credentials: Dict[str, Any] = {} + self._sync_connector_id: str | None = None + self._sync_config: Dict[str, Any] | None = None + self._pending_sync_cursor_value: Any = None def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None: """Load database credentials.""" @@ -160,98 +172,175 @@ def _get_tables(self) -> list[str]: finally: cursor.close() - def _build_query_with_time_filter( + + def _get_base_queries(self) -> list[str]: + if self.query: + return [self.query.rstrip(";")] + return [f"SELECT * FROM {table}" for table in self._get_tables()] + + + def _wrap_query(self, base_query: str, select_clause: str = "*") -> str: + return f"SELECT {select_clause} FROM ({base_query}) AS ragflow_src" + + + @staticmethod + def serialize_cursor_value(value: Any) -> Any: + # Example: + # - int cursor 42 is stored as 42 + # - datetime cursor 2026-05-07T12:34:56+00:00 is stored as + # {"__ragflow_rdbms_cursor_type__": "datetime", "value": "..."} + # Only datetime needs wrapping because connector config is JSON. + if isinstance(value, datetime): + return { + "__ragflow_rdbms_cursor_type__": "datetime", + "value": value.isoformat(), + } + return value + + + @staticmethod + def deserialize_cursor_value(value: Any) -> Any: + # Reverse the datetime wrapper above. + # Non-datetime cursors such as int/str/float are returned as-is. + if ( + isinstance(value, dict) + and value.get("__ragflow_rdbms_cursor_type__") == "datetime" + ): + return datetime.fromisoformat(value["value"]) + return value + + + def _format_sql_value(self, value: Any) -> str: + if isinstance(value, datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + if self.db_type == DatabaseType.MYSQL: + rendered = value.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + else: + rendered = value.astimezone(timezone.utc).isoformat() + return f"'{rendered}'" + if isinstance(value, bool): + if self.db_type == DatabaseType.POSTGRESQL: + return "TRUE" if value else "FALSE" + return "1" if value else "0" + if isinstance(value, (int, float)): + return str(value) + if isinstance(value, str): + return "'" + value.replace("'", "''") + "'" + raise ConnectorValidationError( + f"Unsupported timestamp cursor value type: {type(value).__name__}" + ) + + + def _build_time_filtered_query( self, - start: Optional[datetime] = None, - end: Optional[datetime] = None, + base_query: str, + start: Any = None, + end: Any = None, ) -> str: - """Build the query with optional time filtering for incremental sync.""" - if not self.query: - return "" # Will be handled by table discovery - base_query = self.query.rstrip(";") - if not self.timestamp_column or (start is None and end is None): - return base_query - - has_where = "where" in base_query.lower() - connector = " AND" if has_where else " WHERE" - - time_conditions = [] + return self._wrap_query(base_query) + + conditions = [] if start is not None: - if self.db_type == DatabaseType.MYSQL: - time_conditions.append(f"{self.timestamp_column} > '{start.strftime('%Y-%m-%d %H:%M:%S')}'") - else: - time_conditions.append(f"{self.timestamp_column} > '{start.isoformat()}'") - + conditions.append( + f"ragflow_src.{self.timestamp_column} > {self._format_sql_value(start)}" + ) if end is not None: - if self.db_type == DatabaseType.MYSQL: - time_conditions.append(f"{self.timestamp_column} <= '{end.strftime('%Y-%m-%d %H:%M:%S')}'") - else: - time_conditions.append(f"{self.timestamp_column} <= '{end.isoformat()}'") - - if time_conditions: - return f"{base_query}{connector} {' AND '.join(time_conditions)}" - - return base_query + conditions.append( + f"ragflow_src.{self.timestamp_column} <= {self._format_sql_value(end)}" + ) - def _row_to_document(self, row: Union[tuple, list, Dict[str, Any]], column_names: list) -> Document: - """Convert a database row to a Document.""" - row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row - + query = self._wrap_query(base_query) + if conditions: + query = f"{query} WHERE {' AND '.join(conditions)}" + return query + + + def _build_max_timestamp_query(self, base_query: str) -> str: + return ( + f"SELECT MAX(ragflow_src.{self.timestamp_column}) " + f"FROM ({base_query}) AS ragflow_src" + ) + + + def _build_slim_query(self, base_query: str) -> str: + columns = [self.id_column] if self.id_column else self.content_columns + select_clause = ", ".join(f"ragflow_src.{column}" for column in columns) + return self._wrap_query(base_query, select_clause) + + + def _build_content(self, row_dict: Dict[str, Any]) -> str: content_parts = [] for col in self.content_columns: - if col in row_dict and row_dict[col] is not None: - value = row_dict[col] - if isinstance(value, (dict, list)): - value = json.dumps(value, ensure_ascii=False) - # Use brackets around field name and put value on a new line - # so that TxtParser preserves field boundaries after chunking. - content_parts.append(f"【{col}】:\n{value}") - - content = "\n\n".join(content_parts) - - if self.id_column and self.id_column in row_dict: - doc_id = f"{self.db_type}:{self.database}:{row_dict[self.id_column]}" - else: - content_hash = hashlib.md5(content.encode()).hexdigest() - doc_id = f"{self.db_type}:{self.database}:{content_hash}" - + if col not in row_dict or row_dict[col] is None: + continue + value = row_dict[col] + if isinstance(value, (dict, list)): + value = json.dumps(value, ensure_ascii=False) + content_parts.append(f"【{col}】:\n{value}") + return "\n\n".join(content_parts) + + + def _build_document_id_from_row(self, row_dict: Dict[str, Any]) -> str: + if self.id_column and self.id_column in row_dict and row_dict[self.id_column] is not None: + return f"{self.db_type}:{self.database}:{row_dict[self.id_column]}" + content = self._build_content(row_dict) + content_hash = hashlib.md5(content.encode()).hexdigest() + return f"{self.db_type}:{self.database}:{content_hash}" + + + def _row_to_document( + self, + row: Union[tuple, list, Dict[str, Any]], + column_names: list[str], + ) -> Document: + """Convert a database row to a Document.""" + row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row + content = self._build_content(row_dict) metadata = {} for col in self.metadata_columns: - if col in row_dict and row_dict[col] is not None: - value = row_dict[col] - if isinstance(value, datetime): - value = value.isoformat() - elif isinstance(value, (dict, list)): - value = json.dumps(value, ensure_ascii=False) - else: - value = str(value) - metadata[col] = value - + if col not in row_dict or row_dict[col] is None: + continue + value = row_dict[col] + if isinstance(value, datetime): + value = value.isoformat() + elif isinstance(value, (dict, list)): + value = json.dumps(value, ensure_ascii=False) + else: + value = str(value) + metadata[col] = value + doc_updated_at = datetime.now(timezone.utc) - if self.timestamp_column and self.timestamp_column in row_dict: + if self.timestamp_column and self.timestamp_column in row_dict and row_dict[self.timestamp_column] is not None: ts_value = row_dict[self.timestamp_column] if isinstance(ts_value, datetime): if ts_value.tzinfo is None: doc_updated_at = ts_value.replace(tzinfo=timezone.utc) else: - doc_updated_at = ts_value - + doc_updated_at = ts_value.astimezone(timezone.utc) + first_content_col = self.content_columns[0] if self.content_columns else "record" - semantic_id = str(row_dict.get(first_content_col, "database_record")).replace("\n", " ").replace("\r", " ").strip()[:100] + semantic_id = ( + str(row_dict.get(first_content_col, "database_record")) + .replace("\n", " ") + .replace("\r", " ") + .strip()[:100] + ) + blob = content.encode("utf-8") - return Document( - id=doc_id, - blob=content.encode("utf-8"), + id=self._build_document_id_from_row(row_dict), + blob=blob, source=DocumentSource(self.db_type.value), semantic_identifier=semantic_id, extension=".txt", doc_updated_at=doc_updated_at, - size_bytes=len(content.encode("utf-8")), + size_bytes=len(blob), metadata=metadata if metadata else None, ) + def _yield_documents_from_query( self, query: str, @@ -288,30 +377,146 @@ def _yield_documents_from_query( pass cursor.close() + + def _yield_slim_documents_from_query( + self, + query: str, + ) -> Generator[list[SlimDocument], None, None]: + connection = self._get_connection() + cursor = connection.cursor() + + try: + logging.debug(f"Executing slim query: {query[:200]}...") + cursor.execute(query) + column_names = [desc[0] for desc in cursor.description] + + batch: list[SlimDocument] = [] + for row in cursor: + row_dict = dict(zip(column_names, row)) if isinstance(row, (list, tuple)) else row + batch.append(SlimDocument(id=self._build_document_id_from_row(row_dict))) + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + finally: + try: + cursor.fetchall() + except Exception: + pass + cursor.close() + + + def get_max_cursor_value(self) -> Any: + if not self.timestamp_column: + return None + + max_cursor_value = None + connection = self._get_connection() + cursor = connection.cursor() + + try: + for base_query in self._get_base_queries(): + query = self._build_max_timestamp_query(base_query) + logging.debug(f"Executing max timestamp query: {query[:200]}...") + cursor.execute(query) + row = cursor.fetchone() + if row is None or row[0] is None: + continue + if max_cursor_value is None or row[0] > max_cursor_value: + max_cursor_value = row[0] + finally: + cursor.close() + + return max_cursor_value + + def _yield_documents( self, - start: Optional[datetime] = None, - end: Optional[datetime] = None, + start: Any = None, + end: Any = None, ) -> Generator[list[Document], None, None]: """Generate documents from database query results.""" - if self.query: - query = self._build_query_with_time_filter(start, end) - yield from self._yield_documents_from_query(query) - else: - tables = self._get_tables() - logging.info(f"No query specified. Loading all {len(tables)} tables: {tables}") - for table in tables: - query = f"SELECT * FROM {table}" - logging.info(f"Loading table: {table}") + base_queries = self._get_base_queries() + if not self.query: + logging.info(f"No query specified. Loading all {len(base_queries)} tables.") + + try: + for base_query in base_queries: + query = self._build_time_filtered_query(base_query, start, end) yield from self._yield_documents_from_query(query) - - self._close_connection() + finally: + self._close_connection() + def load_from_state(self) -> Generator[list[Document], None, None]: """Load all documents from the database (full sync).""" logging.debug(f"Loading all records from {self.db_type} database: {self.database}") return self._yield_documents() + + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> Generator[list[SlimDocument], None, None]: + del callback + + base_queries = self._get_base_queries() + if not self.query: + logging.info(f"No query specified. Retrieving slim documents from all {len(base_queries)} tables.") + + try: + for base_query in base_queries: + yield from self._yield_slim_documents_from_query( + self._build_slim_query(base_query) + ) + finally: + self._close_connection() + + def prepare_sync_state(self, connector_id: str, config: Dict[str, Any]) -> None: + self._sync_connector_id = connector_id + self._sync_config = copy.deepcopy(config) + if not self.timestamp_column: + self._pending_sync_cursor_value = None + return + self._pending_sync_cursor_value = self.get_max_cursor_value() + + + def get_saved_sync_cursor_value(self) -> Any: + if self._sync_config is None: + return None + return self.deserialize_cursor_value(self._sync_config.get("sync_cursor_value")) + + + def persist_sync_state(self) -> None: + if not self.timestamp_column or self._sync_connector_id is None or self._sync_config is None: + return + + from api.db.services.connector_service import ConnectorService + + updated_conf = copy.deepcopy(self._sync_config) + updated_conf["sync_cursor_value"] = self.serialize_cursor_value( + self._pending_sync_cursor_value + ) + ConnectorService.update_by_id(self._sync_connector_id, {"config": updated_conf}) + self._sync_config = updated_conf + + + def load_from_cursor_range( + self, + start_value: Any = None, + end_value: Any = None, + ) -> Generator[list[Document], None, None]: + if end_value is None: + self._close_connection() + return iter(()) + if start_value is not None and end_value <= start_value: + self._close_connection() + return iter(()) + return self._yield_documents(start_value, end_value) + + def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> Generator[list[Document], None, None]: @@ -322,16 +527,8 @@ def poll_source( "Falling back to full sync." ) return self.load_from_state() - - start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) - end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) - - logging.debug( - f"Polling {self.db_type} database {self.database} " - f"from {start_datetime} to {end_datetime}" - ) - - return self._yield_documents(start_datetime, end_datetime) + return self._yield_documents(start, end) + def validate_connector_settings(self) -> None: """Validate connector settings by testing the connection.""" diff --git a/common/data_source/rss_connector.py b/common/data_source/rss_connector.py index 85471407abc..6fad756d73b 100644 --- a/common/data_source/rss_connector.py +++ b/common/data_source/rss_connector.py @@ -1,44 +1,29 @@ import hashlib -import ipaddress -import socket from datetime import datetime, timezone from email.utils import parsedate_to_datetime from time import struct_time from typing import Any -from urllib.parse import urlparse +from urllib.parse import urljoin, urlparse import bs4 import feedparser import requests from common.data_source.config import INDEX_BATCH_SIZE, REQUEST_TIMEOUT_SECONDS, DocumentSource -from common.data_source.interfaces import LoadConnector, PollConnector -from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch +from common.data_source.interfaces import LoadConnector, PollConnector, SlimConnectorWithPermSync +from common.data_source.models import ( + Document, + GenerateDocumentsOutput, + GenerateSlimDocumentOutput, + SecondsSinceUnixEpoch, + SlimDocument, +) +from common.ssrf_guard import assert_url_is_safe, pin_dns as _pin_dns +_MAX_REDIRECTS = 10 -def _is_private_ip(ip: str) -> bool: - try: - ip_obj = ipaddress.ip_address(ip) - return ip_obj.is_private or ip_obj.is_link_local or ip_obj.is_loopback - except ValueError: - return False - -def _validate_url_no_ssrf(url: str) -> None: - parsed = urlparse(url) - hostname = parsed.hostname - if not hostname: - raise ValueError("URL must have a valid hostname") - - try: - ip = socket.gethostbyname(hostname) - if _is_private_ip(ip): - raise ValueError(f"URL resolves to private/internal IP address: {ip}") - except socket.gaierror as e: - raise ValueError(f"Failed to resolve hostname: {hostname}") from e - - -class RSSConnector(LoadConnector, PollConnector): +class RSSConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): def __init__(self, feed_url: str, batch_size: int = INDEX_BATCH_SIZE) -> None: self.feed_url = feed_url.strip() self.batch_size = batch_size @@ -61,6 +46,25 @@ def load_from_state(self) -> GenerateDocumentsOutput: def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput: yield from self._load_entries(start=start, end=end) + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + del callback + + feed = self._read_feed(require_entries=False) + batch: list[SlimDocument] = [] + + for entry in feed.entries: + batch.append(SlimDocument(id=self._build_document_id(entry))) + + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + def _load_entries( self, start: SecondsSinceUnixEpoch | None = None, @@ -87,7 +91,8 @@ def _load_entries( if batch: yield batch - def _validate_feed_url(self) -> None: + def _validate_feed_url(self) -> tuple[str, str]: + """Validate ``self.feed_url`` and return ``(hostname, resolved_ip)``.""" if not self.feed_url: raise ValueError("feed_url is required") @@ -95,7 +100,7 @@ def _validate_feed_url(self) -> None: if parsed.scheme not in {"http", "https"} or not parsed.netloc: raise ValueError("feed_url must be a valid http or https URL") - _validate_url_no_ssrf(self.feed_url) + return assert_url_is_safe(self.feed_url) def _read_feed(self, require_entries: bool) -> Any: if self._cached_feed is not None: @@ -103,15 +108,38 @@ def _read_feed(self, require_entries: bool) -> Any: raise ValueError("RSS feed contains no entries") return self._cached_feed - self._validate_feed_url() + # Validate once to get the pinned IP for the initial request. + current_hostname, current_ip = self._validate_feed_url() + current_url = self.feed_url + + # Follow redirects manually: each hop is validated and DNS-pinned + # *before* the connection is made, closing the TOCTOU rebinding window + # that existed when allow_redirects=True was used with post-hoc checks. + response: requests.Response | None = None + for _ in range(_MAX_REDIRECTS + 1): + with _pin_dns(current_hostname, current_ip): + response = requests.get( + current_url, + timeout=REQUEST_TIMEOUT_SECONDS, + allow_redirects=False, + ) + + if response.status_code not in (301, 302, 303, 307, 308): + break + + location = response.headers.get("Location") + if not location: + break # broken redirect; let raise_for_status() handle it + + redirect_url = urljoin(current_url, location) + # Validate redirect target before following it. + current_hostname, current_ip = assert_url_is_safe(redirect_url) + current_url = redirect_url + else: + raise ValueError(f"Exceeded {_MAX_REDIRECTS} redirects fetching {self.feed_url!r}") - response = requests.get(self.feed_url, timeout=REQUEST_TIMEOUT_SECONDS, allow_redirects=True) response.raise_for_status() - final_url = getattr(response, "url", self.feed_url) - if final_url != self.feed_url and urlparse(final_url).hostname: - _validate_url_no_ssrf(final_url) - feed = feedparser.parse(response.content) if getattr(feed, "bozo", False) and not feed.entries: error = getattr(feed, "bozo_exception", None) @@ -127,7 +155,7 @@ def _read_feed(self, require_entries: bool) -> Any: def _build_document(self, entry: Any, updated_at: datetime) -> Document: link = (entry.get("link") or "").strip() title = (entry.get("title") or "").strip() - stable_key = (entry.get("id") or link or title or self.feed_url).strip() + stable_key = self._resolve_stable_key(entry) semantic_identifier = title or link or stable_key content = self._build_content(entry, semantic_identifier) blob = content.encode("utf-8") @@ -149,7 +177,7 @@ def _build_document(self, entry: Any, updated_at: datetime) -> Document: metadata["categories"] = categories return Document( - id=f"rss:{hashlib.md5(stable_key.encode('utf-8')).hexdigest()}", + id=self._build_document_id(entry), source=DocumentSource.RSS, semantic_identifier=semantic_identifier, extension=".txt", @@ -177,6 +205,15 @@ def _build_content(self, entry: Any, semantic_identifier: str) -> str: return "\n\n".join(part for part in parts if part).strip() + def _build_document_id(self, entry: Any) -> str: + stable_key = self._resolve_stable_key(entry) + return f"rss:{hashlib.md5(stable_key.encode('utf-8')).hexdigest()}" + + def _resolve_stable_key(self, entry: Any) -> str: + link = (entry.get("link") or "").strip() + title = (entry.get("title") or "").strip() + return (entry.get("id") or link or title or self.feed_url).strip() + def _resolve_entry_time(self, entry: Any) -> datetime: for field in ("updated_parsed", "published_parsed"): value = entry.get(field) diff --git a/common/data_source/seafile_connector.py b/common/data_source/seafile_connector.py index ef7afeecf47..66bcf954fde 100644 --- a/common/data_source/seafile_connector.py +++ b/common/data_source/seafile_connector.py @@ -20,17 +20,19 @@ CredentialExpiredError, InsufficientPermissionsError, ) -from common.data_source.interfaces import LoadConnector, PollConnector +from common.data_source.interfaces import LoadConnector, PollConnector, SlimConnectorWithPermSync from common.data_source.models import ( Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput, + GenerateSlimDocumentOutput, SeafileSyncScope, + SlimDocument, ) logger = logging.getLogger(__name__) -class SeaFileConnector(LoadConnector, PollConnector): +class SeaFileConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """SeaFile connector supporting account-, library- and directory-level sync. API endpoints used: @@ -357,8 +359,18 @@ def _get_repo_info(self) -> Optional[dict]: return self._get_repo_info_via_account(self.repo_id) @retry(tries=3, delay=1, backoff=2) - def _get_directory_entries(self, repo_id: str, path: str = "/") -> list[dict]: - """List directory contents using the appropriate endpoint.""" + def _get_directory_entries( + self, + repo_id: str, + path: str = "/", + *, + raise_on_failure: bool = False, + ) -> list[dict]: + """List directory contents using the appropriate endpoint. + + When ``raise_on_failure`` is True (used for slim snapshots), HTTP/API errors + propagate so callers do not treat a failed listing as an empty directory. + """ try: if self._use_repo_token: # GET /api/v2.1/via-repo-token/dir/?path=/foo @@ -380,6 +392,8 @@ def _get_directory_entries(self, repo_id: str, path: str = "/") -> list[dict]: logger.warning( "Error fetching directory %s in repo %s: %s", path, repo_id, e, ) + if raise_on_failure: + raise return [] @retry(tries=3, delay=1, backoff=2) @@ -412,9 +426,14 @@ def _list_files_recursive( path: str, start: datetime, end: datetime, + *, + filter_by_mtime: bool = True, + strict_listing: bool = False, ) -> list[tuple[str, dict, dict]]: files = [] - entries = self._get_directory_entries(repo_id, path) + entries = self._get_directory_entries( + repo_id, path, raise_on_failure=strict_listing, + ) for entry in entries: entry_type = entry.get("type") @@ -424,15 +443,33 @@ def _list_files_recursive( if entry_type == "dir": files.extend( self._list_files_recursive( - repo_id, repo_name, entry_path, start, end, + repo_id, + repo_name, + entry_path, + start, + end, + filter_by_mtime=filter_by_mtime, + strict_listing=strict_listing, ) ) elif entry_type == "file": modified = self._parse_mtime(entry.get("mtime")) - if start < modified <= end: + if filter_by_mtime: + if start < modified <= end: + files.append( + ( + entry_path, + entry, + {"id": repo_id, "name": repo_name}, + ) + ) + else: files.append( - (entry_path, entry, - {"id": repo_id, "name": repo_name}) + ( + entry_path, + entry, + {"id": repo_id, "name": repo_name}, + ) ) return files @@ -473,6 +510,8 @@ def _yield_seafile_documents( try: files = self._list_files_recursive( lib["id"], lib["name"], root, start, end, + filter_by_mtime=True, + strict_listing=False, ) all_files.extend(files) except Exception as e: @@ -539,4 +578,59 @@ def poll_source( for batch in self._yield_seafile_documents(start_dt, end_dt): yield batch - \ No newline at end of file + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + """Full snapshot of file IDs eligible for indexing (no downloads). + + Uses ``seafile:{repo_id}:{file_id}`` matching :meth:`_yield_seafile_documents`. + Listing uses strict directory reads (errors propagate) so partial snapshots + are never treated as authoritative for stale-document cleanup. + """ + del callback + logger.info( + "Starting SeaFile slim snapshot: scope=%s url=%s", + self.sync_scope.value, + self.seafile_url, + ) + + libraries = self._resolve_libraries_to_scan() + all_files: list[tuple[str, dict, dict]] = [] + for lib in libraries: + root = self._root_path_for_repo(lib["id"]) + span_start = datetime(1970, 1, 1, tzinfo=timezone.utc) + span_end = datetime.now(timezone.utc) + listed = self._list_files_recursive( + lib["id"], + lib["name"], + root, + span_start, + span_end, + filter_by_mtime=False, + strict_listing=True, + ) + all_files.extend(listed) + + batch: list[SlimDocument] = [] + total = 0 + for file_path, file_entry, library in all_files: + file_size = file_entry.get("size", 0) + if file_size > self.size_threshold: + continue + file_id = file_entry.get("id", "") + repo_id = library["id"] + batch.append(SlimDocument(id=f"seafile:{repo_id}:{file_id}")) + total += 1 + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + logger.info( + "Completed SeaFile slim snapshot: %d documents (listed_paths=%d)", + total, + len(all_files), + ) diff --git a/common/data_source/sharepoint_connector.py b/common/data_source/sharepoint_connector.py index 7bc8e3410dc..e5684023c15 100644 --- a/common/data_source/sharepoint_connector.py +++ b/common/data_source/sharepoint_connector.py @@ -112,10 +112,8 @@ def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint: def retrieve_all_slim_docs_perm_sync( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, callback: Any = None, ) -> Any: """Retrieve all simplified documents with permission sync""" # Simplified implementation - return [] \ No newline at end of file + return [] diff --git a/common/data_source/slack_connector.py b/common/data_source/slack_connector.py index 5fabc3d00fb..162826762cd 100644 --- a/common/data_source/slack_connector.py +++ b/common/data_source/slack_connector.py @@ -528,8 +528,6 @@ def set_credentials_provider(self, credentials_provider: Any) -> None: def retrieve_all_slim_docs_perm_sync( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, callback: Any = None, ) -> GenerateSlimDocumentOutput: if self.client is None: @@ -662,4 +660,4 @@ def get_credentials(self): connector.validate_connector_settings() print("Slack connector settings validated successfully") except Exception as e: - print(f"Validation failed: {e}") \ No newline at end of file + print(f"Validation failed: {e}") diff --git a/common/data_source/teams_connector.py b/common/data_source/teams_connector.py index 0b4cd564252..98b472667a0 100644 --- a/common/data_source/teams_connector.py +++ b/common/data_source/teams_connector.py @@ -106,10 +106,8 @@ def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint: def retrieve_all_slim_docs_perm_sync( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, callback: Any = None, ) -> Any: """Retrieve all simplified documents with permission sync""" # Simplified implementation - return [] \ No newline at end of file + return [] diff --git a/common/data_source/webdav_connector.py b/common/data_source/webdav_connector.py index b860c0b61ae..6ea6558ad5b 100644 --- a/common/data_source/webdav_connector.py +++ b/common/data_source/webdav_connector.py @@ -17,11 +17,11 @@ CredentialExpiredError, InsufficientPermissionsError ) -from common.data_source.interfaces import LoadConnector, OnyxExtensionType, PollConnector -from common.data_source.models import Document, SecondsSinceUnixEpoch, GenerateDocumentsOutput +from common.data_source.interfaces import LoadConnector, OnyxExtensionType, PollConnector, SlimConnectorWithPermSync +from common.data_source.models import Document, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SecondsSinceUnixEpoch, SlimDocument -class WebDAVConnector(LoadConnector, PollConnector): +class WebDAVConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """WebDAV connector for syncing files from WebDAV servers""" def __init__( @@ -102,17 +102,20 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None return None def _list_files_recursive( - self, + self, path: str, start: datetime, end: datetime, + *, + filter_by_mtime: bool = True, ) -> list[tuple[str, dict]]: """Recursively list all files in the given path Args: path: Path to list files from - start: Start datetime for filtering - end: End datetime for filtering + start: Start datetime for filtering (ignored when ``filter_by_mtime`` is False) + end: End datetime for filtering (ignored when ``filter_by_mtime`` is False) + filter_by_mtime: When False, include every supported extension without mtime window Returns: List of tuples containing (file_path, file_info) @@ -134,7 +137,14 @@ def _list_files_recursive( if item.get('type') == 'directory': try: - files.extend(self._list_files_recursive(item_path, start, end)) + files.extend( + self._list_files_recursive( + item_path, + start, + end, + filter_by_mtime=filter_by_mtime, + ) + ) except Exception as e: logging.error(f"Error recursing into directory {item_path}: {e}") continue @@ -168,10 +178,13 @@ def _list_files_recursive( logging.debug(f"File {item_path}: modified={modified}, start={start}, end={end}, include={start < modified <= end}") - if start < modified <= end: - files.append((item_path, item)) + if filter_by_mtime: + if start < modified <= end: + files.append((item_path, item)) + else: + logging.debug(f"File {item_path} filtered out by time range") else: - logging.debug(f"File {item_path} filtered out by time range") + files.append((item_path, item)) except Exception as e: logging.error(f"Error processing file {item_path}: {e}") continue @@ -323,6 +336,61 @@ def poll_source( for batch in self._yield_webdav_documents(start_datetime, end_datetime): yield batch + def retrieve_all_slim_docs_perm_sync( + self, + callback: Any = None, + ) -> GenerateSlimDocumentOutput: + """Full-tree snapshot of indexed paths for stale-document reconciliation. + + Uses the same ``webdav:{base_url}:{file_path}`` ids as :meth:`_yield_webdav_documents`, + without downloading file contents. + """ + del callback + if self.client is None: + raise ConnectorMissingCredentialError("WebDAV client not initialized") + + logging.info( + "Starting WebDAV slim snapshot: base_url=%s path=%s", + self.base_url, + self.remote_path, + ) + + files = self._list_files_recursive( + self.remote_path, + datetime(1970, 1, 1, tzinfo=timezone.utc), + datetime.now(timezone.utc), + filter_by_mtime=False, + ) + batch: list[SlimDocument] = [] + total = 0 + for file_path, file_info in files: + file_name = os.path.basename(file_path) + if not self._is_supported_file(file_name): + continue + size_bytes = file_info.get("size", 0) + if ( + self.size_threshold is not None + and isinstance(size_bytes, int) + and size_bytes > self.size_threshold + ): + continue + batch.append( + SlimDocument(id=f"webdav:{self.base_url}:{file_path}") + ) + total += 1 + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + logging.info( + "Completed WebDAV slim snapshot: %d documents (listed_paths=%d)", + total, + len(files), + ) + def validate_connector_settings(self) -> None: """Validate WebDAV connector settings. diff --git a/common/data_source/zendesk_connector.py b/common/data_source/zendesk_connector.py index 85b3426fe3f..c357b500fb7 100644 --- a/common/data_source/zendesk_connector.py +++ b/common/data_source/zendesk_connector.py @@ -246,6 +246,18 @@ def _article_to_document( ) +def _is_indexable_article(article: dict[str, Any]) -> bool: + body = article.get("body") + return ( + bool(body) + and not article.get("draft") + and not any( + label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS + for label in article.get("label_names") or [] + ) + ) + + def _get_comment_text( comment: dict[str, Any], author_map: dict[str, BasicExpertInfo], @@ -333,6 +345,10 @@ def _ticket_to_document( ) +def _is_indexable_ticket(ticket: dict[str, Any]) -> bool: + return ticket.get("status") != "deleted" + + class ZendeskConnectorCheckpoint(ConnectorCheckpoint): # We use cursor-based paginated retrieval for articles after_cursor_articles: str | None @@ -419,14 +435,7 @@ def _retrieve_articles( has_more = response.has_more after_cursor = response.meta.get("after_cursor") for article in articles: - if ( - article.get("body") is None - or article.get("draft") - or any( - label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS - for label in article.get("label_names", []) - ) - ): + if not _is_indexable_article(article): continue try: @@ -498,7 +507,7 @@ def _retrieve_tickets( has_more = ticket_response.has_more next_start_time = ticket_response.meta["end_time"] for ticket in tickets: - if ticket.get("status") == "deleted": + if not _is_indexable_ticket(ticket): continue try: @@ -553,16 +562,14 @@ def _retrieve_tickets( def retrieve_all_slim_docs_perm_sync( self, - start: SecondsSinceUnixEpoch | None = None, - end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: slim_doc_batch: list[SlimDocument] = [] if self.content_type == "articles": - articles = _get_articles( - self.client, start_time=int(start) if start else None - ) + articles = _get_articles(self.client) for article in articles: + if not _is_indexable_article(article): + continue slim_doc_batch.append( SlimDocument( id=f"article:{article['id']}", @@ -572,10 +579,10 @@ def retrieve_all_slim_docs_perm_sync( yield slim_doc_batch slim_doc_batch = [] elif self.content_type == "tickets": - tickets = _get_tickets( - self.client, start_time=int(start) if start else None - ) + tickets = _get_tickets(self.client) for ticket in tickets: + if not _is_indexable_ticket(ticket): + continue slim_doc_batch.append( SlimDocument( id=f"zendesk_ticket_{ticket['id']}", @@ -664,4 +671,4 @@ def build_dummy_checkpoint(self) -> ZendeskConnectorCheckpoint: checkpoint = next_checkpoint if any_doc: - break \ No newline at end of file + break diff --git a/common/doc_store/infinity_conn_base.py b/common/doc_store/infinity_conn_base.py index 20baa34a60a..af8493b82b2 100644 --- a/common/doc_store/infinity_conn_base.py +++ b/common/doc_store/infinity_conn_base.py @@ -16,10 +16,12 @@ import logging import os +import random import re import json import time from abc import abstractmethod +from typing import Callable, TypeVar import infinity from infinity.common import ConflictType @@ -32,6 +34,117 @@ from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr +# Concurrent CREATE/DROP TABLE on the same Infinity instance can race on +# Infinity's RocksDB-backed catalog counters (e.g. ``db|1|next_table_id``). +# When two writers touch the counter at the same instant, Infinity surfaces +# error 9003 / "Resource busy" instead of waiting on a lock — turning a +# user-visible operation into an avoidable failure under modest concurrency +# (two users creating a knowledge base at the same time, batch onboarding, +# multi-replica deployments, …). +# +# We retry the metadata path (CREATE TABLE / CREATE INDEX / DROP TABLE) on +# this specific error with exponential backoff + jitter. The wrapped calls +# already use ``ConflictType.Ignore``, so re-running them on retry is +# idempotent. The retry budget is intentionally bounded (5 attempts, +# ~1.5s worst case) so a genuine outage still surfaces quickly. +# +# Tunable from the environment: +# INFINITY_META_RETRY_MAX default 5 +# INFINITY_META_RETRY_BASE_DELAY_MS default 50 + +_T = TypeVar("_T") + +# Infinity error code 9003 is raised on RocksDB transaction contention. It is +# not in the SDK's ErrorCode enum yet, so we keep the literal here. +_INFINITY_RESOURCE_BUSY_CODE = 9003 + + +def _int_env(name: str, default: int) -> int: + """Read an int from the environment without crashing on bad input. + + A misconfigured ``INFINITY_META_RETRY_MAX=`` (empty value) or non-numeric + string would otherwise raise ``ValueError`` at module import time and + take down every backend worker. We log and fall back to the default + instead. + """ + raw = os.getenv(name) + if raw is None or raw == "": + return default + try: + return int(raw) + except ValueError: + logging.getLogger(__name__).warning( + "Ignoring invalid %s=%r, falling back to %d", name, raw, default, + ) + return default + + +_META_RETRY_MAX = _int_env("INFINITY_META_RETRY_MAX", 5) +_META_RETRY_BASE_DELAY_MS = _int_env("INFINITY_META_RETRY_BASE_DELAY_MS", 50) + + +def _is_meta_contention_error(exc: BaseException) -> bool: + """Return True iff ``exc`` is the RocksDB metadata-counter "Resource busy". + + Prefer the numeric error code when the SDK exposes one — substring matching + on ``str(exc)`` is the fallback for older SDKs that surface only a tuple + or a plain string. Both surfaces are observed in the wild today. + """ + code = getattr(exc, "error_code", None) + if code is None: + # Some Infinity SDK paths raise a plain ``Exception((9003, "..."))`` + # whose ``args[0]`` carries the code. + args = getattr(exc, "args", None) + if args and isinstance(args, tuple) and args: + code = args[0] + if code == _INFINITY_RESOURCE_BUSY_CODE: + return True + msg = str(exc) + return "Resource busy" in msg and "rocksdb" in msg.lower() + + +def _retry_on_meta_contention( + op_name: str, + operation: Callable[[], _T], + *, + logger: logging.Logger | None = None, + max_attempts: int = _META_RETRY_MAX, + base_delay_ms: int = _META_RETRY_BASE_DELAY_MS, +) -> _T: + """Run ``operation`` and retry on RocksDB "Resource busy" errors. + + Exponential backoff with ±50% jitter to avoid a thundering herd when many + workers retry simultaneously. Any exception that does not match + :func:`_is_meta_contention_error` is re-raised immediately so genuine + failures still surface fast. + """ + log = logger or logging.getLogger(__name__) + last_exc: BaseException | None = None + for attempt in range(max_attempts): + try: + return operation() + except Exception as exc: + if not _is_meta_contention_error(exc): + raise + last_exc = exc + if attempt == max_attempts - 1: + break + base = (base_delay_ms / 1000.0) * (2 ** attempt) + sleep_for = base + random.uniform(0, base * 0.5) + log.info( + "INFINITY meta contention on %s (attempt %d/%d), " + "retrying in %.3fs: %s", + op_name, attempt + 1, max_attempts, sleep_for, exc, + ) + time.sleep(sleep_for) + log.warning( + "INFINITY meta contention on %s exhausted %d attempts: %s", + op_name, max_attempts, last_exc, + ) + assert last_exc is not None + raise last_exc + + class InfinityConnectionBase(DocStoreConnection): def __init__(self, mapping_file_name: str = "infinity_mapping.json", logger_name: str = "ragflow.infinity_conn", table_name_prefix: str="ragflow_"): from common.doc_store.infinity_conn_pool import INFINITY_CONN @@ -173,7 +286,15 @@ def exists(cln): cond = list() for k, v in condition.items(): - if not isinstance(k, str) or not v: + if not isinstance(k, str): + continue + if k == "available_int": + if v == 0: + cond.append("available_int=0") + elif v == 1: + cond.append("available_int=1") + continue + if not v: continue if self.field_keyword(k): if isinstance(v, list): @@ -266,7 +387,11 @@ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_ inf_conn = self.connPool.get_conn() try: - inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore) + inf_db = _retry_on_meta_contention( + f"create_database({self.dbName})", + lambda: inf_conn.create_database(self.dbName, ConflictType.Ignore), + logger=self.logger, + ) # Use configured schema fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name) @@ -285,24 +410,32 @@ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_ vector_name = f"q_{vector_size}_vec" schema[vector_name] = {"type": f"vector,{vector_size},float"} - inf_table = inf_db.create_table( - table_name, - schema, - ConflictType.Ignore, + inf_table = _retry_on_meta_contention( + f"create_table({table_name})", + lambda: inf_db.create_table( + table_name, + schema, + ConflictType.Ignore, + ), + logger=self.logger, ) - inf_table.create_index( - "q_vec_idx", - IndexInfo( - vector_name, - IndexType.Hnsw, - { - "M": "16", - "ef_construction": "50", - "metric": "cosine", - "encode": "lvq", - }, + _retry_on_meta_contention( + f"create_index(q_vec_idx, {table_name})", + lambda: inf_table.create_index( + "q_vec_idx", + IndexInfo( + vector_name, + IndexType.Hnsw, + { + "M": "16", + "ef_construction": "50", + "metric": "cosine", + "encode": "lvq", + }, + ), + ConflictType.Ignore, ), - ConflictType.Ignore, + logger=self.logger, ) for field_name, field_info in schema.items(): if field_info["type"] != "varchar" or "analyzer" not in field_info: @@ -311,10 +444,15 @@ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_ if isinstance(analyzers, str): analyzers = [analyzers] for analyzer in analyzers: - inf_table.create_index( - f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}", - IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}), - ConflictType.Ignore, + idx_name = f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}" + _retry_on_meta_contention( + f"create_index({idx_name}, {table_name})", + lambda fn=field_name, an=analyzer, name=idx_name: inf_table.create_index( + name, + IndexInfo(fn, IndexType.FullText, {"ANALYZER": an}), + ConflictType.Ignore, + ), + logger=self.logger, ) # Create secondary indexes for fields with index_type @@ -323,10 +461,14 @@ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_ continue index_config = field_info["index_type"] if isinstance(index_config, str) and index_config == "secondary": - inf_table.create_index( - f"sec_{field_name}", - IndexInfo(field_name, IndexType.Secondary), - ConflictType.Ignore, + _retry_on_meta_contention( + f"create_index(sec_{field_name}, {table_name})", + lambda fn=field_name: inf_table.create_index( + f"sec_{fn}", + IndexInfo(fn, IndexType.Secondary), + ConflictType.Ignore, + ), + logger=self.logger, ) self.logger.info(f"INFINITY created secondary index sec_{field_name} for field {field_name}") elif isinstance(index_config, dict): @@ -334,10 +476,14 @@ def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_ params = {} if "cardinality" in index_config: params = {"cardinality": index_config["cardinality"]} - inf_table.create_index( - f"sec_{field_name}", - IndexInfo(field_name, IndexType.Secondary, params), - ConflictType.Ignore, + _retry_on_meta_contention( + f"create_index(sec_{field_name}, {table_name})", + lambda fn=field_name, p=params: inf_table.create_index( + f"sec_{fn}", + IndexInfo(fn, IndexType.Secondary, p), + ConflictType.Ignore, + ), + logger=self.logger, ) self.logger.info(f"INFINITY created secondary index sec_{field_name} for field {field_name} with params {params}") @@ -355,18 +501,26 @@ def create_doc_meta_idx(self, index_name: str): """ table_name = index_name inf_conn = self.connPool.get_conn() - inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore) try: + inf_db = _retry_on_meta_contention( + f"create_database({self.dbName})", + lambda: inf_conn.create_database(self.dbName, ConflictType.Ignore), + logger=self.logger, + ) fp_mapping = os.path.join(get_project_base_directory(), "conf", "doc_meta_infinity_mapping.json") if not os.path.exists(fp_mapping): self.logger.error(f"Document metadata mapping file not found at {fp_mapping}") return False with open(fp_mapping) as f: schema = json.load(f) - inf_db.create_table( - table_name, - schema, - ConflictType.Ignore, + _retry_on_meta_contention( + f"create_table({table_name})", + lambda: inf_db.create_table( + table_name, + schema, + ConflictType.Ignore, + ), + logger=self.logger, ) # Create secondary indexes on id and kb_id for better query performance @@ -392,14 +546,14 @@ def create_doc_meta_idx(self, index_name: str): except Exception as e: self.logger.warning(f"Failed to create index on kb_id for {table_name}: {e}") - self.connPool.release_conn(inf_conn) self.logger.debug(f"INFINITY created document metadata table {table_name} with secondary indexes") return True except Exception as e: - self.connPool.release_conn(inf_conn) self.logger.exception(f"Error creating document metadata table {table_name}: {e}") return False + finally: + self.connPool.release_conn(inf_conn) def delete_idx(self, index_name: str, dataset_id: str): if index_name.startswith("ragflow_doc_meta_"): @@ -409,7 +563,11 @@ def delete_idx(self, index_name: str, dataset_id: str): inf_conn = self.connPool.get_conn() try: db_instance = inf_conn.get_database(self.dbName) - db_instance.drop_table(table_name, ConflictType.Ignore) + _retry_on_meta_contention( + f"drop_table({table_name})", + lambda: db_instance.drop_table(table_name, ConflictType.Ignore), + logger=self.logger, + ) self.logger.info(f"INFINITY dropped table {table_name}") finally: self.connPool.release_conn(inf_conn) diff --git a/common/metadata_es_filter.py b/common/metadata_es_filter.py new file mode 100644 index 00000000000..afe0f27386e --- /dev/null +++ b/common/metadata_es_filter.py @@ -0,0 +1,580 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Translate RAGflow document-metadata filter lists into Elasticsearch DSL. + +The legacy ``common.metadata_utils.meta_filter`` evaluates user-defined +metadata conditions in Python after loading every document's metadata into +memory. That works for small knowledge bases but degrades badly past a few +thousand documents. This module produces an equivalent ES bool query so the +filtering can be pushed down to the search engine. + +Operators handled here mirror ``meta_filter`` exactly. When a filter cannot be +translated (unknown operator, malformed value, list-typed input that the +in-memory code special-cases) the translator raises +:class:`UnsupportedMetaFilter` so callers fall back to the in-memory path +without silently changing semantics. +""" + +from __future__ import annotations + +import ast +import re +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, List, Optional, Sequence + +# Field prefix in the doc-metadata ES index. Every user metadata key lives at +# ``meta_fields.`` thanks to the dynamic object mapping in +# ``conf/doc_meta_es_mapping.json``. +META_FIELDS_PREFIX = "meta_fields" + +# Strict ``YYYY-MM-DD`` recogniser, kept consistent with the legacy in-memory +# path. Mismatched-type comparisons (string vs date, list vs scalar) fall back +# to in-memory semantics rather than guess at the right ES coercion. +_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") + +# Operators that the legacy filter exposes. Anything outside this set is a bug +# elsewhere; surface it instead of silently no-op'ing. +SUPPORTED_OPERATORS: frozenset[str] = frozenset( + { + "=", + "≠", + ">", + "<", + "≥", + "≤", + "in", + "not in", + "contains", + "not contains", + "start with", + "end with", + "empty", + "not empty", + } +) + +# ES range comparators keyed by RAGflow operator. +_RANGE_OPS: Dict[str, str] = { + ">": "gt", + "<": "lt", + "≥": "gte", + "≤": "lte", +} + +# Negative operators that diverge from ``meta_filter`` on multi-valued metadata +# fields. The in-memory path checks each value bucket independently, so a doc +# whose field is ``[a, b]`` matches ``≠ a`` (because the ``b`` bucket satisfies +# the predicate). ``must_not term: a`` in ES would exclude that doc outright. +# Without a cheap way to prove a field is single-valued at query time we refuse +# push-down for these operators and let the in-memory fallback handle them. +# ``not contains`` is not in this set: ``all(not contains)`` is equivalent to +# ``not any(contains)``, so ``must_not wildcard *X*`` matches the legacy +# semantics on both single- and multi-valued fields. +MULTIVALUE_UNSAFE_NEGATIVE_OPS: frozenset[str] = frozenset({"≠", "not in"}) + + +class UnsupportedMetaFilter(Exception): + """Raised when a metadata filter cannot be expressed as ES DSL. + + Carries the filter that failed so callers can log a precise reason and the + in-memory fallback can pick up unchanged. + """ + + def __init__(self, reason: str, filter_clause: Optional[Dict[str, Any]] = None) -> None: + super().__init__(reason) + self.reason = reason + self.filter_clause = filter_clause + + +@dataclass +class TranslatedFilter: + """A single user filter rendered as one or more ES bool clauses. + + A clause that wants the field to be present (``≠``, ``not in``, range, + ``not contains``) goes into ``must`` so the negation does not accidentally + match documents missing the key. ``must_not`` carries the actual rejection. + Pure positive filters (``=``, ``contains``, ``in``, ``exists``) fill + ``must`` only. + """ + + must: List[Dict[str, Any]] = field(default_factory=list) + must_not: List[Dict[str, Any]] = field(default_factory=list) + + def to_clauses(self) -> List[Dict[str, Any]]: + """Collapse to the ES clauses this filter contributes to a parent bool. + + Always emits a single atomic clause when there is anything to emit: + a multi-clause ``must`` (e.g. range = ``exists`` + ``range``) gets + wrapped in its own ``bool`` so an OR-logic parent ``should`` can't + match on just one half of the filter. A pure single positive clause + is returned unwrapped because there is nothing to break apart. + """ + if not self.must and not self.must_not: + return [] + if not self.must_not: + if len(self.must) == 1: + return list(self.must) + # Multi-clause positive filter — keep it atomic for OR parents. + return [{"bool": {"must": list(self.must)}}] + # Negative semantics always need wrapping so they survive being OR'd + # with siblings. + return [{"bool": {"must": list(self.must), "must_not": list(self.must_not)}}] + + +@dataclass +class MetaFilterPushdownPlan: + """Composed ES bool query body for an entire RAGflow filter request.""" + + logic: str + translated: List[TranslatedFilter] = field(default_factory=list) + + def is_empty(self) -> bool: + return not self.translated + + def to_query(self, kb_ids: Sequence[str]) -> Dict[str, Any]: + """Render the full ES query body, scoped to the given KB ids. + + The KB filter is always a ``terms`` clause so the query can serve any + number of knowledge bases without rewriting the caller. + """ + kb_clause = {"terms": {"kb_id": list(kb_ids)}} + + if self.is_empty(): + return {"query": {"bool": {"filter": [kb_clause]}}} + + sub_clauses = [t.to_clauses() for t in self.translated] + flat_clauses: List[Dict[str, Any]] = [c for group in sub_clauses for c in group] + + if self.logic == "or": + inner = { + "bool": { + "should": flat_clauses, + "minimum_should_match": 1, + } + } + else: + inner = {"bool": {"must": flat_clauses}} + + return { + "query": { + "bool": { + "filter": [kb_clause, inner], + } + } + } + + +class MetaFilterTranslator: + """Translate one user filter clause at a time into ES DSL fragments. + + Stateless aside from configuration; safe to instantiate once per request + or share at module scope. + """ + + def __init__(self, prefix: str = META_FIELDS_PREFIX) -> None: + self.prefix = prefix + + def field_name(self, key: str) -> str: + """Compose the dotted ES field path for a user metadata key.""" + return f"{self.prefix}.{key}" + + def translate(self, flt: Dict[str, Any]) -> TranslatedFilter: + """Translate a single filter dict into ES bool clauses. + + Raises ``UnsupportedMetaFilter`` for malformed input or operator/value + combinations the legacy in-memory path treats as a special case (e.g. + list-of-strings membership in ``in``/``not in``). + """ + op = flt.get("op") + key = flt.get("key") + value = flt.get("value") + + if not key or not isinstance(key, str): + raise UnsupportedMetaFilter("filter is missing a string key", flt) + if op not in SUPPORTED_OPERATORS: + raise UnsupportedMetaFilter(f"unknown operator {op!r}", flt) + + field_path = self.field_name(key) + + if op == "empty": + return self._translate_empty(field_path) + if op == "not empty": + return self._translate_not_empty(field_path) + if op == "=": + return self._translate_equal(field_path, value, flt) + if op == "≠": + return self._translate_not_equal(field_path, value, flt) + if op in _RANGE_OPS: + return self._translate_range(field_path, op, value, flt) + if op == "in": + return self._translate_in(field_path, value, flt) + if op == "not in": + return self._translate_not_in(field_path, value, flt) + if op == "contains": + return self._translate_contains(field_path, value, flt) + if op == "not contains": + return self._translate_not_contains(field_path, value, flt) + if op == "start with": + return self._translate_start_with(field_path, value, flt) + if op == "end with": + return self._translate_end_with(field_path, value, flt) + + # Unreachable: SUPPORTED_OPERATORS gate above covers every branch. + raise UnsupportedMetaFilter(f"no handler for operator {op!r}", flt) + + def _translate_empty(self, field_path: str) -> TranslatedFilter: + # "empty" matches documents whose value is missing OR equals "" — same + # falsy semantics the in-memory ``not input`` check enforces. The + # blank-string check has to target ``.keyword`` because the analyzed + # text field drops empty values during tokenisation, leaving no token + # for ``term: ""`` to match. + return TranslatedFilter( + must=[ + { + "bool": { + "should": [ + {"bool": {"must_not": [{"exists": {"field": field_path}}]}}, + {"term": {_keyword_path(field_path): ""}}, + ], + "minimum_should_match": 1, + } + } + ] + ) + + def _translate_not_empty(self, field_path: str) -> TranslatedFilter: + return TranslatedFilter( + must=[{"exists": {"field": field_path}}], + must_not=[{"term": {_keyword_path(field_path): ""}}], + ) + + def _translate_equal(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter: + coerced = _coerce_scalar(value, flt) + return TranslatedFilter(must=[_term_or_match(field_path, coerced)]) + + def _translate_not_equal(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter: + coerced = _coerce_scalar(value, flt) + return TranslatedFilter( + must=[{"exists": {"field": field_path}}], + must_not=[_term_or_match(field_path, coerced)], + ) + + def _translate_range(self, field_path: str, op: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter: + coerced = _coerce_range_value(value, flt) + return TranslatedFilter( + must=[ + {"exists": {"field": field_path}}, + {"range": {field_path: {_RANGE_OPS[op]: coerced}}}, + ] + ) + + def _translate_in(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter: + members = _csv_or_list(value, flt) + return TranslatedFilter(must=[_terms_string_or_numeric(field_path, members)]) + + def _translate_not_in(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter: + members = _csv_or_list(value, flt) + return TranslatedFilter( + must=[{"exists": {"field": field_path}}], + must_not=[_terms_string_or_numeric(field_path, members)], + ) + + def _translate_contains(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter: + text = _coerce_string(value, flt) + return TranslatedFilter(must=[_wildcard(field_path, f"*{_escape_wildcard(text)}*")]) + + def _translate_not_contains(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter: + text = _coerce_string(value, flt) + return TranslatedFilter( + must=[{"exists": {"field": field_path}}], + must_not=[_wildcard(field_path, f"*{_escape_wildcard(text)}*")], + ) + + def _translate_start_with(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter: + text = _coerce_string(value, flt) + return TranslatedFilter( + must=[{"prefix": {_keyword_path(field_path): {"value": text, "case_insensitive": True}}}] + ) + + def _translate_end_with(self, field_path: str, value: Any, flt: Dict[str, Any]) -> TranslatedFilter: + text = _coerce_string(value, flt) + return TranslatedFilter(must=[_wildcard(field_path, f"*{_escape_wildcard(text)}")]) + + +def build_meta_filter_query( + filters: Sequence[Dict[str, Any]], + logic: str, + kb_ids: Sequence[str], + translator: Optional[MetaFilterTranslator] = None, +) -> Dict[str, Any]: + """Top-level helper: translate every filter and render the ES query body. + + Raises ``UnsupportedMetaFilter`` if any filter cannot be expressed. + """ + plan = plan_pushdown(filters, logic, translator=translator) + return plan.to_query(kb_ids) + + +def plan_pushdown( + filters: Sequence[Dict[str, Any]], + logic: str, + translator: Optional[MetaFilterTranslator] = None, +) -> MetaFilterPushdownPlan: + """Translate every filter in turn, building a single composed plan. + + Separated from ``build_meta_filter_query`` so callers can inspect or + augment the plan before binding it to a KB scope. + """ + if logic not in {"and", "or"}: + raise UnsupportedMetaFilter(f"unknown logic {logic!r}") + + t = translator or MetaFilterTranslator() + plan = MetaFilterPushdownPlan(logic=logic) + for flt in filters: + plan.translated.append(t.translate(flt)) + return plan + + +def is_pushdown_supported(filters: Sequence[Dict[str, Any]]) -> bool: + """Cheap pre-check: do all filters look translatable without coercion? + + Used by the routing layer to skip the heavier ``plan_pushdown`` call when + the request obviously needs the in-memory fallback. + + Operators in :data:`MULTIVALUE_UNSAFE_NEGATIVE_OPS` are rejected here so a + single such filter forces the whole request to in-memory evaluation, which + is the only place we can replicate the per-bucket semantics over + multi-valued metadata fields. + """ + for flt in filters: + op = flt.get("op") + if op not in SUPPORTED_OPERATORS: + return False + if op in MULTIVALUE_UNSAFE_NEGATIVE_OPS: + return False + if not isinstance(flt.get("key"), str) or not flt.get("key"): + return False + return True + + +def extract_doc_ids(es_response: Dict[str, Any]) -> List[str]: + """Pull doc IDs out of an ES search response shaped like ``{hits:{hits:[...]}}``. + + Tolerates both the dict-typed ES 7+ response and the dict-coerced + ``ObjectApiResponse`` returned by the elasticsearch python client. + """ + hits_root = es_response.get("hits") if isinstance(es_response, dict) else None + if not hits_root: + # ``ObjectApiResponse`` is dict-like; ``.get`` works at both levels. + try: + hits_root = es_response["hits"] + except Exception: + return [] + + raw_hits: Iterable[Dict[str, Any]] + if isinstance(hits_root, dict): + raw_hits = hits_root.get("hits", []) or [] + else: + raw_hits = [] + + out: List[str] = [] + for hit in raw_hits: + if not isinstance(hit, dict): + continue + # ``id`` is mirrored into ``_source`` by the metadata writer; ``_id`` + # is the canonical identifier. Prefer ``_id`` so renames in the source + # field name don't break us. + doc_id = hit.get("_id") + if not doc_id: + source = hit.get("_source") or {} + doc_id = source.get("id") or source.get("doc_id") + if doc_id: + out.append(str(doc_id)) + return out + + +# --------------------------------------------------------------------------- +# Value coercion helpers +# --------------------------------------------------------------------------- + + +def _coerce_scalar(value: Any, flt: Dict[str, Any]) -> Any: + """Mirror the legacy ``ast.literal_eval`` then ``str.lower()`` flow. + + The in-memory filter parses values as Python literals when possible (so + ``"5"`` becomes ``5``) and lower-cases strings. For ES ``term`` queries we + need the same coercion or numeric data won't match. + """ + if value is None: + raise UnsupportedMetaFilter("scalar comparison value is None", flt) + if isinstance(value, (list, dict)): + raise UnsupportedMetaFilter("scalar comparison value is non-scalar", flt) + + s = str(value).strip() + if _DATE_RE.match(s): + return s + try: + parsed = ast.literal_eval(s) + except Exception: + parsed = s + if isinstance(parsed, str): + return parsed.lower() + if isinstance(parsed, (int, float, bool)): + return parsed + return s.lower() + + +def _coerce_range_value(value: Any, flt: Dict[str, Any]) -> Any: + """Range comparisons accept dates verbatim and numbers parsed via literal_eval. + + Strings that aren't numeric or ISO dates are pushed through as-is — ES + will compare them lexically against keyword fields, which is the same + behaviour as the in-memory ``input >= value`` Python comparison after the + original ``ast.literal_eval`` failure path. + """ + if value is None: + raise UnsupportedMetaFilter("range comparison value is None", flt) + s = str(value).strip() + if _DATE_RE.match(s): + return s + try: + parsed = ast.literal_eval(s) + except Exception: + return s + if isinstance(parsed, (int, float)): + return parsed + return s + + +def _coerce_string(value: Any, flt: Dict[str, Any]) -> str: + """String operators (contains/start with/end with) need a non-empty string.""" + if value is None: + raise UnsupportedMetaFilter("string-operator value is None", flt) + if isinstance(value, (list, dict)): + raise UnsupportedMetaFilter("string-operator value must be a scalar", flt) + s = str(value) + if not s: + raise UnsupportedMetaFilter("string-operator value is empty", flt) + return s + + +def _csv_or_list(value: Any, flt: Dict[str, Any]) -> List[Any]: + """``in`` / ``not in`` accept either a real list or a comma-separated string. + + The legacy in-memory path applies ``ast.literal_eval`` to the value too. + Mirror that for parity, then trim whitespace and lower-case any strings. + """ + if value is None: + raise UnsupportedMetaFilter("membership value is None", flt) + + if isinstance(value, (list, tuple)): + members = list(value) + elif isinstance(value, str): + try: + parsed = ast.literal_eval(value) + except Exception: + parsed = value + if isinstance(parsed, (list, tuple)): + members = list(parsed) + else: + members = [m.strip() for m in value.split(",") if m.strip()] + else: + members = [value] + + if not members: + raise UnsupportedMetaFilter("membership value resolved to empty list", flt) + + normalised: List[Any] = [] + for m in members: + if isinstance(m, str): + normalised.append(m.lower().strip()) + else: + normalised.append(m) + return normalised + + +def _keyword_path(field_path: str) -> str: + """Sub-field used for exact-match string queries. + + Dynamic mapping under ``meta_fields`` indexes string values as ``text`` + with a ``.keyword`` multi-field. ``term``/``terms``/``prefix``/``wildcard`` + against the analyzed parent breaks for any multi-word value because the + inverted index stores per-token entries, not the original phrase. Routing + string queries through ``.keyword`` keeps semantics aligned with the + in-memory ``meta_filter`` (full-string compare after lower-casing). + """ + return f"{field_path}.keyword" + + +def _term_or_match(field_path: str, value: Any) -> Dict[str, Any]: + """Exact-match clause that respects how dynamic mapping indexes the value. + + String values target the ``.keyword`` sub-field with ``case_insensitive`` + so phrase values still match (the in-memory path lower-cases before + comparing). Numeric / bool values target the parent path because numeric + fields have no ``.keyword`` sub-field under default dynamic mapping. + """ + if isinstance(value, str): + return { + "term": { + _keyword_path(field_path): { + "value": value, + "case_insensitive": True, + } + } + } + return {"term": {field_path: value}} + + +def _terms_string_or_numeric(field_path: str, members: List[Any]) -> Dict[str, Any]: + """``in``/``not in`` payload that mirrors ``_term_or_match`` per element. + + ES ``terms`` does not accept ``case_insensitive``, so for string members we + expand into a ``bool: should`` of case-insensitive ``term`` queries on the + keyword sub-field. Pure-numeric / bool member lists keep the cheaper + ``terms`` form on the parent path. + """ + if all(not isinstance(m, str) for m in members): + return {"terms": {field_path: members}} + return { + "bool": { + "should": [_term_or_match(field_path, m) for m in members], + "minimum_should_match": 1, + } + } + + +def _wildcard(field_path: str, pattern: str) -> Dict[str, Any]: + """Wildcard runs against ``.keyword`` so the original phrase is searched. + + ``wildcard`` against an analyzed text field walks per-token entries, which + drops phrase context (``Alice Wonderland`` becomes tokens ``alice``, + ``wonderland``). The ``.keyword`` sub-field preserves the full original + string, matching the in-memory ``str.find`` semantics. + """ + return { + "wildcard": { + _keyword_path(field_path): { + "value": pattern, + "case_insensitive": True, + } + } + } + + +def _escape_wildcard(text: str) -> str: + """Escape the two ES wildcard metacharacters so user input stays literal.""" + return text.replace("\\", "\\\\").replace("*", "\\*").replace("?", "\\?") diff --git a/common/metadata_utils.py b/common/metadata_utils.py index c919bd186af..c2fc90b5414 100644 --- a/common/metadata_utils.py +++ b/common/metadata_utils.py @@ -42,6 +42,13 @@ def convert_conditions(metadata_condition): def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): doc_ids = set([]) + def normalize_string_values(value): + if isinstance(value, str): + return value.lower() + if isinstance(value, list): + return [item.lower() if isinstance(item, str) else item for item in value] + return value + def filter_out(v2docs, operator, value): ids = [] for input, docids in v2docs.items(): @@ -96,10 +103,8 @@ def filter_out(v2docs, operator, value): value = value.lower() else: # Non-comparison operators: maintain original logic - if isinstance(input, str): - input = input.lower() - if isinstance(value, str): - value = value.lower() + input = normalize_string_values(input) + value = normalize_string_values(value) matched = False try: @@ -161,11 +166,13 @@ def filter_out(v2docs, operator, value): async def apply_meta_data_filter( meta_data_filter: dict | None, - metas: dict, - question: str, + metas: dict | None = None, + question: str = "", chat_mdl: Any = None, base_doc_ids: list[str] | None = None, manual_value_resolver: Callable[[dict], dict] | None = None, + kb_ids: list[str] | None = None, + metas_loader: Callable[[], dict] | None = None, ) -> list[str] | None: """ Apply metadata filtering rules and return the filtered doc_ids. @@ -175,6 +182,20 @@ async def apply_meta_data_filter( - semi_auto: generate conditions using selected metadata keys only - manual: directly filter based on provided conditions + When ``kb_ids`` is supplied and the active doc store is Elasticsearch the + generated filter conditions are pushed down to ES via + ``DocMetadataService.filter_doc_ids_by_meta_pushdown`` instead of being + evaluated in Python over ``metas``. The in-memory ``meta_filter`` path + remains the fallback so callers without a KB scope, or backends without + push-down support, behave exactly as before. + + ``metas`` may be supplied eagerly or via ``metas_loader``. The loader is + only invoked when the metadata dict is actually needed — i.e. for the LLM + context in ``auto`` / ``semi_auto`` modes, or as the in-memory fallback + when push-down can't service a request. ``manual`` mode that lands on the + push-down path therefore skips the expensive + ``get_flatted_meta_by_kbs`` round-trip entirely. + Returns: list of doc_ids, ["-999"] when manual filters yield no result, or None when auto/semi_auto filters return empty. @@ -188,9 +209,28 @@ async def apply_meta_data_filter( method = meta_data_filter.get("method") + # Memoised metadata loader. ``_get_metas`` materialises the dict at most + # once per call; downstream branches that never reach an in-memory eval + # leave the loader untouched. + cached_metas: dict | None = metas + + def _get_metas() -> dict: + nonlocal cached_metas + if cached_metas is None: + cached_metas = metas_loader() if metas_loader else {} + return cached_metas + + def _evaluate(conditions: list[dict], logic: str) -> list[str]: + """Run conditions through ES push-down when possible, in-memory otherwise.""" + if conditions and kb_ids: + pushed = _try_meta_pushdown(kb_ids, conditions, logic) + if pushed is not None: + return pushed + return meta_filter(_get_metas(), conditions, logic) + if method == "auto": - filters: dict = await gen_meta_filter(chat_mdl, metas, question) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + filters: dict = await gen_meta_filter(chat_mdl, _get_metas(), question) + doc_ids.extend(_evaluate(filters["conditions"], filters.get("logic", "and"))) if not doc_ids: return None elif method == "semi_auto": @@ -207,23 +247,47 @@ async def apply_meta_data_filter( constraints[key] = op if selected_keys: - filtered_metas = {key: metas[key] for key in selected_keys if key in metas} + current_metas = _get_metas() + filtered_metas = {key: current_metas[key] for key in selected_keys if key in current_metas} if filtered_metas: filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question, constraints=constraints) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + doc_ids.extend(_evaluate(filters["conditions"], filters.get("logic", "and"))) if not doc_ids: return None elif method == "manual": filters = meta_data_filter.get("manual", []) if manual_value_resolver: filters = [manual_value_resolver(flt) for flt in filters] - doc_ids.extend(meta_filter(metas, filters, meta_data_filter.get("logic", "and"))) + doc_ids.extend(_evaluate(filters, meta_data_filter.get("logic", "and"))) if filters and not doc_ids: doc_ids = ["-999"] return doc_ids +def _try_meta_pushdown( + kb_ids: list[str], + conditions: list[dict], + logic: str, +) -> list[str] | None: + """Attempt the ES push-down path; return ``None`` to fall back in-memory. + + Lazy-imports ``DocMetadataService`` so this module stays usable in + environments where the API/db layer hasn't been wired up (e.g. unit tests + that exercise ``meta_filter`` directly). + """ + try: + from api.db.services.doc_metadata_service import DocMetadataService + except Exception as e: + logging.debug(f"[apply_meta_data_filter] push-down disabled, import failed: {e}") + return None + try: + return DocMetadataService.filter_doc_ids_by_meta_pushdown(kb_ids, conditions, logic) + except Exception as e: + logging.warning(f"[apply_meta_data_filter] push-down errored, falling back: {e}") + return None + + def dedupe_list(values: list) -> list: seen = set() deduped = [] diff --git a/common/parser_config_utils.py b/common/parser_config_utils.py index 0bc7ffc28b3..daf91cc8e1a 100644 --- a/common/parser_config_utils.py +++ b/common/parser_config_utils.py @@ -29,5 +29,8 @@ def normalize_layout_recognizer(layout_recognizer_raw: Any) -> tuple[Any, str | elif lowered.endswith("@paddleocr"): parser_model_name = layout_recognizer_raw.rsplit("@", 1)[0] layout_recognizer = "PaddleOCR" + elif lowered.endswith("@opendataloader"): + parser_model_name = layout_recognizer_raw.rsplit("@", 1)[0] + layout_recognizer = "OpenDataLoader" return layout_recognizer, parser_model_name diff --git a/common/settings.py b/common/settings.py index 2b67dc34d72..49693b93701 100644 --- a/common/settings.py +++ b/common/settings.py @@ -17,6 +17,8 @@ import json import secrets import logging +from datetime import date + from common.constants import RAG_FLOW_SERVICE_NAME from common.file_utils import get_project_base_directory from common.config_utils import get_base_config, decrypt_database_config @@ -43,6 +45,8 @@ import memory.utils.infinity_conn as memory_infinity_conn import memory.utils.ob_conn as memory_ob_conn +TIMEZONE = os.getenv("TZ", "Asia/Shanghai") + LLM = None LLM_FACTORY = None LLM_BASE_URL = None @@ -137,6 +141,24 @@ def get_svr_queue_name(priority: int) -> str: def get_svr_queue_names(): return [get_svr_queue_name(priority) for priority in [1, 0]] +def init_secret_key(): + secret_key = os.environ.get("RAGFLOW_SECRET_KEY") + if secret_key and len(secret_key) >= 32: + return secret_key + + # Check if there's a configured secret key + configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key") + if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32: + return configured_key + return None + + +def get_secret_key(): + global SECRET_KEY + if SECRET_KEY is None: + return _get_or_create_secret_key() + return SECRET_KEY + def _get_or_create_secret_key(): # secret_key = os.environ.get("RAGFLOW_SECRET_KEY") # if secret_key and len(secret_key) >= 32: @@ -152,7 +174,8 @@ def _get_or_create_secret_key(): generated_key = secrets.token_hex(32) secret_key = REDIS_CONN.get_or_create_secret_key("ragflow:system:secret_key", generated_key) - logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.") + if generated_key == secret_key: + logging.warning("SECURITY WARNING: Using auto-generated SECRET_KEY.") return secret_key class StorageFactory: @@ -243,7 +266,7 @@ def init_settings(): HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") global SECRET_KEY - SECRET_KEY = _get_or_create_secret_key() + SECRET_KEY = init_secret_key() # authentication diff --git a/common/ssrf_guard.py b/common/ssrf_guard.py new file mode 100644 index 00000000000..b60bcd4bc99 --- /dev/null +++ b/common/ssrf_guard.py @@ -0,0 +1,172 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Shared SSRF-guard utilities. + +Uses only the standard library so it can be imported from both ``api/`` and +``common/`` without pulling in any heavyweight dependencies. +""" + +import ipaddress +import logging +import socket +import threading +from contextlib import contextmanager +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# DNS pinning — closes the TOCTOU / rebinding window between SSRF validation +# and the actual TCP connection. The monkey-patch is a no-op for any host +# that has no active pin, so it cannot affect unrelated code. +# --------------------------------------------------------------------------- + +_tl = threading.local() +_global_dns_pins: dict[str, str] = {} +_global_pin_lock = threading.Lock() +_orig_getaddrinfo = socket.getaddrinfo + + +def _getaddrinfo_with_pins(host, port, *args, **kwargs): + # Thread-local pins (synchronous callers: requests.get in the same thread) + local_pins: dict = getattr(_tl, "dns_pins", {}) + if host in local_pins: + ip = local_pins[host] + family = socket.AF_INET6 if ":" in ip else socket.AF_INET + return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))] + # Process-global pins (async callers whose DNS resolves in executor threads) + with _global_pin_lock: + ip = _global_dns_pins.get(host) + if ip is not None: + family = socket.AF_INET6 if ":" in ip else socket.AF_INET + return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))] + return _orig_getaddrinfo(host, port, *args, **kwargs) + + +socket.getaddrinfo = _getaddrinfo_with_pins + + +@contextmanager +def pin_dns(hostname: str, ip: str): + """Pin *hostname* → *ip* in the current thread for the duration of this context. + + Use for synchronous ``requests.get()`` callers to prevent DNS rebinding + between SSRF validation and the actual TCP connection. + """ + pins = _tl.__dict__.setdefault("dns_pins", {}) + pins[hostname] = ip + try: + yield + finally: + pins.pop(hostname, None) + + +@contextmanager +def pin_dns_global(hostname: str, ip: str): + """Pin *hostname* → *ip* across all threads for the duration of this context. + + Use for async callers (e.g. asyncio-based crawlers) where DNS resolution + may happen in thread-pool executor threads rather than the calling thread. + """ + with _global_pin_lock: + _global_dns_pins[hostname] = ip + try: + yield + finally: + with _global_pin_lock: + _global_dns_pins.pop(hostname, None) + + +_DEFAULT_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"}) + + +def _effective_ip( + ip: ipaddress.IPv4Address | ipaddress.IPv6Address, +) -> ipaddress.IPv4Address | ipaddress.IPv6Address: + """Return the IPv4 equivalent for IPv4-mapped IPv6 addresses, unchanged otherwise. + + Without this normalization ``::ffff:127.0.0.1`` would pass ``is_global`` + as an IPv6Address in some Python versions, bypassing the loopback check. + """ + if isinstance(ip, ipaddress.IPv6Address): + mapped = ip.ipv4_mapped + if mapped is not None: + return mapped + return ip + + +def assert_url_is_safe( + url: str, + *, + allowed_schemes: frozenset[str] = _DEFAULT_ALLOWED_SCHEMES, +) -> tuple[str, str]: + """Raise ``ValueError`` if *url* is not safe to fetch (SSRF guard). + + Checks performed in order: + + 1. Scheme is in *allowed_schemes*. + 2. Hostname is present. + 3. **Every** address returned by ``getaddrinfo`` is globally routable + (``ip.is_global``). This is an allowlist approach: it catches private, + loopback, link-local, reserved, multicast, and all other + special-purpose ranges rather than individual deny-list flags. + IPv4-mapped IPv6 addresses (e.g. ``::ffff:127.0.0.1``) are normalised + to their IPv4 form via :func:`_effective_ip` before the check. + + Returns ``(hostname, resolved_ip)`` — the first validated public IP string + — so the caller can **pin** that address in its HTTP client and prevent + DNS-rebinding attacks (the hostname is resolved exactly once). + """ + parsed = urlparse(url) + scheme = parsed.scheme + if scheme not in allowed_schemes: + logger.warning( + "SSRF guard blocked URL with disallowed scheme: scheme=%r url=%r", + scheme, + url, + ) + raise ValueError(f"Disallowed URL scheme: {scheme!r}. Only {sorted(allowed_schemes)} are allowed.") + + hostname = parsed.hostname + if not hostname: + logger.warning("SSRF guard blocked URL with missing host: url=%r", url) + raise ValueError("URL is missing a host.") + + try: + addr_infos = socket.getaddrinfo(hostname, None) + except socket.gaierror as exc: + logger.warning("SSRF guard could not resolve hostname=%r reason=%s", hostname, exc) + raise ValueError(f"Could not resolve hostname {hostname!r}: {exc}") from exc + + resolved_ip: str | None = None + for _family, _type, _proto, _canonname, sockaddr in addr_infos: + raw_ip = ipaddress.ip_address(sockaddr[0]) + eff_ip = _effective_ip(raw_ip) + if not eff_ip.is_global: + logger.warning( + "SSRF guard blocked URL: hostname=%r resolved to non-public address=%s", + hostname, + raw_ip, + ) + raise ValueError(f"URL resolves to a non-public address ({raw_ip}), which is not allowed.") + if resolved_ip is None: + resolved_ip = str(raw_ip) + + if resolved_ip is None: + logger.warning("SSRF guard blocked URL: hostname=%r resolved to no addresses", hostname) + raise ValueError(f"Hostname {hostname!r} resolved to no addresses.") + + return hostname, resolved_ip diff --git a/conf/infinity_mapping.json b/conf/infinity_mapping.json index 77d26dd9604..5f7ed80f261 100644 --- a/conf/infinity_mapping.json +++ b/conf/infinity_mapping.json @@ -38,5 +38,6 @@ "removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, "doc_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, "toc_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, - "raptor_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"} + "raptor_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"}, + "raptor_layer_int": {"type": "integer", "default": 0} } diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 0cadfe3679d..2fc12803d78 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -377,7 +377,7 @@ "tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK,TTS,SPEECH2TEXT,MODERATION", "status": "1", "rank": "950", - "url" : "https://dashscope.aliyuncs.com/compatible-mode/v1", + "url": "https://dashscope.aliyuncs.com/compatible-mode/v1", "llm": [ { "llm_name": "qwen3.5-122b-a10b", @@ -421,13 +421,6 @@ "model_type": "chat", "is_tools": false }, - { - "llm_name": "deepseek-r1-distill-qwen-7b", - "tags": "LLM,CHAT,32K", - "max_tokens": 32768, - "model_type": "chat", - "is_tools": false - }, { "llm_name": "deepseek-r1-distill-qwen-14b", "tags": "LLM,CHAT,32K", @@ -1134,16 +1127,16 @@ "url": "https://api.deepseek.com/v1", "llm": [ { - "llm_name": "deepseek-chat", + "llm_name": "deepseek-v4-flash", "tags": "LLM,CHAT,", - "max_tokens": 64000, + "max_tokens": 1000000, "model_type": "chat", "is_tools": true }, { - "llm_name": "deepseek-reasoner", + "llm_name": "deepseek-v4-pro", "tags": "LLM,CHAT,", - "max_tokens": 64000, + "max_tokens": 1000000, "model_type": "chat", "is_tools": true } @@ -1557,53 +1550,52 @@ "rank": "980", "llm": [ { - "llm_name": "gemini-3-pro-preview", - "tags": "LLM,CHAT,1M,IMAGE2TEXT", - "max_tokens": 1048576, - "model_type": "image2text", - "is_tools": true + "llm_name": "gemini-3-pro-preview", + "tags": "LLM,CHAT,1M,IMAGE2TEXT", + "max_tokens": 1048576, + "model_type": "image2text", + "is_tools": true }, { - "llm_name": "gemini-2.5-flash", - "tags": "LLM,CHAT,1024K,IMAGE2TEXT", - "max_tokens": 1048576, - "model_type": "image2text", - "is_tools": true + "llm_name": "gemini-2.5-flash", + "tags": "LLM,CHAT,1024K,IMAGE2TEXT", + "max_tokens": 1048576, + "model_type": "image2text", + "is_tools": true }, { - "llm_name": "gemini-2.5-pro", - "tags": "LLM,CHAT,IMAGE2TEXT,1024K", - "max_tokens": 1048576, - "model_type": "image2text", - "is_tools": true + "llm_name": "gemini-2.5-pro", + "tags": "LLM,CHAT,IMAGE2TEXT,1024K", + "max_tokens": 1048576, + "model_type": "image2text", + "is_tools": true }, { - "llm_name": "gemini-2.5-flash-lite", - "tags": "LLM,CHAT,1024K,IMAGE2TEXT", - "max_tokens": 1048576, - "model_type": "image2text", - "is_tools": true + "llm_name": "gemini-2.5-flash-lite", + "tags": "LLM,CHAT,1024K,IMAGE2TEXT", + "max_tokens": 1048576, + "model_type": "image2text", + "is_tools": true }, { - "llm_name": "gemini-2.0-flash", - "tags": "LLM,CHAT,1024K", - "max_tokens": 1048576, - "model_type": "image2text", - "is_tools": true + "llm_name": "gemini-2.0-flash", + "tags": "LLM,CHAT,1024K", + "max_tokens": 1048576, + "model_type": "image2text", + "is_tools": true }, { - "llm_name": "gemini-2.0-flash-lite", - "tags": "LLM,CHAT,1024K", - "max_tokens": 1048576, - "model_type": "image2text", - "is_tools": true + "llm_name": "gemini-2.0-flash-lite", + "tags": "LLM,CHAT,1024K", + "max_tokens": 1048576, + "model_type": "image2text", + "is_tools": true }, - { - "llm_name": "gemini-embedding-001", - "tags": "TEXT EMBEDDING", - "max_tokens": 2048, - "model_type": "embedding" + "llm_name": "gemini-embedding-001", + "tags": "TEXT EMBEDDING", + "max_tokens": 2048, + "model_type": "embedding" } ] }, @@ -2949,20 +2941,6 @@ "model_type": "chat", "is_tools": true }, - { - "llm_name": "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", - "tags": "LLM,CHAT,32k", - "max_tokens": 32000, - "model_type": "chat", - "is_tools": true - }, - { - "llm_name": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", - "tags": "LLM,CHAT,32k", - "max_tokens": 32000, - "model_type": "chat", - "is_tools": true - }, { "llm_name": "deepseek-ai/DeepSeek-V2.5", "tags": "LLM,CHAT,32k", @@ -4247,13 +4225,6 @@ "model_type": "chat", "is_tools": false }, - { - "llm_name": "DeepSeek-R1-Distill-Qwen-7B", - "tags": "LLM,CHAT", - "max_tokens": 65792, - "model_type": "chat", - "is_tools": false - }, { "llm_name": "DeepSeek-R1-Distill-Qwen-1.5B", "tags": "LLM,CHAT", @@ -6255,6 +6226,14 @@ "rank": "910", "llm": [] }, + { + "name": "OpenDataLoader", + "logo": "", + "tags": "OCR", + "status": "1", + "rank": "920", + "llm": [] + }, { "name": "n1n", "logo": "", @@ -6293,6 +6272,435 @@ } ] }, + { + "name": "Astraflow", + "logo": "", + "tags": "LLM,TEXT EMBEDDING", + "status": "1", + "rank": "250", + "url": "https://api-us-ca.umodelverse.ai/v1", + "llm": [ + { + "llm_name": "claude-opus-4-7", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "claude-opus-4-6", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "claude-sonnet-4-5-20250929", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "claude-haiku-4-5-20251001", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-5.4", + "tags": "LLM,CHAT,400k", + "max_tokens": 400000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-5.4-mini", + "tags": "LLM,CHAT,400k", + "max_tokens": 400000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-5.4-nano", + "tags": "LLM,CHAT,400k", + "max_tokens": 400000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-4o-mini", + "tags": "LLM,CHAT,128k", + "max_tokens": 128000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-Max", + "tags": "LLM,CHAT,131k", + "max_tokens": 131072, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-Coder", + "tags": "LLM,CHAT,131k", + "max_tokens": 131072, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-32B", + "tags": "LLM,CHAT,131k", + "max_tokens": 131072, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-VL-235B-A22B-Instruct", + "tags": "LLM,CHAT,131k", + "max_tokens": 131072, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "kimi-k2.6", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "glm-5.1", + "tags": "LLM,CHAT,128k", + "max_tokens": 128000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "MiniMax-M2.7", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "MiniMax-M2", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gemini-2.5-pro", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gemini-2.5-flash", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "qwen3-embedding-8b", + "tags": "TEXT EMBEDDING,8K", + "max_tokens": 8192, + "model_type": "embedding", + "is_tools": false + }, + { + "llm_name": "text-embedding-3-large", + "tags": "TEXT EMBEDDING,8K", + "max_tokens": 8191, + "model_type": "embedding", + "is_tools": false + }, + { + "llm_name": "text-embedding-ada-002", + "tags": "TEXT EMBEDDING,8K", + "max_tokens": 8191, + "model_type": "embedding", + "is_tools": false + } + ] + }, + { + "name": "FuturMix", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT,SPEECH2TEXT,TTS,TEXT RE-RANK", + "status": "1", + "rank": "248", + "url": "https://futurmix.ai/v1", + "llm": [ + { + "llm_name": "claude-sonnet-4-20250514", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "claude-3.5-haiku", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-4o", + "tags": "LLM,CHAT,128k", + "max_tokens": 128000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-4o-mini", + "tags": "LLM,CHAT,128k", + "max_tokens": 128000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gemini-2.5-flash", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gemini-2.0-flash", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-chat", + "tags": "LLM,CHAT,64k", + "max_tokens": 65536, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "deepseek-reasoner", + "tags": "LLM,CHAT,64k", + "max_tokens": 65536, + "model_type": "chat", + "is_tools": false + }, + { + "llm_name": "gpt-4o", + "tags": "IMAGE2TEXT,CHAT,128k", + "max_tokens": 128000, + "model_type": "image2text", + "is_tools": true + }, + { + "llm_name": "text-embedding-3-small", + "tags": "TEXT EMBEDDING,8K", + "max_tokens": 8191, + "model_type": "embedding", + "is_tools": false + }, + { + "llm_name": "text-embedding-3-large", + "tags": "TEXT EMBEDDING,8K", + "max_tokens": 8191, + "model_type": "embedding", + "is_tools": false + }, + { + "llm_name": "tts-1", + "tags": "TTS", + "max_tokens": 4096, + "model_type": "tts", + "is_tools": false + }, + { + "llm_name": "tts-1-hd", + "tags": "TTS", + "max_tokens": 4096, + "model_type": "tts", + "is_tools": false + }, + { + "llm_name": "whisper-1", + "tags": "SPEECH2TEXT", + "max_tokens": 25000000, + "model_type": "speech2text", + "is_tools": false + }, + { + "llm_name": "jina-reranker-v2-base-multilingual", + "tags": "RE-RANK,8k", + "max_tokens": 8192, + "model_type": "rerank", + "is_tools": false + } + ] + }, + { + "name": "Astraflow-CN", + "logo": "", + "tags": "LLM,TEXT EMBEDDING", + "status": "1", + "rank": "249", + "url": "https://api.modelverse.cn/v1", + "llm": [ + { + "llm_name": "claude-opus-4-7", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "claude-opus-4-6", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "claude-sonnet-4-5-20250929", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "claude-haiku-4-5-20251001", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-5.4", + "tags": "LLM,CHAT,400k", + "max_tokens": 400000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-5.4-mini", + "tags": "LLM,CHAT,400k", + "max_tokens": 400000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-5.4-nano", + "tags": "LLM,CHAT,400k", + "max_tokens": 400000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-4o-mini", + "tags": "LLM,CHAT,128k", + "max_tokens": 128000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-Max", + "tags": "LLM,CHAT,131k", + "max_tokens": 131072, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-Coder", + "tags": "LLM,CHAT,131k", + "max_tokens": 131072, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-32B", + "tags": "LLM,CHAT,131k", + "max_tokens": 131072, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "Qwen/Qwen3-VL-235B-A22B-Instruct", + "tags": "LLM,CHAT,131k", + "max_tokens": 131072, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "kimi-k2.6", + "tags": "LLM,CHAT,200k", + "max_tokens": 200000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "glm-5.1", + "tags": "LLM,CHAT,128k", + "max_tokens": 128000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "MiniMax-M2.7", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "MiniMax-M2", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gemini-2.5-pro", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gemini-2.5-flash", + "tags": "LLM,CHAT,1M", + "max_tokens": 1000000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "qwen3-embedding-8b", + "tags": "TEXT EMBEDDING,8K", + "max_tokens": 8192, + "model_type": "embedding", + "is_tools": false + }, + { + "llm_name": "text-embedding-3-large", + "tags": "TEXT EMBEDDING,8K", + "max_tokens": 8191, + "model_type": "embedding", + "is_tools": false + }, + { + "llm_name": "text-embedding-ada-002", + "tags": "TEXT EMBEDDING,8K", + "max_tokens": 8191, + "model_type": "embedding", + "is_tools": false + } + ] + }, { "name": "Avian", "logo": "", diff --git a/conf/mapping.json b/conf/mapping.json index f32acb02bc3..495f7c7763c 100644 --- a/conf/mapping.json +++ b/conf/mapping.json @@ -92,7 +92,7 @@ { "kwd": { "match_pattern": "regex", - "match": "^(.*_(kwd|id|ids|uid|uids)|uid)$", + "match": "^(.*_(kwd|id|ids|uid|uids)|uid|id)$", "mapping": { "type": "keyword", "similarity": "boolean", diff --git a/conf/models/aliyun.json b/conf/models/aliyun.json new file mode 100644 index 00000000000..51adef5d748 --- /dev/null +++ b/conf/models/aliyun.json @@ -0,0 +1,52 @@ +{ + "name": "Aliyun", + "url": { + "default": "https://dashscope.aliyuncs.com", + "singapore": "https://dashscope-intl.aliyuncs.com", + "us": "https://dashscope-us.aliyuncs.com" + }, + "url_suffix": { + "chat": "compatible-mode/v1/chat/completions", + "embedding": "compatible-mode/v1/embeddings", + "rerank": "compatible-api/v1/reranks", + "models": "api/v1/deployments/models" + }, + "models": [ + { + "name": "qwen-flash", + "max_tokens": 995904, + "model_types": [ + "chat" + ] + }, + { + "name": "text-embedding-v4", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "text-embedding-v3", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "qwen3-rerank", + "max_tokens": 8192, + "model_types": [ + "rerank" + ] + } + ], + "features": { + "thinking": { + "default_value": true, + "supported_models": [ + "qwen-flash" + ] + } + } +} \ No newline at end of file diff --git a/conf/models/deepseek.json b/conf/models/deepseek.json new file mode 100644 index 00000000000..146e11862a9 --- /dev/null +++ b/conf/models/deepseek.json @@ -0,0 +1,36 @@ +{ + "name": "DeepSeek", + "url": { + "default": "https://api.deepseek.com" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "balance": "user/balance" + }, + "class": "deepseek", + "models": [ + { + "name": "deepseek-v4-flash", + "max_tokens": 1048576, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "deepseek-v4-pro", + "max_tokens": 1048576, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + } + ] +} \ No newline at end of file diff --git a/conf/models/gitee.json b/conf/models/gitee.json new file mode 100644 index 00000000000..630106592f2 --- /dev/null +++ b/conf/models/gitee.json @@ -0,0 +1,44 @@ +{ + "name": "Gitee", + "url": { + "default": "https://api.moark.com/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "status": "", + "balance": "tokens/packages/balance", + "embedding": "embedding", + "rerank": "rerank" + }, + "models": [ + { + "name": "qwen3-8b", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "qwen3-0.6b", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "glm-4.7-flash", + "max_tokens": 204800, + "model_types": [ + "chat" + ] + }, + { + "name": "BAAI/bge-reranker-v2-m3", + "max_tokens": 8192, + "model_types": [ + "rerank" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/google.json b/conf/models/google.json new file mode 100644 index 00000000000..2e4cf30525f --- /dev/null +++ b/conf/models/google.json @@ -0,0 +1,37 @@ +{ + "name": "Google", + "url": { + "default": "https://generativelanguage.googleapis.com" + }, + "url_suffix": { + "models": "v1beta/models" + }, + "class": "gemini", + "models": [ + { + "name": "gemini-2.5-flash", + "max_tokens": 1048576, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + } + ], + "features": { + "thinking": { + "default_value": true, + "supported_models": [ + "gemini-2.5-flash" + ] + }, + "reasoning_effort": { + "default_value": "high", + "supported_modes": [ + "gemini-2.5-flash" + ] + } + } +} \ No newline at end of file diff --git a/conf/models/huggingface.json b/conf/models/huggingface.json new file mode 100644 index 00000000000..c46ab4a46bd --- /dev/null +++ b/conf/models/huggingface.json @@ -0,0 +1,21 @@ +{ + "name": "HuggingFace", + "url": { + "default": "https://router.huggingface.co/v1/" + }, + "url-suffix": { + "chat": "chat/completions", + "models": "models", + "embedding": "hf-inference/models" + }, + "class": "huggingface", + "models": [ + { + "name": "openai/gpt-oss-120b:fastest", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/lmstudio.json b/conf/models/lmstudio.json new file mode 100644 index 00000000000..a22cbb982fe --- /dev/null +++ b/conf/models/lmstudio.json @@ -0,0 +1,8 @@ +{ + "name": "lmstudio", + "url_suffix": { + "chat": "chat/completions", + "models": "models" + }, + "class": "local" +} \ No newline at end of file diff --git a/conf/models/minimax.json b/conf/models/minimax.json new file mode 100644 index 00000000000..31760ac2597 --- /dev/null +++ b/conf/models/minimax.json @@ -0,0 +1,104 @@ +{ + "name": "MiniMax", + "url": { + "default": "https://api.minimaxi.com/", + "global": "https://api.minimax.io/" + }, + "url_suffix": { + "chat": "v1/text/chatcompletion_v2", + "models": "v1/models", + "tts": "v1/t2a_v2", + "files": "v1/files/list" + }, + "class": "minimax", + "models": [ + { + "name": "minimax-m2.7", + "max_tokens": 204800, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "minimax-m2.7-highspeed", + "max_tokens": 204800, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "minimax-m2.5", + "max_tokens": 204800, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "minimax-m2.5-highspeed", + "max_tokens": 204800, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "minimax-m2.1", + "max_tokens": 204800, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "minimax-m2.1-highspeed", + "max_tokens": 204800, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "minimax-m2", + "max_tokens": 204800, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "minimax-m2-her", + "max_tokens": 65536, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + } + ] +} \ No newline at end of file diff --git a/conf/models/moonshot.json b/conf/models/moonshot.json new file mode 100644 index 00000000000..b9df95e0c22 --- /dev/null +++ b/conf/models/moonshot.json @@ -0,0 +1,84 @@ +{ + "name": "Moonshot", + "url": { + "default": "https://api.moonshot.cn/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "balance": "users/me/balance" + }, + "class": "kimi", + "models": [ + { + "name": "kimi-k2.6", + "max_tokens": 262144, + "model_types": [ + "chat", + "vision" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "kimi-k2.5", + "max_tokens": 262144, + "model_types": [ + "chat", + "vision" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "moonshot-v1-8k", + "max_tokens": 8000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "moonshot-v1-32k", + "max_tokens": 32000, + "model_types": [ + "chat" + ] + }, + { + "name": "moonshot-v1-128k", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "moonshot-v1-8k-vision-preview", + "max_tokens": 8000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "moonshot-v1-32k-vision-preview", + "max_tokens": 32000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "moonshot-v1-128k-vision-preview", + "max_tokens": 128000, + "model_types": [ + "chat", + "vision" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/nvidia.json b/conf/models/nvidia.json new file mode 100644 index 00000000000..8ba81f1fd3f --- /dev/null +++ b/conf/models/nvidia.json @@ -0,0 +1,461 @@ +{ + "name": "Nvidia", + "url": { + "default": "https://integrate.api.nvidia.com/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models" + }, + "class": "nvidia", + "models": [ + { + "name": "abacusai/dracarys-llama-3.1-70b-instruct", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "bytedance/seed-oss-36b-instruct", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "deepseek-ai/deepseek-v4-flash", + "max_tokens": 1048576, + "model_types": [ + "chat" + ] + }, + { + "name": "deepseek-ai/deepseek-v4-pro", + "max_tokens": 1048576, + "model_types": [ + "chat" + ] + }, + { + "name": "deepseek-ai/deepseek-v3.2", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "deepseek-ai/deepseek-v3.1", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "google/codegemma-7b", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, + { + "name": "google/gemma-2-2b-it", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, + { + "name": "google/gemma-4-31b-it", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "google/gemma-7b", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, + { + "name": "ibm/granite-3.3-8b-instruct", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "meta/llama-3.1-405b-instruct", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "meta/llama-3.2-90b-vision-instruct", + "max_tokens": 131072, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "meta/llama-4-maverick-17b-128e-instruct", + "max_tokens": 1048576, + "model_types": [ + "chat" + ] + }, + { + "name": "microsoft/phi-4-mini-flash-reasoning", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "minimaxai/minimax-m2.1", + "max_tokens": 204800, + "model_types": [ + "chat" + ] + }, + { + "name": "minimaxai/minimax-m2.5", + "max_tokens": 204800, + "model_types": [ + "chat" + ] + }, + { + "name": "minimaxai/minimax-m2.7", + "max_tokens": 204800, + "model_types": [ + "chat" + ] + }, + { + "name": "mistralai/devstral-2-123b-instruct-2512", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "mistralai/magistral-small-2506", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "mistralai/mistral-7b-instruct-v0.3", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "mistralai/mistral-large-3-675b-instruct-2512", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "mistralai/mistral-medium-3-5-128b", + "max_tokens": 131072, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "mistralai/mistral-nemotron", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "mistralai/mixtral-8x22b-instruct", + "max_tokens": 65536, + "model_types": [ + "chat" + ] + }, + { + "name": "moonshotai/kimi-k2.5", + "max_tokens": 262144, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "moonshotai/kimi-k2.6", + "max_tokens": 262144, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "moonshotai/kimi-k2-instruct", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "moonshotai/kimi-k2-instruct-0905", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "moonshotai/kimi-k2-thinking", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "nvidia/gliner-pii", + "max_tokens": 4096, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/llama-3.1-nemoguard-8b-content-safety", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/llama-3.1-nemoguard-8b-topic-control", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/llama-3.1-nemotron-nano-8b-v1", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/llama-3.1-nemotron-safety-guard-8b-v3", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/llama-3.1-nemotron-ultra-253b-v1", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "nvidia/llama-3.2-nemoretriever-1b-vlm-embed-v1", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "nvidia/llama-3.3-nemotron-super-49b-v1", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/llama-3.3-nemotron-super-49b-v1.5", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "nvidia/nemoguard-jailbreak-detect", + "max_tokens": 4096, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/nemotron-3-nano-30b-a3b", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/nemotron-3-nano-omni-30b-a3b-reasoning", + "max_tokens": 131072, + "model_types": [ + "chat", + "vision" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "nvidia/nemotron-3-super-120b-a12b", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/nemotron-content-safety-reasoning-4b", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/nemotron-mini-4b-instruct", + "max_tokens": 4096, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/nvidia-nemotron-nano-9b-v2", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/riva-translate-4b-instruct-v1_1", + "max_tokens": 4096, + "model_types": [ + "chat" + ] + }, + { + "name": "nvidia/usdcode", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, + { + "name": "openai/gpt-oss-120b", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "qwen/qwen2.5-coder-7b-instruct", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "qwen/qwen3-5-122b-a10b", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "qwen/qwen3-235b-a22b", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "qwen/qwen3-coder-480b-a35b-instruct", + "max_tokens": 262144, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "z-ai/glm-5", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "z-ai/glm-5.1", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "z-ai/glm-4.7", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + } + ] +} \ No newline at end of file diff --git a/conf/models/ollama.json b/conf/models/ollama.json new file mode 100644 index 00000000000..ed0a1e011b9 --- /dev/null +++ b/conf/models/ollama.json @@ -0,0 +1,8 @@ +{ + "name": "ollama", + "url_suffix": { + "chat": "chat/completions", + "models": "models" + }, + "class": "local" +} \ No newline at end of file diff --git a/conf/models/openai.json b/conf/models/openai.json index f89c6c0d1db..696c6f93b3c 100644 --- a/conf/models/openai.json +++ b/conf/models/openai.json @@ -4,8 +4,10 @@ "default": "https://api.openai.com/v1" }, "url_suffix": { - "chat": "chat/completions" + "chat": "chat/completions", + "models": "models" }, + "class": "gpt", "models": [ { "name": "gpt-5.2-pro", @@ -13,8 +15,7 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-5.2", @@ -22,8 +23,7 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-5.1", @@ -31,8 +31,7 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-5.1-chat-latest", @@ -40,8 +39,7 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-5", @@ -49,8 +47,7 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-5-mini", @@ -58,8 +55,7 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-5-nano", @@ -67,8 +63,7 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-5-chat-latest", @@ -76,8 +71,7 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-4.1", @@ -85,8 +79,7 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-4.1-mini", @@ -94,8 +87,7 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-4.1-nano", @@ -103,43 +95,14 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-4.5-preview", "max_tokens": 128000, "model_types": [ "chat" - ], - "features": {} - }, - { - "name": "o3", - "max_tokens": 200000, - "model_types": [ - "chat", - "vision" - ], - "features": {} - }, - { - "name": "o4-mini", - "max_tokens": 200000, - "model_types": [ - "chat", - "vision" - ], - "features": {} - }, - { - "name": "o4-mini-high", - "max_tokens": 200000, - "model_types": [ - "chat", - "vision" - ], - "features": {} + ] }, { "name": "gpt-4o-mini", @@ -147,8 +110,7 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-4o", @@ -156,88 +118,77 @@ "model_types": [ "chat", "vision" - ], - "features": {} + ] }, { "name": "gpt-3.5-turbo", "max_tokens": 4096, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "gpt-3.5-turbo-16k-0613", "max_tokens": 16385, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "text-embedding-ada-002", "max_tokens": 8191, "model_types": [ "embedding" - ], - "features": {} + ] }, { "name": "text-embedding-3-small", "max_tokens": 8191, "model_types": [ "embedding" - ], - "features": {} + ] }, { "name": "text-embedding-3-large", "max_tokens": 8191, "model_types": [ "embedding" - ], - "features": {} + ] }, { "name": "whisper-1", "max_tokens": 26214400, "model_types": [ "asr" - ], - "features": {} + ] }, { "name": "gpt-4", "max_tokens": 8191, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "gpt-4-turbo", "max_tokens": 8191, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "gpt-4-32k", "max_tokens": 32768, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "tts-1", "max_tokens": 2048, "model_types": [ "tts" - ], - "features": {} + ] } ] } \ No newline at end of file diff --git a/conf/models/openrouter.json b/conf/models/openrouter.json new file mode 100644 index 00000000000..6af1e2d15df --- /dev/null +++ b/conf/models/openrouter.json @@ -0,0 +1,49 @@ +{ + "name": "OpenRouter", + "url": { + "default": "https://openrouter.ai/api/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "embedding": "embeddings", + "rerank": "rerank", + "balance": "credits" + }, + "class": "openrouter", + "models": [ + { + "name": "google/gemma-4-31b-it", + "max_tokens": 262144, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "minimax/minimax-m2.5", + "max_tokens": 196608, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "tencent/hy3-preview", + "max_tokens": 262144, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + } + ] +} \ No newline at end of file diff --git a/conf/models/siliconflow.json b/conf/models/siliconflow.json new file mode 100644 index 00000000000..4da3e0dcab8 --- /dev/null +++ b/conf/models/siliconflow.json @@ -0,0 +1,50 @@ +{ + "name": "SiliconFlow", + "url": { + "default": "https://api.siliconflow.cn/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "embedding": "embeddings", + "rerank": "rerank", + "balance": "user/info" + }, + "models": [ + { + "name": "qwen/qwen3-8b", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "qwen/qwen3.5-4b", + "max_tokens": 262144, + "model_types": [ + "chat" + ] + }, + { + "name": "tencent/hunyuan-mt-7b", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "BAAI/bge-reranker-v2-m3", + "max_tokens": 8192, + "model_types": [ + "rerank" + ] + }, + { + "name": "Qwen/Qwen3-Embedding-0.6B", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + } + ] +} diff --git a/conf/models/vllm.json b/conf/models/vllm.json new file mode 100644 index 00000000000..96ec1a2403b --- /dev/null +++ b/conf/models/vllm.json @@ -0,0 +1,8 @@ +{ + "name": "vllm", + "url_suffix": { + "chat": "chat/completions", + "models": "models" + }, + "class": "local" +} \ No newline at end of file diff --git a/conf/models/volcengine.json b/conf/models/volcengine.json new file mode 100644 index 00000000000..96a6004097a --- /dev/null +++ b/conf/models/volcengine.json @@ -0,0 +1,32 @@ +{ + "name": "VolcEngine", + "url": { + "default": "https://ark.cn-beijing.volces.com/api/v3" + }, + "url_suffix": { + "chat": "chat/completions", + "files": "files", + "embedding": "embeddings/multimodal" + }, + "class": "volcengine", + "models": [ + { + "name": "doubao-seed-2-0-pro-260215", + "max_tokens": 262144, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "doubao-embedding-vision-250615", + "max_tokens": 131072, + "model_types": [ + "embedding" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/xai.json b/conf/models/xai.json index 5e12776c92e..41fe7978f12 100644 --- a/conf/models/xai.json +++ b/conf/models/xai.json @@ -6,42 +6,37 @@ "url_suffix": { "chat": "chat/completions" }, + "class": "grok", "models": [ { "name": "grok-4", "max_tokens": 256000, - "model_types": ["chat"], - "features": {} + "model_types": ["chat"] }, { "name": "grok-3", "max_tokens": 131072, - "model_types": ["chat"], - "features": {} + "model_types": ["chat"] }, { "name": "grok-3-fast", "max_tokens": 131072, - "model_types": ["chat"], - "features": {} + "model_types": ["chat"] }, { "name": "grok-3-mini", "max_tokens": 131072, - "model_types": ["chat"], - "features": {} + "model_types": ["chat"] }, { "name": "grok-3-mini-mini-fast", "max_tokens": 131072, - "model_types": ["chat"], - "features": {} + "model_types": ["chat"] }, { "name": "grok-2-vision", "max_tokens": 32768, - "model_types": ["vision"], - "features": {} + "model_types": ["vision"] } ] } \ No newline at end of file diff --git a/conf/models/zhipu-ai.json b/conf/models/zhipu-ai.json index b38624bffe2..d1bbac649fd 100644 --- a/conf/models/zhipu-ai.json +++ b/conf/models/zhipu-ai.json @@ -7,66 +7,144 @@ "chat": "chat/completions", "async_chat": "async/chat/completions", "async_result": "async-result", - "embedding": "embedding", - "rerank": "rerank" + "embedding": "embeddings", + "rerank": "rerank", + "files": "files" }, + "class": "glm", "models": [ + { + "name": "glm-5", + "max_tokens": 204800, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "glm-5-turbo", + "max_tokens": 204800, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "glm-5v-turbo", + "max_tokens": 204800, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, { "name": "glm-4.7", - "max_tokens": 128000, + "max_tokens": 204800, "model_types": [ "chat" ], - "features": {} + "thinking": { + "default_value": true, + "clear_thinking": true + } }, { - "name": "glm-4.5", - "max_tokens": 128000, + "name": "glm-4.7-flashx", + "max_tokens": 204800, "model_types": [ "chat" ], - "features": {} + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "glm-4.6", + "max_tokens": 204800, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } }, { "name": "glm-4.6v-Flash", - "max_tokens": 128000, + "max_tokens": 131072, "model_types": [ "chat", "vision" ], - "features": {} + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "glm-4.5", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } }, { "name": "glm-4.5-x", - "max_tokens": 128000, + "max_tokens": 131072, "model_types": [ "chat" ], - "features": {} + "thinking": { + "default_value": true, + "clear_thinking": true + } }, { "name": "glm-4.5-air", - "max_tokens": 128000, + "max_tokens": 131072, "model_types": [ "chat" ], - "features": {} + "thinking": { + "default_value": true, + "clear_thinking": true + } }, { "name": "glm-4.5-airx", - "max_tokens": 128000, + "max_tokens": 131072, "model_types": [ "chat" ], - "features": {} + "thinking": { + "default_value": true, + "clear_thinking": true + } }, { "name": "glm-4.5-flash", - "max_tokens": 128000, + "max_tokens": 131072, "model_types": [ "chat" ], - "features": {} + "thinking": { + "default_value": true, + "clear_thinking": true + } }, { "name": "glm-4.5v", @@ -74,168 +152,119 @@ "model_types": [ "vision" ], - "features": {} + "thinking": { + "default_value": true, + "clear_thinking": true + } }, { "name": "glm-4-plus", - "max_tokens": 128000, + "max_tokens": 131072, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "glm-4-0520", - "max_tokens": 128000, + "max_tokens": 131072, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "glm-4", - "max_tokens": 128000, + "max_tokens": 131072, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "glm-4-airx", "max_tokens": 8000, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "glm-4-air", - "max_tokens": 128000, + "max_tokens": 131072, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "glm-4-flash", - "max_tokens": 128000, + "max_tokens": 131072, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "glm-4-flashx", - "max_tokens": 128000, + "max_tokens": 131072, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "glm-4-long", "max_tokens": 1000000, "model_types": [ "chat" - ], - "features": {} - }, - { - "name": "glm-3-turbo", - "max_tokens": 128000, - "model_types": [ - "chat" - ], - "features": {} + ] }, { "name": "glm-4v", "max_tokens": 2000, "model_types": [ "vision" - ], - "features": {} + ] }, { "name": "glm-4-9b", "max_tokens": 8192, "model_types": [ "chat" - ], - "features": {} + ] }, { "name": "embedding-2", "max_tokens": 512, "model_types": [ "embedding" - ], - "features": {} + ] }, { "name": "embedding-3", "max_tokens": 512, "model_types": [ "embedding" - ], - "features": {} + ] }, { - "name": "glm-asr", + "name": "glm-asr-2512", "max_tokens": 4096, "model_types": [ "asr" - ], - "features": {} + ] }, { "name": "glm-tts", "model_types": [ "tts" - ], - "features": {} + ] }, { "name": "glm-ocr", "model_types": [ "ocr" - ], - "features": {} + ] }, { - "name": "glm-rerank", + "name": "rerank", "model_types": [ "rerank" - ], - "features": {} - } - ], - "features": { - "thinking": { - "default_value": true, - "supported_models": [ - "glm-5.1", - "glm-5", - "glm-5v-turbo", - "glm-4.7", - "glm-4.6", - "glm-4.6v", - "glm-4.5", - "glm-4.5v" - ] - }, - "clear_thinking": { - "default_value": true, - "supported_models": [ - "glm-5.1", - "glm-5", - "glm-5v-turbo", - "glm-4.7", - "glm-4.6", - "glm-4.6v", - "glm-4.5", - "glm-4.5v" ] } - } + ] } \ No newline at end of file diff --git a/conf/skill_es_mapping.json b/conf/skill_es_mapping.json new file mode 100644 index 00000000000..a9d3cba8699 --- /dev/null +++ b/conf/skill_es_mapping.json @@ -0,0 +1,136 @@ +{ + "settings": { + "index": { + "number_of_shards": 1, + "number_of_replicas": 0, + "refresh_interval": "1000ms" + }, + "similarity": { + "scripted_sim": { + "type": "scripted", + "script": { + "source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);" + } + } + } + }, + "mappings": { + "dynamic": false, + "properties": { + "skill_id": { + "type": "keyword", + "store": true + }, + "space_id": { + "type": "keyword", + "store": true + }, + "folder_id": { + "type": "keyword", + "store": true + }, + "name": { + "type": "text", + "index": false, + "store": true + }, + "name_tks": { + "type": "text", + "similarity": "scripted_sim", + "analyzer": "whitespace", + "store": true + }, + "tags": { + "type": "text", + "index": false, + "store": true + }, + "tags_tks": { + "type": "text", + "similarity": "scripted_sim", + "analyzer": "whitespace", + "store": true + }, + "description": { + "type": "text", + "index": false, + "store": true + }, + "description_tks": { + "type": "text", + "similarity": "scripted_sim", + "analyzer": "whitespace", + "store": true + }, + "content": { + "type": "text", + "index": false, + "store": true + }, + "content_tks": { + "type": "text", + "similarity": "scripted_sim", + "analyzer": "whitespace", + "store": true + }, + "q_3072_vec": { + "type": "dense_vector", + "dims": 3072, + "index": true, + "similarity": "cosine" + }, + "q_2560_vec": { + "type": "dense_vector", + "dims": 2560, + "index": true, + "similarity": "cosine" + }, + "q_1536_vec": { + "type": "dense_vector", + "dims": 1536, + "index": true, + "similarity": "cosine" + }, + "q_1024_vec": { + "type": "dense_vector", + "dims": 1024, + "index": true, + "similarity": "cosine" + }, + "q_768_vec": { + "type": "dense_vector", + "dims": 768, + "index": true, + "similarity": "cosine" + }, + "q_512_vec": { + "type": "dense_vector", + "dims": 512, + "index": true, + "similarity": "cosine" + }, + "q_256_vec": { + "type": "dense_vector", + "dims": 256, + "index": true, + "similarity": "cosine" + }, + "version": { + "type": "keyword", + "store": true + }, + "status": { + "type": "keyword", + "store": true + }, + "create_time": { + "type": "long", + "store": true + }, + "update_time": { + "type": "long", + "store": true + } + } + } +} diff --git a/conf/skill_infinity_mapping.json b/conf/skill_infinity_mapping.json new file mode 100644 index 00000000000..4e4766ea8f5 --- /dev/null +++ b/conf/skill_infinity_mapping.json @@ -0,0 +1,64 @@ +{ + "skill_id": { + "type": "varchar", + "default": "", + "index_type": "secondary" + }, + "space_id": { + "type": "varchar", + "default": "", + "index_type": "secondary" + }, + "folder_id": { + "type": "varchar", + "default": "" + }, + "name": { + "type": "varchar", + "default": "", + "analyzer": [ + "rag-coarse", + "rag-fine" + ] + }, + "tags": { + "type": "varchar", + "default": "", + "analyzer": [ + "rag-coarse", + "rag-fine" + ] + }, + "description": { + "type": "varchar", + "default": "", + "analyzer": [ + "rag-coarse", + "rag-fine" + ] + }, + "content": { + "type": "varchar", + "default": "", + "analyzer": [ + "rag-coarse", + "rag-fine" + ] + }, + "version": { + "type": "varchar", + "default": "1.0.0" + }, + "status": { + "type": "varchar", + "default": "1" + }, + "create_time": { + "type": "bigint", + "default": 0 + }, + "update_time": { + "type": "bigint", + "default": 0 + } +} \ No newline at end of file diff --git a/deepdoc/parser/docling_parser.py b/deepdoc/parser/docling_parser.py index a2ebc400255..948a7acb0cd 100644 --- a/deepdoc/parser/docling_parser.py +++ b/deepdoc/parser/docling_parser.py @@ -30,10 +30,12 @@ import requests from PIL import Image +from common.constants import MAXIMUM_PAGE_NUMBER + try: from docling.document_converter import DocumentConverter except Exception: - DocumentConverter = None + DocumentConverter = None try: from deepdoc.parser.pdf_parser import RAGFlowPdfParser @@ -44,6 +46,7 @@ class RAGFlowPdfParser: from deepdoc.parser.utils import extract_pdf_outlines + class DoclingContentType(str, Enum): IMAGE = "image" TABLE = "table" @@ -124,7 +127,7 @@ def check_installation(self, docling_server_url: Optional[str] = None) -> bool: self.logger.error(f"[Docling] init DocumentConverter failed: {e}") return False - def __images__(self, fnm, zoomin: int = 1, page_from=0, page_to=600, callback=None): + def __images__(self, fnm, zoomin: int = 1, page_from=0, page_to=MAXIMUM_PAGE_NUMBER, callback=None): self.page_from = page_from self.page_to = page_to bytes_io = None @@ -350,6 +353,13 @@ def _parse_pdf_remote( docling_server_url: Optional[str] = None, request_timeout: Optional[int] = None, ): + """ + Parses a PDF document using a remote Docling server. + + Prioritizes native chunking endpoints (/v1/chunk/source, /v1alpha/chunk/source) + to prevent token overflow, with a graceful fallback to standard conversion + endpoints if chunking is unavailable. + """ server_url = self._effective_server_url(docling_server_url) if not server_url: raise RuntimeError("[Docling] DOCLING_SERVER_URL is not configured.") @@ -372,36 +382,48 @@ def _parse_pdf_remote( filename = Path(filepath).name or "input.pdf" b64 = base64.b64encode(pdf_bytes).decode("ascii") - v1_payload = { - "options": { - "from_formats": ["pdf"], - "to_formats": ["json", "md", "text"], - }, - "sources": [ - { - "kind": "file", - "filename": filename, - "base64_string": b64, - } - ], + + # Standard payloads + # Standard fallback payloads (no chunking) + v1_payload_standard = { + "options": {"from_formats": ["pdf"], "to_formats": ["json", "md", "text"]}, + "sources": [{"kind": "file", "filename": filename, "base64_string": b64}], + } + v1alpha_payload_standard = { + "options": {"from_formats": ["pdf"], "to_formats": ["json", "md", "text"]}, + "file_sources": [{"filename": filename, "base64_string": b64}], + } + + # --- NEW: Correct API Contract for Chunking --- + chunking_opts = { + "from_formats": ["pdf"], + "to_formats": ["json", "md", "text"], + "do_chunking": True, + "chunking_options": { + "max_tokens": 512, + "overlap": 50, + "tokenizer": "sentencepiece" # Required by Docling contract + } } - v1alpha_payload = { - "options": { - "from_formats": ["pdf"], - "to_formats": ["json", "md", "text"], - }, - "file_sources": [ - { - "filename": filename, - "base64_string": b64, - } - ], + v1_payload_chunked = { + "options": chunking_opts, + "sources": [{"kind": "file", "filename": filename, "base64_string": b64}], } + v1alpha_payload_chunked = { + "options": chunking_opts, + "file_sources": [{"filename": filename, "base64_string": b64}], + } + errors = [] response_json = None - for endpoint, payload in ( - ("/v1/convert/source", v1_payload), - ("/v1alpha/convert/source", v1alpha_payload), + is_chunked_response = False + + # Try chunked endpoints first, then fall back to standard if the server is older + for endpoint, payload, chunk_flag in ( + ("/v1/convert/source", v1_payload_chunked, True), + ("/v1alpha/convert/source", v1alpha_payload_chunked, True), + ("/v1/convert/source", v1_payload_standard, False), + ("/v1alpha/convert/source", v1alpha_payload_standard, False), ): try: resp = requests.post( @@ -411,20 +433,57 @@ def _parse_pdf_remote( ) if resp.status_code < 300: response_json = resp.json() + is_chunked_response = chunk_flag + + if chunk_flag: + self.logger.info(f"[Docling] Successfully used native chunking on: {endpoint}") + else: + self.logger.info(f"[Docling] Chunking unavailable, fell back to standard: {endpoint}") break + + # If chunking request is rejected (e.g., 422 Unprocessable Entity on older servers), + # log it and let the loop naturally fall back to the standard payload. + if chunk_flag: + self.logger.warning(f"[Docling] Server rejected chunking parameters: HTTP {resp.status_code}") + continue + errors.append(f"{endpoint}: HTTP {resp.status_code} {resp.text[:300]}") + except Exception as exc: + self.logger.error(f"[Docling] Request error on {endpoint}: {exc}") errors.append(f"{endpoint}: {exc}") if response_json is None: raise RuntimeError("[Docling] remote convert failed: " + " | ".join(errors)) + sections: list[tuple[str, ...]] = [] + tables = [] + + # --- NEW: Handle Native Chunked Response --- + if is_chunked_response: + # The chunking endpoint returns an array of chunk items + chunks = response_json if isinstance(response_json, list) else response_json.get("results", []) + for chunk_data in chunks: + if not isinstance(chunk_data, dict): + continue + # Depending on the exact docling-serve spec, the text might be nested + chunk_text = chunk_data.get("text", "") + if not chunk_text and isinstance(chunk_data.get("chunk"), dict): + chunk_text = chunk_data["chunk"].get("text", "") + + if isinstance(chunk_text, str) and chunk_text.strip(): + # Feed the pre-sliced chunks directly into RAGFlow's expected format + sections.extend(self._sections_from_remote_text(chunk_text, parse_method=parse_method)) + + if callback: + callback(0.95, f"[Docling] Native chunks received: {len(sections)}") + return sections, tables + + # --- FALLBACK: Standard RAGFlow parsing for older docling servers --- docs = self._extract_remote_document_entries(response_json) if not docs: raise RuntimeError("[Docling] remote response does not contain parsed documents.") - sections: list[tuple[str, ...]] = [] - tables = [] for doc in docs: md = doc.get("md_content") txt = doc.get("text_content") diff --git a/deepdoc/parser/docx_parser.py b/deepdoc/parser/docx_parser.py index 0257a320f7f..2d56729b744 100644 --- a/deepdoc/parser/docx_parser.py +++ b/deepdoc/parser/docx_parser.py @@ -21,6 +21,7 @@ from rag.nlp import rag_tokenizer from io import BytesIO import logging +from common.constants import MAXIMUM_PAGE_NUMBER from docx.image.exceptions import ( InvalidImageStreamError, UnexpectedEndOfFileError, @@ -158,7 +159,7 @@ def blockType(b): return lines return ["\n".join(lines)] - def __call__(self, fnm, from_page=0, to_page=100000000): + def __call__(self, fnm, from_page=0, to_page=MAXIMUM_PAGE_NUMBER): self.doc = Document(fnm) if isinstance( fnm, str) else Document(BytesIO(fnm)) pn = 0 # parsed page diff --git a/deepdoc/parser/html_parser.py b/deepdoc/parser/html_parser.py index f4d360c6413..7462ad99e9f 100644 --- a/deepdoc/parser/html_parser.py +++ b/deepdoc/parser/html_parser.py @@ -52,7 +52,7 @@ def parser_txt(cls, txt, chunk_token_num): raise TypeError("txt type should be string!") temp_sections = [] - soup = BeautifulSoup(txt, "html5lib") + soup = BeautifulSoup(txt, "html.parser") # delete