diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000000..a6a922a221b --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,17 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + timezone: "Asia/Shanghai" + day: "friday" + target-branch: "v2" + groups: + python-dependencies: + patterns: + - "*" +# ignore: +# - dependency-name: "pymupdf" +# versions: ["*"] + diff --git a/.github/workflows/build-and-push-python-pg.yml b/.github/workflows/build-and-push-python-pg.yml index bc4dc3f2c77..4640f5edbd0 100644 --- a/.github/workflows/build-and-push-python-pg.yml +++ b/.github/workflows/build-and-push-python-pg.yml @@ -33,13 +33,13 @@ jobs: - name: Checkout uses: actions/checkout@v4 with: - ref: main + ref: v1 - name: Prepare id: prepare run: | DOCKER_IMAGE=ghcr.io/1panel-dev/maxkb-python-pg DOCKER_PLATFORMS=${{ github.event.inputs.architecture }} - TAG_NAME=python3.11-pg15.8 + TAG_NAME=python3.11-pg15.14 DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest" echo ::set-output name=docker_image::${DOCKER_IMAGE} echo ::set-output name=version::${TAG_NAME} diff --git a/.github/workflows/build-and-push.yml b/.github/workflows/build-and-push.yml index 26d2b86d297..1e1daf2696c 100644 --- a/.github/workflows/build-and-push.yml +++ b/.github/workflows/build-and-push.yml @@ -7,7 +7,7 @@ on: inputs: dockerImageTag: description: 'Image Tag' - default: 'v1.10.3-dev' + default: 'v1.10.7-dev' required: true dockerImageTagWithLatest: description: '是否发布latest tag(正式发版时选择,测试版本切勿选择)' @@ -36,7 +36,7 @@ on: jobs: build-and-push-to-fit2cloud-registry: if: ${{ contains(github.event.inputs.registry, 'fit2cloud') }} - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest steps: - name: Check Disk Space run: df -h @@ -52,10 +52,6 @@ jobs: swap-storage: true - name: Check Disk Space run: df -h - - name: Set Swap Space - uses: pierotofy/set-swap-space@master - with: - swap-size-gb: 8 - name: Checkout uses: actions/checkout@v4 with: @@ -68,24 +64,17 @@ jobs: TAG_NAME=${{ github.event.inputs.dockerImageTag }} TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }} if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then - DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest" + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*}" else DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}" fi echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --memory-swap -1 \ - --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=${GITHUB_SHA::8} --no-cache \ + --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \ ${DOCKER_IMAGE_TAGS} . - name: Set up QEMU uses: docker/setup-qemu-action@v3 - with: - # Until https://github.com/tonistiigi/binfmt/issues/215 - image: tonistiigi/binfmt:qemu-v7.0.0-28 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - with: - buildkitd-config-inline: | - [worker.oci] - max-parallelism = 1 - name: Login to GitHub Container Registry uses: docker/login-action@v3 with: @@ -100,11 +89,12 @@ jobs: password: ${{ secrets.FIT2CLOUD_REGISTRY_PASSWORD }} - name: Docker Buildx (build-and-push) run: | + sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches && free -m docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile build-and-push-to-dockerhub: if: ${{ contains(github.event.inputs.registry, 'dockerhub') }} - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest steps: - name: Check Disk Space run: df -h @@ -120,10 +110,6 @@ jobs: swap-storage: true - name: Check Disk Space run: df -h - - name: Set Swap Space - uses: pierotofy/set-swap-space@master - with: - swap-size-gb: 8 - name: Checkout uses: actions/checkout@v4 with: @@ -136,24 +122,17 @@ jobs: TAG_NAME=${{ github.event.inputs.dockerImageTag }} TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }} if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then - DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest" + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*}" else DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}" fi echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --memory-swap -1 \ - --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=${GITHUB_SHA::8} --no-cache \ + --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \ ${DOCKER_IMAGE_TAGS} . - name: Set up QEMU uses: docker/setup-qemu-action@v3 - with: - # Until https://github.com/tonistiigi/binfmt/issues/215 - image: tonistiigi/binfmt:qemu-v7.0.0-28 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - with: - buildkitd-config-inline: | - [worker.oci] - max-parallelism = 1 - name: Login to GitHub Container Registry uses: docker/login-action@v3 with: @@ -167,4 +146,5 @@ jobs: password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Docker Buildx (build-and-push) run: | + sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches && free -m docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile diff --git a/README.md b/README.md index cfe819e56ff..b4a925edb64 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@

MaxKB

-

Ready-to-use AI Chatbot

+

Open-source platform for building enterprise-grade agents

+

强大易用的企业级智能体平台

1Panel-dev%2FMaxKB | Trendshift

License: GPL v3 @@ -10,10 +11,10 @@


-MaxKB = Max Knowledge Base, it is a ready-to-use AI chatbot that integrates Retrieval-Augmented Generation (RAG) pipelines, supports robust workflows, and provides advanced MCP tool-use capabilities. MaxKB is widely applied in scenarios such as intelligent customer service, corporate internal knowledge bases, academic research, and education. +MaxKB = Max Knowledge Brain, it is an open-source platform for building enterprise-grade agents. MaxKB integrates Retrieval-Augmented Generation (RAG) pipelines, supports robust workflows, and provides advanced MCP tool-use capabilities. MaxKB is widely applied in scenarios such as intelligent customer service, corporate internal knowledge bases, academic research, and education. -- **RAG Pipeline**: Supports direct uploading of documents / automatic crawling of online documents, with features for automatic text splitting, vectorization, and RAG (Retrieval-Augmented Generation). This effectively reduces hallucinations in large models, providing a superior smart Q&A interaction experience. -- **Flexible Orchestration**: Equipped with a powerful workflow engine, function library and MCP tool-use, enabling the orchestration of AI processes to meet the needs of complex business scenarios. +- **RAG Pipeline**: Supports direct uploading of documents / automatic crawling of online documents, with features for automatic text splitting, vectorization. This effectively reduces hallucinations in large models, providing a superior smart Q&A interaction experience. +- **Agentic Workflow**: Equipped with a powerful workflow engine, function library and MCP tool-use, enabling the orchestration of AI processes to meet the needs of complex business scenarios. - **Seamless Integration**: Facilitates zero-coding rapid integration into third-party business systems, quickly equipping existing systems with intelligent Q&A capabilities to enhance user satisfaction. - **Model-Agnostic**: Supports various large models, including private models (such as DeepSeek, Llama, Qwen, etc.) and public models (like OpenAI, Claude, Gemini, etc.). - **Multi Modal**: Native support for input and output text, image, audio and video. @@ -23,7 +24,7 @@ MaxKB = Max Knowledge Base, it is a ready-to-use AI chatbot that integrates Retr Execute the script below to start a MaxKB container using Docker: ```bash -docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages 1panel/maxkb +docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages 1panel/maxkb:v1 ``` Access MaxKB web interface at `http://your_server_ip:8080` with default admin credentials: @@ -31,7 +32,7 @@ Access MaxKB web interface at `http://your_server_ip:8080` with default admin cr - username: admin - password: MaxKB@123.. -中国用户如遇到 Docker 镜像 Pull 失败问题,请参照该 [离线安装文档](https://maxkb.cn/docs/installation/offline_installtion/) 进行安装。 +中国用户如遇到 Docker 镜像 Pull 失败问题,请参照该 [离线安装文档](https://maxkb.cn/docs/v1/installation/offline_installtion/) 进行安装。 ## Screenshots @@ -55,8 +56,6 @@ Access MaxKB web interface at `http://your_server_ip:8080` with default admin cr ## Feature Comparison -MaxKB is positioned as an Ready-to-use RAG (Retrieval-Augmented Generation) intelligent Q&A application, rather than a middleware platform for building large model applications. The following table is merely a comparison from a functional perspective. - diff --git a/README_CN.md b/README_CN.md index e55150902ea..07fa00ea4e6 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,25 +1,25 @@

MaxKB

-

基于大模型和 RAG 的知识库问答系统

-

Ready-to-use, flexible RAG Chatbot

+

强大易用的企业级智能体平台

1Panel-dev%2FMaxKB | Trendshift - 1Panel-dev%2FMaxKB | Aliyun

English README - License: GPL v3 + License: GPL v3 Latest release - Stars - Download + Stars + Download + Gitee Stars + GitCode Stars


-MaxKB = Max Knowledge Base,是一款开箱即用的 RAG Chatbot,具备强大的工作流和 MCP 工具调用能力。它支持对接各种主流大语言模型(LLMs),广泛应用于智能客服、企业内部知识库、学术研究与教育等场景。 +MaxKB = Max Knowledge Brain,是一个强大易用的企业级智能体平台,致力于解决企业 AI 落地面临的技术门槛高、部署成本高、迭代周期长等问题,助力企业在人工智能时代赢得先机。秉承“开箱即用,伴随成长”的设计理念,MaxKB 支持企业快速接入主流大模型,高效构建专属知识库,并提供从基础问答(RAG)、复杂流程自动化(工作流)到智能体(Agent)的渐进式升级路径,全面赋能智能客服、智能办公助手等多种应用场景。 -- **开箱即用**:支持直接上传文档 / 自动爬取在线文档,支持文本自动拆分、向量化和 RAG(检索增强生成),有效减少大模型幻觉,智能问答交互体验好; -- **模型中立**:支持对接各种大模型,包括本地私有大模型(DeepSeek R1 / Llama 3 / Qwen 2 等)、国内公共大模型(通义千问 / 腾讯混元 / 字节豆包 / 百度千帆 / 智谱 AI / Kimi 等)和国外公共大模型(OpenAI / Claude / Gemini 等); +- **RAG 检索增强生成**:高效搭建本地 AI 知识库,支持直接上传文档 / 自动爬取在线文档,支持文本自动拆分、向量化,有效减少大模型幻觉,提升问答效果; - **灵活编排**:内置强大的工作流引擎、函数库和 MCP 工具调用能力,支持编排 AI 工作过程,满足复杂业务场景下的需求; -- **无缝嵌入**:支持零编码快速嵌入到第三方业务系统,让已有系统快速拥有智能问答能力,提高用户满意度。 +- **无缝嵌入**:支持零编码快速嵌入到第三方业务系统,让已有系统快速拥有智能问答能力,提高用户满意度; +- **模型中立**:支持对接各种大模型,包括本地私有大模型(DeepSeek R1 / Qwen 3 等)、国内公共大模型(通义千问 / 腾讯混元 / 字节豆包 / 百度千帆 / 智谱 AI / Kimi 等)和国外公共大模型(OpenAI / Claude / Gemini 等)。 MaxKB 三分钟视频介绍:https://www.bilibili.com/video/BV18JypYeEkj/ @@ -27,10 +27,10 @@ MaxKB 三分钟视频介绍:https://www.bilibili.com/video/BV18JypYeEkj/ ``` # Linux 机器 -docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb +docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb:v1 # Windows 机器 -docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/var/lib/postgresql/data -v C:/python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb +docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/var/lib/postgresql/data -v C:/python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb:v1 # 用户名: admin # 密码: MaxKB@123.. @@ -38,8 +38,8 @@ docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/var/lib/po - 你也可以通过 [1Panel 应用商店](https://apps.fit2cloud.com/1panel) 快速部署 MaxKB; - 如果是内网环境,推荐使用 [离线安装包](https://community.fit2cloud.com/#/products/maxkb/downloads) 进行安装部署; -- MaxKB 产品版本分为社区版和专业版,详情请参见:[MaxKB 产品版本对比](https://maxkb.cn/pricing.html); -- 如果您需要向团队介绍 MaxKB,可以使用这个 [官方 PPT 材料](https://maxkb.cn/download/introduce-maxkb_202503.pdf)。 +- MaxKB 不同产品产品版本的对比请参见:[MaxKB 产品版本对比](https://maxkb.cn/price); +- 如果您需要向团队介绍 MaxKB,可以使用这个 [官方 PPT 材料](https://fit2cloud.com/maxkb/download/introduce-maxkb_202507.pdf)。 如你有更多问题,可以查看使用手册,或者通过论坛与我们交流。 diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index c5a0de1a152..05ab5009c0a 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -11,7 +11,6 @@ import re import time from functools import reduce -from types import AsyncGeneratorType from typing import List, Dict from django.db.models import QuerySet @@ -33,13 +32,26 @@ Called MCP Tool: %s -```json %s -``` + """ +tool_message_json_template = """ +```json +%s +``` +""" + + +def generate_tool_message_template(name, context): + if '```' in context: + return tool_message_template % (name, context) + else: + return tool_message_template % (name, tool_message_json_template % (context)) + + def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, reasoning_content: str): chat_model = node_variable.get('chat_model') @@ -102,19 +114,19 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) - async def _yield_mcp_response(chat_model, message_list, mcp_servers): async with MultiServerMCPClient(json.loads(mcp_servers)) as client: agent = create_react_agent(chat_model, client.get_tools()) response = agent.astream({"messages": message_list}, stream_mode='messages') async for chunk in response: if isinstance(chunk[0], ToolMessage): - content = tool_message_template % (chunk[0].name, chunk[0].content) + content = generate_tool_message_template(chunk[0].name, chunk[0].content) chunk[0].content = content yield chunk[0] if isinstance(chunk[0], AIMessageChunk): yield chunk[0] + def mcp_response_generator(chat_model, message_list, mcp_servers): loop = asyncio.new_event_loop() try: @@ -130,6 +142,7 @@ def mcp_response_generator(chat_model, message_list, mcp_servers): finally: loop.close() + async def anext_async(agen): return await agen.__anext__() @@ -186,7 +199,9 @@ def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') self.context['question'] = details.get('question') self.context['reasoning_content'] = details.get('reasoning_content') - self.answer_text = details.get('answer') + self.context['model_setting'] = details.get('model_setting') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting=None, @@ -216,7 +231,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record message_list = self.generate_message_list(system, prompt, history_message) self.context['message_list'] = message_list - if mcp_enable and mcp_servers is not None: + if mcp_enable and mcp_servers is not None and '"stdio"' not in mcp_servers: r = mcp_response_generator(chat_model, message_list, mcp_servers) return NodeResult( {'result': r, 'chat_model': chat_model, 'message_list': message_list, @@ -271,6 +286,7 @@ def get_details(self, index: int, **kwargs): "index": index, 'run_time': self.context.get('run_time'), 'system': self.context.get('system'), + 'model_setting': self.context.get('model_setting'), 'history_message': [{'content': message.content, 'role': message.type} for message in (self.context.get('history_message') if self.context.get( 'history_message') is not None else [])], diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py index d962f7163bb..95445f45612 100644 --- a/apps/application/flow/step_node/application_node/impl/base_application_node.py +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -168,7 +168,8 @@ def save_context(self, details, workflow_manage): self.context['question'] = details.get('question') self.context['type'] = details.get('type') self.context['reasoning_content'] = details.get('reasoning_content') - self.answer_text = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None, @@ -178,7 +179,8 @@ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_c current_chat_id = string_to_uuid(chat_id + application_id) Chat.objects.get_or_create(id=current_chat_id, defaults={ 'application_id': application_id, - 'abstract': message[0:1024] + 'abstract': message[0:1024], + 'client_id': client_id, }) if app_document_list is None: app_document_list = [] diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py index 6a51edd6bae..1d3115e4c67 100644 --- a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py +++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -15,7 +15,9 @@ class BaseReplyNode(IReplyNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') - self.answer_text = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: if reply_type == 'referencing': result = self.get_reference_content(fields) diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py index 6ddcb6e2fca..0c4d09bce5c 100644 --- a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py @@ -66,7 +66,7 @@ def save_image(image_list): for doc in document: file = QuerySet(File).filter(id=doc['file_id']).first() - buffer = io.BytesIO(file.get_byte().tobytes()) + buffer = io.BytesIO(file.get_byte()) buffer.name = doc['name'] # this is the important line for split_handle in (parse_table_handle_list + split_handles): diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py index 7cbbe9cc1d4..dcf35dd3cfd 100644 --- a/apps/application/flow/step_node/form_node/impl/base_form_node.py +++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py @@ -38,7 +38,8 @@ def save_context(self, details, workflow_manage): self.context['start_time'] = details.get('start_time') self.context['form_data'] = form_data self.context['is_submit'] = details.get('is_submit') - self.answer_text = details.get('result') + if self.node_params.get('is_result', False): + self.answer_text = details.get('result') if form_data is not None: for key in form_data: self.context[key] = form_data[key] @@ -70,7 +71,7 @@ def get_answer_list(self) -> List[Answer] | None: "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), 'form_data': self.context.get('form_data', {}), "is_submit": self.context.get("is_submit", False)} - form = f'{json.dumps(form_setting,ensure_ascii=False)}' + form = f'{json.dumps(form_setting, ensure_ascii=False)}' context = self.workflow_manage.get_workflow_content() form_content_format = self.workflow_manage.reset_prompt(form_content_format) prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') @@ -85,7 +86,7 @@ def get_details(self, index: int, **kwargs): "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), 'form_data': self.context.get('form_data', {}), "is_submit": self.context.get("is_submit", False)} - form = f'{json.dumps(form_setting,ensure_ascii=False)}' + form = f'{json.dumps(form_setting, ensure_ascii=False)}' context = self.workflow_manage.get_workflow_content() form_content_format = self.workflow_manage.reset_prompt(form_content_format) prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') diff --git a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py index d21424f750d..0678b81243c 100644 --- a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py +++ b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py @@ -45,6 +45,8 @@ def get_field_value(debug_field_list, name, is_required): def valid_reference_value(_type, value, name): + if value is None: + return if _type == 'int': instance_type = int | float elif _type == 'float': @@ -65,15 +67,22 @@ def valid_reference_value(_type, value, name): def convert_value(name: str, value, _type, is_required, source, node): - if not is_required and value is None: + if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)): return None if not is_required and source == 'reference' and (value is None or len(value) == 0): return None if source == 'reference': + if value and isinstance(value, list) and len(value) == 0: + if not is_required: + return None + else: + raise Exception(f"字段:{name}类型:{_type}值:{value}必填参数") value = node.workflow_manage.get_reference_field( value[0], value[1:]) valid_reference_value(_type, value, name) + if value is None: + return None if _type == 'int': return int(value) if _type == 'float': @@ -113,7 +122,8 @@ def valid_function(function_lib, user_id): class BaseFunctionLibNodeNode(IFunctionLibNode): def save_context(self, details, workflow_manage): self.context['result'] = details.get('result') - self.answer_text = str(details.get('result')) + if self.node_params.get('is_result'): + self.answer_text = str(details.get('result')) def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult: function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first() diff --git a/apps/application/flow/step_node/function_node/impl/base_function_node.py b/apps/application/flow/step_node/function_node/impl/base_function_node.py index 4a5c75c8132..f6127e55550 100644 --- a/apps/application/flow/step_node/function_node/impl/base_function_node.py +++ b/apps/application/flow/step_node/function_node/impl/base_function_node.py @@ -32,6 +32,8 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow): def valid_reference_value(_type, value, name): + if value is None: + return if _type == 'int': instance_type = int | float elif _type == 'float': @@ -49,13 +51,20 @@ def valid_reference_value(_type, value, name): def convert_value(name: str, value, _type, is_required, source, node): - if not is_required and value is None: + if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)): return None if source == 'reference': + if value and isinstance(value, list) and len(value) == 0: + if not is_required: + return None + else: + raise Exception(f"字段:{name}类型:{_type}值:{value}必填参数") value = node.workflow_manage.get_reference_field( value[0], value[1:]) valid_reference_value(_type, value, name) + if value is None: + return None if _type == 'int': return int(value) if _type == 'float': @@ -84,7 +93,8 @@ def convert_value(name: str, value, _type, is_required, source, node): class BaseFunctionNodeNode(IFunctionNode): def save_context(self, details, workflow_manage): self.context['result'] = details.get('result') - self.answer_text = str(details.get('result')) + if self.node_params.get('is_result', False): + self.answer_text = str(details.get('result')) def execute(self, input_field_list, code, **kwargs) -> NodeResult: params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'), diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py index d5cc2c5a211..16423eafd61 100644 --- a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py +++ b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py @@ -16,7 +16,8 @@ class BaseImageGenerateNode(IImageGenerateNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') self.context['question'] = details.get('question') - self.answer_text = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, model_params_setting, @@ -24,7 +25,8 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t **kwargs) -> NodeResult: print(model_params_setting) application = self.workflow_manage.work_flow_post_handler.chat_info.application - tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) + tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), + **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index 3b96f15cd6f..0b405619dde 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -62,14 +62,15 @@ def file_id_to_base64(file_id: str): file = QuerySet(File).filter(id=file_id).first() file_bytes = file.get_byte() base64_image = base64.b64encode(file_bytes).decode("utf-8") - return [base64_image, what(None, file_bytes.tobytes())] + return [base64_image, what(None, file_bytes)] class BaseImageUnderstandNode(IImageUnderstandNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') self.context['question'] = details.get('question') - self.answer_text = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, model_params_setting, @@ -171,7 +172,7 @@ def generate_message_list(self, image_model, system: str, prompt: str, history_m file = QuerySet(File).filter(id=file_id).first() image_bytes = file.get_byte() base64_image = base64.b64encode(image_bytes).decode("utf-8") - image_format = what(None, image_bytes.tobytes()) + image_format = what(None, image_bytes) images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}}) messages = [HumanMessage( content=[ diff --git a/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py index 6c9fe97fc69..d5197e9ad11 100644 --- a/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py +++ b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py @@ -14,7 +14,6 @@ def save_context(self, details, workflow_manage): self.context['result'] = details.get('result') self.context['tool_params'] = details.get('tool_params') self.context['mcp_tool'] = details.get('mcp_tool') - self.answer_text = details.get('result') def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult: servers = json.loads(mcp_servers) @@ -27,7 +26,8 @@ async def call_tool(s, session, t, a): return s res = asyncio.run(call_tool(servers, mcp_server, mcp_tool, params)) - return NodeResult({'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {}) + return NodeResult( + {'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {}) def handle_variables(self, tool_params): # 处理参数中的变量 diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index 48a2639b782..e1fd5b86069 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -80,7 +80,8 @@ def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') self.context['message_tokens'] = details.get('message_tokens') self.context['answer_tokens'] = details.get('answer_tokens') - self.answer_text = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting=None, diff --git a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py index c85588cd4d2..8f48823f00c 100644 --- a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py +++ b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py @@ -18,7 +18,9 @@ class BaseSpeechToTextNode(ISpeechToTextNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') - self.answer_text = details.get('answer') + self.context['result'] = details.get('answer') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult: stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id')) @@ -30,7 +32,7 @@ def process_audio_item(audio_item, model): # 根据file_name 吧文件转成mp3格式 file_format = file.file_name.split('.')[-1] with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_format}') as temp_file: - temp_file.write(file.get_byte().tobytes()) + temp_file.write(file.get_byte()) temp_file_path = temp_file.name with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_amr_file: temp_mp3_path = temp_amr_file.name diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index bf5203274eb..24b9684714e 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -40,10 +40,13 @@ def save_context(self, details, workflow_manage): self.context['document'] = details.get('document_list') self.context['image'] = details.get('image_list') self.context['audio'] = details.get('audio_list') + self.context['other'] = details.get('other_list') self.status = details.get('status') self.err_message = details.get('err_message') for key, value in workflow_variable.items(): workflow_manage.context[key] = value + for item in details.get('global_fields', []): + workflow_manage.context[item.get('key')] = item.get('value') def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: pass @@ -59,7 +62,8 @@ def execute(self, question, **kwargs) -> NodeResult: 'question': question, 'image': self.workflow_manage.image_list, 'document': self.workflow_manage.document_list, - 'audio': self.workflow_manage.audio_list + 'audio': self.workflow_manage.audio_list, + 'other': self.workflow_manage.other_list, } return NodeResult(node_variable, workflow_variable) @@ -83,5 +87,6 @@ def get_details(self, index: int, **kwargs): 'image_list': self.context.get('image'), 'document_list': self.context.get('document'), 'audio_list': self.context.get('audio'), + 'other_list': self.context.get('other'), 'global_fields': global_fields } diff --git a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py index 72c4d3be514..330dc5f5804 100644 --- a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py +++ b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py @@ -37,7 +37,9 @@ def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"): class BaseTextToSpeechNode(ITextToSpeechNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') - self.answer_text = details.get('answer') + self.context['result'] = details.get('result') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') def execute(self, tts_model_id, chat_id, content, model_params_setting=None, @@ -72,4 +74,5 @@ def get_details(self, index: int, **kwargs): 'content': self.context.get('content'), 'err_message': self.err_message, 'answer': self.context.get('answer'), + 'result': self.context.get('result') } diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index be91f69be9e..554b0b75f47 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -14,7 +14,7 @@ from functools import reduce from typing import List, Dict -from django.db import close_old_connections +from django.db import close_old_connections, connection from django.db.models import QuerySet from django.utils import translation from django.utils.translation import get_language @@ -238,6 +238,7 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None, document_list=None, audio_list=None, + other_list=None, start_node_id=None, start_node_data=None, chat_record=None, child_node=None): if form_data is None: @@ -248,12 +249,15 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl document_list = [] if audio_list is None: audio_list = [] + if other_list is None: + other_list = [] self.start_node_id = start_node_id self.start_node = None self.form_data = form_data self.image_list = image_list self.document_list = document_list self.audio_list = audio_list + self.other_list = other_list self.params = params self.flow = flow self.context = {} @@ -294,8 +298,8 @@ def init_fields(self): if global_fields is not None: for global_field in global_fields: global_field_list.append({**global_field, 'node_id': node_id, 'node_name': node_name}) - field_list.sort(key=lambda f: len(f.get('node_name')), reverse=True) - global_field_list.sort(key=lambda f: len(f.get('node_name')), reverse=True) + field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True) + global_field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True) self.field_list = field_list self.global_field_list = global_field_list @@ -565,6 +569,8 @@ def hand_event_node_result(self, current_node, node_result_future): return None finally: current_node.node_chunk.end() + # 归还链接 + connection.close() def run_node_async(self, node): future = executor.submit(self.run_node, node) @@ -674,10 +680,16 @@ def get_next_node(self): return None @staticmethod - def dependent_node(up_node_id, node): + def dependent_node(edge, node): + up_node_id = edge.sourceNodeId if not node.node_chunk.is_end(): return False if node.id == up_node_id: + if node.context.get('branch_id', None): + if edge.sourceAnchorId == f"{node.id}_{node.context.get('branch_id', None)}_right": + return True + else: + return False if node.type == 'form-node': if node.context.get('form_data', None) is not None: return True @@ -690,9 +702,11 @@ def dependent_node_been_executed(self, node_id): @param node_id: 需要判断的节点id @return: """ - up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] - return all([any([self.dependent_node(up_node_id, node) for node in self.node_context]) for up_node_id in - up_node_id_list]) + up_edge_list = [edge for edge in self.flow.edges if edge.targetNodeId == node_id] + return all( + [any([self.dependent_node(edge, node) for node in self.node_context if node.id == edge.sourceNodeId]) for + edge in + up_edge_list]) def get_up_node_id_list(self, node_id): up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] @@ -751,7 +765,10 @@ def get_reference_field(self, node_id: str, fields: List[str]): if node_id == 'global': return INode.get_field(self.context, fields) else: - return self.get_node_by_id(node_id).get_reference_field(fields) + node = self.get_node_by_id(node_id) + if node: + return node.get_reference_field(fields) + return None def get_workflow_content(self): context = { diff --git a/apps/application/migrations/0015_re_database_index.py b/apps/application/migrations/0015_re_database_index.py index 740a2a2d241..cafe14e209c 100644 --- a/apps/application/migrations/0015_re_database_index.py +++ b/apps/application/migrations/0015_re_database_index.py @@ -1,9 +1,8 @@ # Generated by Django 4.2.15 on 2024-09-18 16:14 import logging -import psycopg2 +import psycopg from django.db import migrations -from psycopg2 import extensions from smartdoc.const import CONFIG @@ -17,7 +16,7 @@ def get_connect(db_name): "port": CONFIG.get('DB_PORT') } # 建立连接 - connect = psycopg2.connect(**conn_params) + connect = psycopg.connect(**conn_params) return connect @@ -28,7 +27,7 @@ def sql_execute(conn, reindex_sql: str, alter_database_sql: str): @param conn: @param alter_database_sql: """ - conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) + conn.autocommit = True with conn.cursor() as cursor: cursor.execute(reindex_sql, []) cursor.execute(alter_database_sql, []) diff --git a/apps/application/models/application.py b/apps/application/models/application.py index dfe9534e82b..0032271a70b 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -11,7 +11,7 @@ from django.contrib.postgres.fields import ArrayField from django.db import models from langchain.schema import HumanMessage, AIMessage - +from django.utils.translation import gettext as _ from common.encoder.encoder import SystemEncoder from common.mixins.app_model_mixin import AppModelMixin from dataset.models.data_set import DataSet @@ -167,7 +167,11 @@ def get_human_message(self): return HumanMessage(content=self.problem_text) def get_ai_message(self): - return AIMessage(content=self.answer_text) + answer_text = self.answer_text + if answer_text is None or len(str(answer_text).strip()) == 0: + answer_text = _( + 'Sorry, no relevant content was found. Please re-describe your problem or provide more information. ') + return AIMessage(content=answer_text) def get_node_details_runtime_node_id(self, runtime_node_id): return self.details.get(runtime_node_id, None) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 3792076be7c..b898100160a 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -16,6 +16,7 @@ import uuid from functools import reduce from typing import Dict, List + from django.contrib.postgres.fields import ArrayField from django.core import cache, validators from django.core import signing @@ -24,8 +25,8 @@ from django.db.models.expressions import RawSQL from django.http import HttpResponse from django.template import Template, Context +from django.utils.translation import gettext_lazy as _, get_language, to_locale from langchain_mcp_adapters.client import MultiServerMCPClient -from mcp.client.sse import sse_client from rest_framework import serializers, status from rest_framework.utils.formatting import lazy_format @@ -38,7 +39,7 @@ from common.constants.authentication_type import AuthenticationType from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list -from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed, ChatException +from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed from common.field.common import UploadedImageField, UploadedFileField from common.models.db_model_manage import DBModelManage from common.response import result @@ -57,7 +58,6 @@ from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR from users.models import User -from django.utils.translation import gettext_lazy as _, get_language, to_locale chat_cache = cache.caches['chat_cache'] @@ -148,10 +148,12 @@ class ModelSettingSerializer(serializers.Serializer): error_messages=ErrMessage.char(_("Thinking process switch"))) reasoning_content_start = serializers.CharField(required=False, allow_null=True, default="", allow_blank=True, max_length=256, + trim_whitespace=False, error_messages=ErrMessage.char( _("The thinking process begins to mark"))) reasoning_content_end = serializers.CharField(required=False, allow_null=True, allow_blank=True, default="", max_length=256, + trim_whitespace=False, error_messages=ErrMessage.char(_("End of thinking process marker"))) @@ -162,7 +164,7 @@ class ApplicationWorkflowSerializer(serializers.Serializer): max_length=256, min_length=1, error_messages=ErrMessage.char(_("Application Description"))) work_flow = serializers.DictField(required=False, error_messages=ErrMessage.dict(_("Workflow Objects"))) - prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400, error_messages=ErrMessage.char(_("Opening remarks"))) @staticmethod @@ -225,7 +227,7 @@ class ApplicationSerializer(serializers.Serializer): min_value=0, max_value=1024, error_messages=ErrMessage.integer(_("Historical chat records"))) - prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400, error_messages=ErrMessage.char(_("Opening remarks"))) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), allow_null=True, @@ -320,6 +322,7 @@ def get_embed(self, with_valid=True, params=None): def get_query_api_input(self, application, params): query = '' + is_asker = False if application.work_flow is not None: work_flow = application.work_flow if work_flow is not None: @@ -331,8 +334,10 @@ def get_query_api_input(self, application, params): if input_field_list is not None: for field in input_field_list: if field['assignment_method'] == 'api_input' and field['variable'] in params: + if field['variable'] == 'asker': + is_asker = True query += f"&{field['variable']}={params[field['variable']]}" - if 'asker' in params: + if 'asker' in params and not is_asker: query += f"&asker={params.get('asker')}" return query @@ -493,7 +498,7 @@ class Edit(serializers.Serializer): min_value=0, max_value=1024, error_messages=ErrMessage.integer(_("Historical chat records"))) - prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400, error_messages=ErrMessage.char(_("Opening remarks"))) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), error_messages=ErrMessage.list(_("Related Knowledge Base")) @@ -1010,7 +1015,8 @@ def profile(self, with_valid=True): 'stt_autosend': application.stt_autosend, 'file_upload_enable': application.file_upload_enable, 'file_upload_setting': application.file_upload_setting, - 'work_flow': application.work_flow, + 'work_flow': {'nodes': [node for node in ((application.work_flow or {}).get('nodes', []) or []) if + node.get('id') == 'base-node']}, 'show_source': application_access_token.show_source, 'language': application_access_token.language, **application_setting_dict}) @@ -1071,6 +1077,7 @@ def edit(self, instance: Dict, with_valid=True): for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: application.__setattr__(update_key, instance.get(update_key)) + print(application.name) application.save() if 'dataset_id_list' in instance: @@ -1089,6 +1096,7 @@ def edit(self, instance: Dict, with_valid=True): chat_cache.clear_by_application_id(application_id) application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application_id).first() # 更新缓存数据 + print(application.name) get_application_access_token(application_access_token.access_token, False) return self.one(with_valid=False) @@ -1141,6 +1149,8 @@ def get_work_flow_model(instance): instance['file_upload_enable'] = node_data['file_upload_enable'] if 'file_upload_setting' in node_data: instance['file_upload_setting'] = node_data['file_upload_setting'] + if 'name' in node_data: + instance['name'] = node_data['name'] break def speech_to_text(self, file, with_valid=True): @@ -1318,7 +1328,12 @@ class McpServers(serializers.Serializer): def get_mcp_servers(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) + if '"stdio"' in self.data.get('mcp_servers'): + raise AppApiException(500, _('stdio is not supported')) servers = json.loads(self.data.get('mcp_servers')) + for server, config in servers.items(): + if config.get('transport') not in ['sse', 'streamable_http']: + raise AppApiException(500, _('Only support transport=sse or transport=streamable_http')) async def get_mcp_tools(servers): async with MultiServerMCPClient(servers) as client: diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 2194028e6dd..e0ea7e9f555 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -213,12 +213,21 @@ def get_message(instance): return instance.get('messages')[-1].get('content') @staticmethod - def generate_chat(chat_id, application_id, message, client_id): + def generate_chat(chat_id, application_id, message, client_id, asker=None): if chat_id is None: chat_id = str(uuid.uuid1()) chat = QuerySet(Chat).filter(id=chat_id).first() if chat is None: - Chat(id=chat_id, application_id=application_id, abstract=message[0:1024], client_id=client_id).save() + asker_dict = {'user_name': '游客'} + if asker is not None: + if isinstance(asker, str): + asker_dict = { + 'user_name': asker + } + elif isinstance(asker, dict): + asker_dict = asker + Chat(id=chat_id, application_id=application_id, abstract=message[0:1024], client_id=client_id, + asker=asker_dict).save() return chat_id def chat(self, instance: Dict, with_valid=True): @@ -232,7 +241,8 @@ def chat(self, instance: Dict, with_valid=True): application_id = self.data.get('application_id') client_id = self.data.get('client_id') client_type = self.data.get('client_type') - chat_id = self.generate_chat(chat_id, application_id, message, client_id) + chat_id = self.generate_chat(chat_id, application_id, message, client_id, + asker=instance.get('form_data', {}).get("asker")) return ChatMessageSerializer( data={ 'chat_id': chat_id, 'message': message, @@ -245,6 +255,7 @@ def chat(self, instance: Dict, with_valid=True): 'image_list': instance.get('image_list', []), 'document_list': instance.get('document_list', []), 'audio_list': instance.get('audio_list', []), + 'other_list': instance.get('other_list', []), } ).chat(base_to_response=OpenaiToResponse()) @@ -274,6 +285,7 @@ class ChatMessageSerializer(serializers.Serializer): image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture"))) document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document"))) audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Audio"))) + other_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Other"))) child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Child Nodes"))) @@ -372,6 +384,7 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response): image_list = self.data.get('image_list') document_list = self.data.get('document_list') audio_list = self.data.get('audio_list') + other_list = self.data.get('other_list') user_id = chat_info.application.user_id chat_record_id = self.data.get('chat_record_id') chat_record = None @@ -382,13 +395,14 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response): work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), {'history_chat_record': history_chat_record, 'question': message, 'chat_id': chat_info.chat_id, 'chat_record_id': str( - uuid.uuid1()) if chat_record is None else chat_record.id, + uuid.uuid1()) if chat_record is None else str(chat_record.id), 'stream': stream, 're_chat': re_chat, 'client_id': client_id, 'client_type': client_type, 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), base_to_response, form_data, image_list, document_list, audio_list, + other_list, self.data.get('runtime_node_id'), self.data.get('node_data'), chat_record, self.data.get('child_node')) r = work_flow_manage.run() diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index b90194d5ae2..bc397fecf4a 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -13,8 +13,9 @@ from functools import reduce from io import BytesIO from typing import Dict -import pytz + import openpyxl +import pytz from django.core import validators from django.core.cache import caches from django.db import transaction, models @@ -33,8 +34,8 @@ ModelSettingSerializer from application.serializers.chat_message_serializers import ChatInfo from common.constants.permission_constants import RoleConstants -from common.db.search import native_search, native_page_search, page_search, get_dynamics_model -from common.exception.app_exception import AppApiException +from common.db.search import native_search, native_page_search, page_search, get_dynamics_model, native_page_handler +from common.exception.app_exception import AppApiException, AppUnauthorizedFailed from common.util.common import post from common.util.field_message import ErrMessage from common.util.file_util import get_file_content @@ -144,7 +145,8 @@ def get_query_set(self, select_ids=None): 'trample_num': models.IntegerField(), 'comparer': models.CharField(), 'application_chat.update_time': models.DateTimeField(), - 'application_chat.id': models.UUIDField(), })) + 'application_chat.id': models.UUIDField(), + 'application_chat_record_temp.id': models.UUIDField()})) base_query_dict = {'application_chat.application_id': self.data.get("application_id"), 'application_chat.update_time__gte': start_time, @@ -174,7 +176,14 @@ def get_query_set(self, select_ids=None): condition = base_condition & min_trample_query else: condition = base_condition - return query_set.filter(condition).order_by("-application_chat.update_time") + inner_queryset = QuerySet(Chat).filter(application_id=self.data.get("application_id")) + if 'abstract' in self.data and self.data.get('abstract') is not None: + inner_queryset = inner_queryset.filter(abstract__icontains=self.data.get('abstract')) + + return { + 'inner_queryset': inner_queryset, + 'default_queryset': query_set.filter(condition).order_by("-application_chat.update_time") + } def list(self, with_valid=True): if with_valid: @@ -215,7 +224,8 @@ def to_row(row: Dict): reference_paragraph, "\n".join([ f"{improve_paragraph_list[index].get('title')}\n{improve_paragraph_list[index].get('content')}" - for index in range(len(improve_paragraph_list))]), + for index in range(len(improve_paragraph_list)) + ]) if improve_paragraph_list is not None else "", row.get('asker').get('user_name'), row.get('message_tokens') + row.get('answer_tokens'), row.get('run_time'), str(row.get('create_time').astimezone(pytz.timezone(TIME_ZONE)).strftime('%Y-%m-%d %H:%M:%S') @@ -225,55 +235,90 @@ def export(self, data, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - data_list = native_search(self.get_query_set(data.get('select_ids')), - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "application", 'sql', - 'export_application_chat.sql')), - with_table_name=False) + batch_size = 2000 - batch_size = 500 + select_sql = get_file_content( + os.path.join( + PROJECT_DIR, + "apps", + "application", + "sql", + "export_application_chat.sql" + ) + ) def stream_response(): - workbook = openpyxl.Workbook() - worksheet = workbook.active - worksheet.title = 'Sheet1' - - headers = [gettext('Conversation ID'), gettext('summary'), gettext('User Questions'), - gettext('Problem after optimization'), - gettext('answer'), gettext('User feedback'), - gettext('Reference segment number'), - gettext('Section title + content'), - gettext('Annotation'), gettext('USER'), gettext('Consuming tokens'), - gettext('Time consumed (s)'), - gettext('Question Time')] - for col_idx, header in enumerate(headers, 1): - cell = worksheet.cell(row=1, column=col_idx) - cell.value = header - - for i in range(0, len(data_list), batch_size): - batch_data = data_list[i:i + batch_size] - - for row_idx, row in enumerate(batch_data, start=i + 2): - for col_idx, value in enumerate(self.to_row(row), 1): - cell = worksheet.cell(row=row_idx, column=col_idx) - if isinstance(value, str): - value = re.sub(ILLEGAL_CHARACTERS_RE, '', value) - if isinstance(value, datetime.datetime): - eastern = pytz.timezone(TIME_ZONE) - c = datetime.timezone(eastern._utcoffset) - value = value.astimezone(c) - cell.value = value - - output = BytesIO() - workbook.save(output) - output.seek(0) - yield output.getvalue() - output.close() - workbook.close() - - response = StreamingHttpResponse(stream_response(), - content_type='application/vnd.open.xmlformats-officedocument.spreadsheetml.sheet') + import tempfile + + headers = [ + gettext('Conversation ID'), + gettext('summary'), + gettext('User Questions'), + gettext('Problem after optimization'), + gettext('answer'), + gettext('User feedback'), + gettext('Reference segment number'), + gettext('Section title + content'), + gettext('Annotation'), + gettext('USER'), + gettext('Consuming tokens'), + gettext('Time consumed (s)'), + gettext('Question Time') + ] + + with tempfile.NamedTemporaryFile(suffix=".xlsx") as tmp: + + workbook = openpyxl.Workbook(write_only=True) + worksheet = workbook.create_sheet(title="Sheet1") + + # 写表头 + worksheet.append(headers) + + for data_list in native_page_handler( + batch_size, + self.get_query_set(data.get('select_ids')), + primary_key='application_chat_record_temp.id', + primary_queryset='default_queryset', + get_primary_value=lambda item: item.get('id'), + select_string=select_sql, + with_table_name=False + ): + + for row in data_list: + + row_values = [] + for value in self.to_row(row): + + if isinstance(value, str): + value = re.sub(ILLEGAL_CHARACTERS_RE, '', value) + + elif isinstance(value, datetime.datetime): + eastern = pytz.timezone(TIME_ZONE) + c = datetime.timezone(eastern._utcoffset) + value = value.astimezone(c) + + row_values.append(value) + + worksheet.append(row_values) + + workbook.save(tmp.name) + workbook.close() + + # 分块返回文件 + with open(tmp.name, "rb") as f: + while True: + chunk = f.read(8192) + if not chunk: + break + yield chunk + + response = StreamingHttpResponse( + stream_response(), + content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ) + response['Content-Disposition'] = 'attachment; filename="data.xlsx"' + return response def page(self, current_page: int, page_size: int, with_valid=True): @@ -476,6 +521,13 @@ class Query(serializers.Serializer): chat_id = serializers.UUIDField(required=True) order_asc = serializers.BooleanField(required=False, allow_null=True) + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + exist = QuerySet(Chat).filter(id=self.data.get("chat_id"), + application_id=self.data.get("application_id")).exists() + if not exist: + raise AppUnauthorizedFailed(403, _('No permission to access')) + def list(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) diff --git a/apps/application/sql/export_application_chat.sql b/apps/application/sql/export_application_chat.sql index bb265ea5b02..99c8f7a3172 100644 --- a/apps/application/sql/export_application_chat.sql +++ b/apps/application/sql/export_application_chat.sql @@ -1,38 +1,39 @@ -SELECT - application_chat."id" as chat_id, - application_chat.abstract as abstract, - application_chat_record_temp.problem_text as problem_text, - application_chat_record_temp.answer_text as answer_text, - application_chat_record_temp.message_tokens as message_tokens, - application_chat_record_temp.answer_tokens as answer_tokens, - application_chat_record_temp.run_time as run_time, - application_chat_record_temp.details::JSON as details, - application_chat_record_temp."index" as "index", - application_chat_record_temp.improve_paragraph_list as improve_paragraph_list, - application_chat_record_temp.vote_status as vote_status, - application_chat_record_temp.create_time as create_time, - to_json(application_chat.asker) as asker -FROM - application_chat application_chat - LEFT JOIN ( - SELECT COUNT - ( "id" ) AS chat_record_count, - SUM ( CASE WHEN "vote_status" = '0' THEN 1 ELSE 0 END ) AS star_num, - SUM ( CASE WHEN "vote_status" = '1' THEN 1 ELSE 0 END ) AS trample_num, - SUM ( CASE WHEN array_length( application_chat_record.improve_paragraph_id_list, 1 ) IS NULL THEN 0 ELSE array_length( application_chat_record.improve_paragraph_id_list, 1 ) END ) AS mark_sum, - chat_id - FROM - application_chat_record - GROUP BY - application_chat_record.chat_id - ) chat_record_temp ON application_chat."id" = chat_record_temp.chat_id - LEFT JOIN ( - SELECT - *, - CASE - WHEN array_length( application_chat_record.improve_paragraph_id_list, 1 ) IS NULL THEN - '{}' ELSE ( SELECT ARRAY_AGG ( row_to_json ( paragraph ) ) FROM paragraph WHERE "id" = ANY ( application_chat_record.improve_paragraph_id_list ) ) - END as improve_paragraph_list - FROM - application_chat_record application_chat_record - ) application_chat_record_temp ON application_chat_record_temp.chat_id = application_chat."id" \ No newline at end of file +SELECT application_chat_record_temp.id AS id, + application_chat."id" as chat_id, + application_chat.abstract as abstract, + application_chat_record_temp.problem_text as problem_text, + application_chat_record_temp.answer_text as answer_text, + application_chat_record_temp.message_tokens as message_tokens, + application_chat_record_temp.answer_tokens as answer_tokens, + application_chat_record_temp.run_time as run_time, + application_chat_record_temp.details::JSON as details, application_chat_record_temp."index" as "index", + application_chat_record_temp.improve_paragraph_list as improve_paragraph_list, + application_chat_record_temp.vote_status as vote_status, + application_chat_record_temp.create_time as create_time, + to_json(application_chat.asker) as asker +FROM application_chat application_chat + LEFT JOIN (SELECT COUNT + ("id") AS chat_record_count, + SUM(CASE WHEN "vote_status" = '0' THEN 1 ELSE 0 END) AS star_num, + SUM(CASE WHEN "vote_status" = '1' THEN 1 ELSE 0 END) AS trample_num, + SUM(CASE + WHEN array_length(application_chat_record.improve_paragraph_id_list, 1) IS NULL + THEN 0 + ELSE array_length(application_chat_record.improve_paragraph_id_list, 1) END) AS mark_sum, + chat_id + FROM application_chat_record + WHERE chat_id IN (SELECT id + FROM application_chat ${inner_queryset}) + GROUP BY application_chat_record.chat_id) chat_record_temp + ON application_chat."id" = chat_record_temp.chat_id + LEFT JOIN (SELECT *, + CASE + WHEN array_length(application_chat_record.improve_paragraph_id_list, 1) IS NULL THEN + '{}' + ELSE (SELECT ARRAY_AGG(row_to_json(paragraph)) + FROM paragraph + WHERE "id" = ANY (application_chat_record.improve_paragraph_id_list)) + END as improve_paragraph_list + FROM application_chat_record application_chat_record) application_chat_record_temp + ON application_chat_record_temp.chat_id = application_chat."id" + ${default_queryset} \ No newline at end of file diff --git a/apps/application/sql/list_application_chat.sql b/apps/application/sql/list_application_chat.sql index 7f3e1680c99..c9f83c6b7c3 100644 --- a/apps/application/sql/list_application_chat.sql +++ b/apps/application/sql/list_application_chat.sql @@ -11,6 +11,9 @@ FROM chat_id FROM application_chat_record + WHERE chat_id IN ( + SELECT id FROM application_chat ${inner_queryset}) GROUP BY application_chat_record.chat_id - ) chat_record_temp ON application_chat."id" = chat_record_temp.chat_id \ No newline at end of file + ) chat_record_temp ON application_chat."id" = chat_record_temp.chat_id +${default_queryset} \ No newline at end of file diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index 2c9cbd86bf4..024279832b1 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -61,8 +61,6 @@ def get_response_body_api(): 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title=_("Affiliation user"), description=_("Affiliation user")), - 'status': openapi.Schema(type=openapi.TYPE_BOOLEAN, title=_("Is publish"), description=_('Is publish')), - 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title=_("Creation time"), description=_('Creation time')), @@ -302,7 +300,19 @@ def get_request_body_api(): 'no_references_prompt': openapi.Schema(type=openapi.TYPE_STRING, title=_("No citation segmentation prompt"), default="{question}", - description=_("No citation segmentation prompt")) + description=_("No citation segmentation prompt")), + 'reasoning_content_enable': openapi.Schema(type=openapi.TYPE_BOOLEAN, + title=_("Reasoning enable"), + default=False, + description=_("Reasoning enable")), + 'reasoning_content_end': openapi.Schema(type=openapi.TYPE_STRING, + title=_("Reasoning end tag"), + default="", + description=_("Reasoning end tag")), + "reasoning_content_start": openapi.Schema(type=openapi.TYPE_STRING, + title=_("Reasoning start tag"), + default="", + description=_("Reasoning start tag")) } ) diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index 54b5678f747..f27a19c200e 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -326,11 +326,6 @@ def get_request_params_api(): type=openapi.TYPE_STRING, required=True, description=_('Application ID')), - openapi.Parameter(name='history_day', - in_=openapi.IN_QUERY, - type=openapi.TYPE_NUMBER, - required=True, - description=_('Historical days')), openapi.Parameter(name='abstract', in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False, description=_("abstract")), openapi.Parameter(name='min_star', in_=openapi.IN_QUERY, type=openapi.TYPE_INTEGER, required=False, diff --git a/apps/application/views/application_version_views.py b/apps/application/views/application_version_views.py index de900936268..1cd42a643a0 100644 --- a/apps/application/views/application_version_views.py +++ b/apps/application/views/application_version_views.py @@ -48,7 +48,11 @@ class Page(APIView): ApplicationVersionApi.Query.get_request_params_api()), responses=result.get_page_api_response(ApplicationVersionApi.get_response_body_api()), tags=[_('Application/Version')]) - @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + @has_permissions(PermissionConstants.APPLICATION_READ, + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND), compare=CompareConstants.AND) def get(self, request: Request, application_id: str, current_page: int, page_size: int): return result.success( ApplicationVersionSerializer.Query( @@ -65,7 +69,14 @@ class Operate(APIView): manual_parameters=ApplicationVersionApi.Operate.get_request_params_api(), responses=result.get_api_response(ApplicationVersionApi.get_response_body_api()), tags=[_('Application/Version')]) - @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + @has_permissions(PermissionConstants.APPLICATION_READ, ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission( + group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=keywords.get( + 'application_id'))], + compare=CompareConstants.AND), + compare=CompareConstants.AND) def get(self, request: Request, application_id: str, work_flow_version_id: str): return result.success( ApplicationVersionSerializer.Operate( diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index f16041d1de3..8c3e8059bcb 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -7,16 +7,6 @@ @desc: """ -from django.core import cache -from django.http import HttpResponse -from django.utils.translation import gettext_lazy as _, gettext -from drf_yasg.utils import swagger_auto_schema -from langchain_core.prompts import PromptTemplate -from rest_framework.decorators import action -from rest_framework.parsers import MultiPartParser -from rest_framework.request import Request -from rest_framework.views import APIView - from application.serializers.application_serializers import ApplicationSerializer from application.serializers.application_statistics_serializers import ApplicationStatisticsSerializer from application.swagger_api.application_api import ApplicationApi @@ -31,6 +21,14 @@ from common.swagger_api.common_api import CommonApi from common.util.common import query_params_to_single_dict from dataset.serializers.dataset_serializers import DataSetSerializers +from django.core import cache +from django.http import HttpResponse +from django.utils.translation import gettext_lazy as _ +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.request import Request +from rest_framework.views import APIView chat_cache = cache.caches['chat_cache'] @@ -494,7 +492,7 @@ def get(self, request: Request): class HitTest(APIView): authentication_classes = [TokenAuth] - @action(methods="GET", detail=False) + @action(methods="PUT", detail=False) @swagger_auto_schema(operation_summary=_("Hit Test List"), operation_id=_("Hit Test List"), manual_parameters=CommonApi.HitTestApi.get_request_params_api(), responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()), @@ -505,15 +503,15 @@ class HitTest(APIView): [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], compare=CompareConstants.AND)) - def get(self, request: Request, application_id: str): - return result.success( - ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id, - "query_text": request.query_params.get("query_text"), - "top_number": request.query_params.get("top_number"), - 'similarity': request.query_params.get('similarity'), - 'search_mode': request.query_params.get( - 'search_mode')}).hit_test( - )) + def put(self, request: Request, application_id: str): + return result.success(ApplicationSerializer.HitTest(data={ + 'id': application_id, + 'user_id': request.user.id, + "query_text": request.data.get("query_text"), + "top_number": request.data.get("top_number"), + 'similarity': request.data.get('similarity'), + 'search_mode': request.data.get('search_mode')} + ).hit_test()) class Publish(APIView): authentication_classes = [TokenAuth] diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 0415f8208dc..30d54fa65a4 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -59,7 +59,8 @@ class Export(APIView): @has_permissions( ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) @log(menu='Conversation Log', operate="Export conversation", get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id'))) @@ -144,6 +145,8 @@ def post(self, request: Request, chat_id: str): 'document_list') if 'document_list' in request.data else [], 'audio_list': request.data.get( 'audio_list') if 'audio_list' in request.data else [], + 'other_list': request.data.get( + 'other_list') if 'other_list' in request.data else [], 'client_type': request.auth.client_type, 'node_id': request.data.get('node_id', None), 'runtime_node_id': request.data.get('runtime_node_id', None), @@ -162,7 +165,9 @@ def post(self, request: Request, chat_id: str): @has_permissions( ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND + ) ) def get(self, request: Request, application_id: str): return result.success(ChatSerializers.Query( @@ -180,8 +185,7 @@ class Operate(APIView): [RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, dynamic_tag=keywords.get('application_id'))], - compare=CompareConstants.AND), - compare=CompareConstants.AND) + compare=CompareConstants.AND)) @log(menu='Conversation Log', operate="Delete a conversation", get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id'))) def delete(self, request: Request, application_id: str, chat_id: str): @@ -204,7 +208,8 @@ class ClientChatHistoryPage(APIView): @has_permissions( ViewPermission([RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) def get(self, request: Request, application_id: str, current_page: int, page_size: int): return result.success(ChatSerializers.ClientChatHistory( @@ -239,7 +244,7 @@ def delete(self, request: Request, application_id: str, chat_id: str): request_body=ChatClientHistoryApi.Operate.ReAbstract.get_request_body_api(), tags=[_("Application/Conversation Log")]) @has_permissions(ViewPermission( - [RoleConstants.APPLICATION_ACCESS_TOKEN], + [RoleConstants.APPLICATION_ACCESS_TOKEN, RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], compare=CompareConstants.AND), @@ -265,7 +270,8 @@ class Page(APIView): @has_permissions( ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) def get(self, request: Request, application_id: str, current_page: int, page_size: int): return result.success(ChatSerializers.Query( @@ -290,7 +296,8 @@ class Operate(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str): return result.success(ChatRecordSerializer.Operate( @@ -308,7 +315,8 @@ def get(self, request: Request, application_id: str, chat_id: str, chat_record_i @has_permissions( ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) def get(self, request: Request, application_id: str, chat_id: str): return result.success(ChatRecordSerializer.Query( @@ -327,9 +335,11 @@ class Page(APIView): tags=[_("Application/Conversation Log")] ) @has_permissions( - ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, + RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) def get(self, request: Request, application_id: str, chat_id: str, current_page: int, page_size: int): return result.success(ChatRecordSerializer.Query( @@ -352,7 +362,8 @@ class Vote(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) ) @log(menu='Conversation Log', operate="Like, Dislike", get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id'))) @@ -375,7 +386,7 @@ class ChatRecordImprove(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))] - )) + , compare=CompareConstants.AND)) def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str): return result.success(ChatRecordSerializer.ChatRecordImprove( data={'chat_id': chat_id, 'chat_record_id': chat_record_id}).get()) @@ -395,7 +406,7 @@ class Improve(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], - + compare=CompareConstants.AND ), ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, @@ -422,6 +433,7 @@ def put(self, request: Request, application_id: str, chat_id: str, chat_record_i ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND ), ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.DATASET, @@ -449,6 +461,7 @@ class Operate(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND ), ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.DATASET, @@ -497,7 +510,8 @@ class UploadFile(APIView): ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, - dynamic_tag=keywords.get('application_id'))]) + dynamic_tag=keywords.get('application_id'))] + , compare=CompareConstants.AND) ) def post(self, request: Request, application_id: str, chat_id: str): files = request.FILES.getlist('file') diff --git a/apps/common/auth/handle/impl/user_token.py b/apps/common/auth/handle/impl/user_token.py index dbb6bd2b51a..bdb041f9f79 100644 --- a/apps/common/auth/handle/impl/user_token.py +++ b/apps/common/auth/handle/impl/user_token.py @@ -6,18 +6,18 @@ @date:2024/3/14 03:02 @desc: 用户认证 """ +from django.core import cache from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ from common.auth.handle.auth_base_handle import AuthBaseHandle from common.constants.authentication_type import AuthenticationType from common.constants.permission_constants import RoleConstants, get_permission_list_by_role, Auth from common.exception.app_exception import AppAuthenticationFailed -from smartdoc.settings import JWT_AUTH +from smartdoc.const import CONFIG from users.models import User -from django.core import cache - from users.models.user import get_user_dynamics_permission -from django.utils.translation import gettext_lazy as _ + token_cache = cache.caches['token_cache'] @@ -35,7 +35,7 @@ def handle(self, request, token: str, get_token_details): auth_details = get_token_details() user = QuerySet(User).get(id=auth_details['id']) # 续期 - token_cache.touch(token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds()) + token_cache.touch(token, timeout=CONFIG.get_session_timeout()) rule = RoleConstants[user.role] permission_list = get_permission_list_by_role(RoleConstants[user.role]) # 获取用户的应用和知识库的权限 diff --git a/apps/common/config/embedding_config.py b/apps/common/config/embedding_config.py index a6e9ab9aa9b..69081be055d 100644 --- a/apps/common/config/embedding_config.py +++ b/apps/common/config/embedding_config.py @@ -11,35 +11,50 @@ from common.cache.mem_cache import MemCache -lock = threading.Lock() +_lock = threading.Lock() +locks = {} class ModelManage: cache = MemCache('model', {}) up_clear_time = time.time() + @staticmethod + def _get_lock(_id): + lock = locks.get(_id) + if lock is None: + with _lock: + lock = locks.get(_id) + if lock is None: + lock = threading.Lock() + locks[_id] = lock + + return lock + @staticmethod def get_model(_id, get_model): - # 获取锁 - lock.acquire() - try: - model_instance = ModelManage.cache.get(_id) - if model_instance is None or not model_instance.is_cache_model(): + model_instance = ModelManage.cache.get(_id) + if model_instance is None: + lock = ModelManage._get_lock(_id) + with lock: + model_instance = ModelManage.cache.get(_id) + if model_instance is None: + model_instance = get_model(_id) + ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8) + else: + if model_instance.is_cache_model(): + ModelManage.cache.touch(_id, timeout=60 * 60 * 8) + else: model_instance = get_model(_id) - ModelManage.cache.set(_id, model_instance, timeout=60 * 30) - return model_instance - # 续期 - ModelManage.cache.touch(_id, timeout=60 * 30) - ModelManage.clear_timeout_cache() - return model_instance - finally: - # 释放锁 - lock.release() + ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8) + ModelManage.clear_timeout_cache() + return model_instance @staticmethod def clear_timeout_cache(): - if time.time() - ModelManage.up_clear_time > 60: - ModelManage.cache.clear_timeout_data() + if time.time() - ModelManage.up_clear_time > 60 * 60: + threading.Thread(target=lambda: ModelManage.cache.clear_timeout_data()).start() + ModelManage.up_clear_time = time.time() @staticmethod def delete_key(_id): diff --git a/apps/common/db/search.py b/apps/common/db/search.py index bef42a1414a..07ecd1b0262 100644 --- a/apps/common/db/search.py +++ b/apps/common/db/search.py @@ -170,6 +170,51 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet | D return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size) +def native_page_handler(page_size: int, + queryset: QuerySet | Dict[str, QuerySet], + select_string: str, + field_replace_dict=None, + with_table_name=False, + primary_key=None, + get_primary_value=None, + primary_queryset: str = None, + ): + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict({**queryset, + primary_queryset: queryset[primary_queryset].order_by( + primary_key)}, select_string, field_replace_dict, with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query(queryset.order_by( + primary_key), select_string, field_replace_dict, with_table_name) + total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql + total = select_one(total_sql, exec_params) + processed_count = 0 + last_id = None + while processed_count < total.get("count"): + if last_id is not None: + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict({**queryset, + primary_queryset: queryset[primary_queryset].filter( + **{f"{primary_key}__gt": last_id}).order_by( + primary_key)}, + select_string, field_replace_dict, + with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query( + queryset.filter(**{f"{primary_key}__gt": last_id}).order_by( + primary_key), + select_string, field_replace_dict, + with_table_name) + limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql( + 0, page_size + ) + page_sql = exec_sql + " " + limit_sql + result = select_list(page_sql, exec_params) + yield result + processed_count += page_size + last_id = get_primary_value(result[-1]) + + def get_field_replace_dict(queryset: QuerySet): """ 获取需要替换的字段 默认 “xxx.xxx”需要被替换成 “xxx”."xxx" diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 72d16ebb523..6899c31f33e 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -24,6 +24,7 @@ from common.util.lock import try_lock, un_lock from common.util.page_utils import page_desc from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State +from dataset.serializers.common_serializers import create_dataset_index from embedding.models import SourceType, SearchMode from smartdoc.conf import PROJECT_DIR from django.utils.translation import gettext_lazy as _ @@ -238,11 +239,8 @@ def update_status(query_set: QuerySet, taskType: TaskType, state: State): for key in params_dict: _value_ = params_dict[key] exec_sql = exec_sql.replace(key, str(_value_)) - lock.acquire() - try: + with lock: native_update(query_set, exec_sql) - finally: - lock.release() @staticmethod def embedding_by_document(document_id, embedding_model: Embeddings, state_list=None): @@ -272,7 +270,6 @@ def is_the_task_interrupted(): ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.STARTED) - # 根据段落进行向量化处理 page_desc(QuerySet(Paragraph) .annotate( @@ -285,6 +282,8 @@ def is_the_task_interrupted(): ListenerManagement.get_aggregation_document_status( document_id)), is_the_task_interrupted) + # 检查是否存在索引 + create_dataset_index(document_id=document_id) except Exception as e: max_kb_error.error(_('Vectorized document: {document_id} error {error} {traceback}').format( document_id=document_id, error=str(e), traceback=traceback.format_exc())) diff --git a/apps/common/forms/__init__.py b/apps/common/forms/__init__.py index 6095421935b..251f01df092 100644 --- a/apps/common/forms/__init__.py +++ b/apps/common/forms/__init__.py @@ -22,3 +22,4 @@ from .radio_card_field import * from .label import * from .slider_field import * +from .switch_field import * diff --git a/apps/common/forms/switch_field.py b/apps/common/forms/switch_field.py index 9fa176beea0..ea119c3ecfb 100644 --- a/apps/common/forms/switch_field.py +++ b/apps/common/forms/switch_field.py @@ -28,6 +28,6 @@ def __init__(self, label: str or BaseLabel, @param props_info: """ - super().__init__('Switch', label, required, default_value, relation_show_field_dict, + super().__init__('SwitchInput', label, required, default_value, relation_show_field_dict, {}, TriggerType.OPTION_LIST, attrs, props_info) diff --git a/apps/common/handle/impl/doc_split_handle.py b/apps/common/handle/impl/doc_split_handle.py index 1df7b6a66e0..4161f13a19d 100644 --- a/apps/common/handle/impl/doc_split_handle.py +++ b/apps/common/handle/impl/doc_split_handle.py @@ -112,11 +112,7 @@ def get_image_id(image_id): title_font_list = [ [36, 100], - [26, 36], - [24, 26], - [22, 24], - [18, 22], - [16, 18] + [30, 36] ] @@ -130,7 +126,7 @@ def get_title_level(paragraph: Paragraph): if len(paragraph.runs) == 1: font_size = paragraph.runs[0].font.size pt = font_size.pt - if pt >= 16: + if pt >= 30: for _value, index in zip(title_font_list, range(len(title_font_list))): if pt >= _value[0] and pt < _value[1]: return index + 1 diff --git a/apps/common/handle/impl/table/xls_parse_table_handle.py b/apps/common/handle/impl/table/xls_parse_table_handle.py index 5609e3e8835..897e347e8a8 100644 --- a/apps/common/handle/impl/table/xls_parse_table_handle.py +++ b/apps/common/handle/impl/table/xls_parse_table_handle.py @@ -82,7 +82,10 @@ def get_content(self, file, save_image): for row in data: # 将每个单元格中的内容替换换行符为
以保留原始格式 md_table += '| ' + ' | '.join( - [str(cell).replace('\n', '
') if cell else '' for cell in row]) + ' |\n' + [str(cell) + .replace('\r\n', '
') + .replace('\n', '
') + if cell else '' for cell in row]) + ' |\n' md_tables += md_table + '\n\n' return md_tables diff --git a/apps/common/handle/impl/table/xlsx_parse_table_handle.py b/apps/common/handle/impl/table/xlsx_parse_table_handle.py index abaec05769a..a68eb14f1a1 100644 --- a/apps/common/handle/impl/table/xlsx_parse_table_handle.py +++ b/apps/common/handle/impl/table/xlsx_parse_table_handle.py @@ -19,36 +19,24 @@ def support(self, file, get_buffer): def fill_merged_cells(self, sheet, image_dict): data = [] - - # 获取第一行作为标题行 - headers = [] - for idx, cell in enumerate(sheet[1]): - if cell.value is None: - headers.append(' ' * (idx + 1)) - else: - headers.append(cell.value) - # 从第二行开始遍历每一行 - for row in sheet.iter_rows(min_row=2, values_only=False): - row_data = {} + for row in sheet.iter_rows(values_only=False): + row_data = [] for col_idx, cell in enumerate(row): cell_value = cell.value - - # 如果单元格为空,并且该单元格在合并单元格内,获取合并单元格的值 - if cell_value is None: - for merged_range in sheet.merged_cells.ranges: - if cell.coordinate in merged_range: - cell_value = sheet[merged_range.min_row][merged_range.min_col - 1].value - break - image = image_dict.get(cell_value, None) if image is not None: cell_value = f'![](/api/image/{image.id})' # 使用标题作为键,单元格的值作为值存入字典 - row_data[headers[col_idx]] = cell_value + row_data.insert(col_idx, cell_value) data.append(row_data) + for merged_range in sheet.merged_cells.ranges: + cell_value = data[merged_range.min_row - 1][merged_range.min_col - 1] + for row_index in range(merged_range.min_row, merged_range.max_row + 1): + for col_index in range(merged_range.min_col, merged_range.max_col + 1): + data[row_index - 1][col_index - 1] = cell_value return data def handle(self, file, get_buffer, save_image): @@ -65,11 +53,13 @@ def handle(self, file, get_buffer, save_image): paragraphs = [] ws = wb[sheetname] data = self.fill_merged_cells(ws, image_dict) - - for row in data: - row_output = "; ".join([f"{key}: {value}" for key, value in row.items()]) - # print(row_output) - paragraphs.append({'title': '', 'content': row_output}) + if len(data) >= 2: + head_list = data[0] + for row_index in range(1, len(data)): + row_output = "; ".join( + [f"{head_list[col_index]}: {data[row_index][col_index]}" for col_index in + range(0, len(data[row_index]))]) + paragraphs.append({'title': '', 'content': row_output}) result.append({'name': sheetname, 'paragraphs': paragraphs}) @@ -78,7 +68,6 @@ def handle(self, file, get_buffer, save_image): return [{'name': file.name, 'paragraphs': []}] return result - def get_content(self, file, save_image): try: # 加载 Excel 文件 @@ -94,18 +83,18 @@ def get_content(self, file, save_image): # 如果未指定 sheet_name,则使用第一个工作表 for sheetname in workbook.sheetnames: sheet = workbook[sheetname] if sheetname else workbook.active - rows = self.fill_merged_cells(sheet, image_dict) - if len(rows) == 0: + data = self.fill_merged_cells(sheet, image_dict) + if len(data) == 0: continue # 提取表头和内容 - headers = [f"{key}" for key, value in rows[0].items()] + headers = [f"{value}" for value in data[0]] # 构建 Markdown 表格 md_table = '| ' + ' | '.join(headers) + ' |\n' md_table += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n' - for row in rows: - r = [f'{value}' for key, value in row.items()] + for row_index in range(1, len(data)): + r = [f'{value}' for value in data[row_index]] md_table += '| ' + ' | '.join( [str(cell).replace('\n', '
') if cell is not None else '' for cell in r]) + ' |\n' diff --git a/apps/common/handle/impl/xls_split_handle.py b/apps/common/handle/impl/xls_split_handle.py index 3d8afdf62de..dbdcc95506d 100644 --- a/apps/common/handle/impl/xls_split_handle.py +++ b/apps/common/handle/impl/xls_split_handle.py @@ -14,7 +14,7 @@ def post_cell(cell_value): - return cell_value.replace('\n', '
').replace('|', '|') + return cell_value.replace('\r\n', '
').replace('\n', '
').replace('|', '|') def row_to_md(row): diff --git a/apps/common/management/commands/services/services/gunicorn.py b/apps/common/management/commands/services/services/gunicorn.py index cc42c4f7cb3..a32220ab881 100644 --- a/apps/common/management/commands/services/services/gunicorn.py +++ b/apps/common/management/commands/services/services/gunicorn.py @@ -16,13 +16,14 @@ def cmd(self): log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' bind = f'{HTTP_HOST}:{HTTP_PORT}' + max_requests = 10240 if int(self.worker) > 1 else 0 cmd = [ 'gunicorn', 'smartdoc.wsgi:application', '-b', bind, '-k', 'gthread', '--threads', '200', '-w', str(self.worker), - '--max-requests', '10240', + '--max-requests', str(max_requests), '--max-requests-jitter', '2048', '--access-logformat', log_format, '--access-logfile', '-' diff --git a/apps/common/management/commands/services/services/local_model.py b/apps/common/management/commands/services/services/local_model.py index 4511f8f5fee..db11d2d404f 100644 --- a/apps/common/management/commands/services/services/local_model.py +++ b/apps/common/management/commands/services/services/local_model.py @@ -24,13 +24,15 @@ def cmd(self): os.environ.setdefault('SERVER_NAME', 'local_model') log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + worker = CONFIG.get("LOCAL_MODEL_HOST_WORKER", 1) + max_requests = 10240 if int(worker) > 1 else 0 cmd = [ 'gunicorn', 'smartdoc.wsgi:application', '-b', bind, '-k', 'gthread', '--threads', '200', - '-w', "1", - '--max-requests', '10240', + '-w', str(worker), + '--max-requests', str(max_requests), '--max-requests-jitter', '2048', '--access-logformat', log_format, '--access-logfile', '-' diff --git a/apps/common/middleware/doc_headers_middleware.py b/apps/common/middleware/doc_headers_middleware.py index d818b842ca5..83419b19fb0 100644 --- a/apps/common/middleware/doc_headers_middleware.py +++ b/apps/common/middleware/doc_headers_middleware.py @@ -9,43 +9,102 @@ from django.http import HttpResponse from django.utils.deprecation import MiddlewareMixin +from common.auth import handles, TokenDetails + content = """ - + Document + + + + - - + """ @@ -54,9 +113,18 @@ class DocHeadersMiddleware(MiddlewareMixin): def process_response(self, request, response): if request.path.startswith('/doc/') or request.path.startswith('/doc/chat/'): - HTTP_REFERER = request.META.get('HTTP_REFERER') - if HTTP_REFERER is None: + auth = request.COOKIES.get('Authorization') + if auth is None: return HttpResponse(content) - if HTTP_REFERER == request._current_scheme_host + request.path: - return response + else: + try: + token = auth + token_details = TokenDetails(token) + for handle in handles: + if handle.support(request, token, token_details.get_token_details): + handle.handle(request, token, token_details.get_token_details) + return response + return HttpResponse(content) + except Exception as e: + return HttpResponse(content) return response diff --git a/apps/common/swagger_api/common_api.py b/apps/common/swagger_api/common_api.py index 3134db0d083..9e7d1976298 100644 --- a/apps/common/swagger_api/common_api.py +++ b/apps/common/swagger_api/common_api.py @@ -15,33 +15,21 @@ class CommonApi: class HitTestApi(ApiMixin): @staticmethod - def get_request_params_api(): - return [ - openapi.Parameter(name='query_text', - in_=openapi.IN_QUERY, - type=openapi.TYPE_STRING, - required=True, - description=_('query text')), - openapi.Parameter(name='top_number', - in_=openapi.IN_QUERY, - type=openapi.TYPE_NUMBER, - default=10, - required=True, - description='topN'), - openapi.Parameter(name='similarity', - in_=openapi.IN_QUERY, - type=openapi.TYPE_NUMBER, - default=0.6, - required=True, - description=_('similarity')), - openapi.Parameter(name='search_mode', - in_=openapi.IN_QUERY, - type=openapi.TYPE_STRING, - default="embedding", - required=True, - description=_('Retrieval pattern embedding|keywords|blend') - ) - ] + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['query_text', 'top_number', 'similarity', 'search_mode'], + properties={ + 'query_text': openapi.Schema(type=openapi.TYPE_STRING, title=_('query text'), + description=_('query text')), + 'top_number': openapi.Schema(type=openapi.TYPE_NUMBER, title=_('top number'), + description=_('top number')), + 'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title=_('similarity'), + description=_('similarity')), + 'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title=_('search mode'), + description=_('search mode')) + } + ) @staticmethod def get_response_body_api(): diff --git a/apps/common/util/common.py b/apps/common/util/common.py index b0111029af9..8583a1c989f 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -11,6 +11,7 @@ import io import mimetypes import pickle +import random import re import shutil from functools import reduce @@ -297,3 +298,14 @@ def markdown_to_plain_text(md: str) -> str: # 去除首尾空格 text = text.strip() return text + + +SAFE_CHAR_SET = ( + [chr(i) for i in range(65, 91) if chr(i) not in {'I', 'O'}] + # 大写字母 A-H, J-N, P-Z + [chr(i) for i in range(97, 123) if chr(i) not in {'i', 'l', 'o'}] + # 小写字母 a-h, j-n, p-z + [str(i) for i in range(10) if str(i) not in {'0', '1', '7'}] # 数字 2-6, 8-9 +) + + +def get_random_chars(number=4): + return ''.join(random.choices(SAFE_CHAR_SET, k=number)) diff --git a/apps/common/util/fork.py b/apps/common/util/fork.py index 4405b9b76e4..dc27ccf1982 100644 --- a/apps/common/util/fork.py +++ b/apps/common/util/fork.py @@ -3,6 +3,7 @@ import re import traceback from functools import reduce +from pathlib import Path from typing import List, Set from urllib.parse import urljoin, urlparse, ParseResult, urlsplit, urlunparse @@ -52,6 +53,28 @@ def remove_fragment(url: str) -> str: return urlunparse(modified_url) +def remove_last_path_robust(url): + """健壮地删除URL的最后一个路径部分""" + parsed = urlparse(url) + + # 分割路径并过滤空字符串 + paths = [p for p in parsed.path.split('/') if p] + + if paths: + paths.pop() # 移除最后一个路径 + + # 重建路径 + new_path = '/' + '/'.join(paths) if paths else '/' + + # 重建URL + return urlunparse(( + parsed.scheme, + parsed.netloc, + new_path, + parsed.params, + parsed.query, + parsed.fragment + )) class Fork: class Response: def __init__(self, content: str, child_link_list: List[ChildLink], status, message: str): @@ -70,6 +93,8 @@ def error(message: str): def __init__(self, base_fork_url: str, selector_list: List[str]): base_fork_url = remove_fragment(base_fork_url) + if any([True for end_str in ['index.html', '.htm', '.html'] if base_fork_url.endswith(end_str)]): + base_fork_url =remove_last_path_robust(base_fork_url) self.base_fork_url = urljoin(base_fork_url if base_fork_url.endswith("/") else base_fork_url + '/', '.') parsed = urlsplit(base_fork_url) query = parsed.query @@ -137,18 +162,30 @@ def get_beautiful_soup(response): html_content = response.content.decode(encoding) beautiful_soup = BeautifulSoup(html_content, "html.parser") meta_list = beautiful_soup.find_all('meta') - charset_list = [meta.attrs.get('charset') for meta in meta_list if - meta.attrs is not None and 'charset' in meta.attrs] + charset_list = Fork.get_charset_list(meta_list) if len(charset_list) > 0: charset = charset_list[0] if charset != encoding: try: - html_content = response.content.decode(charset) + html_content = response.content.decode(charset, errors='replace') except Exception as e: - logging.getLogger("max_kb").error(f'{e}') + logging.getLogger("max_kb").error(f'{e}: {traceback.format_exc()}') return BeautifulSoup(html_content, "html.parser") return beautiful_soup + @staticmethod + def get_charset_list(meta_list): + charset_list = [] + for meta in meta_list: + if meta.attrs is not None: + if 'charset' in meta.attrs: + charset_list.append(meta.attrs.get('charset')) + elif meta.attrs.get('http-equiv', '').lower() == 'content-type' and 'content' in meta.attrs: + match = re.search(r'charset=([^\s;]+)', meta.attrs['content'], re.I) + if match: + charset_list.append(match.group(1)) + return charset_list + def fork(self): try: @@ -175,4 +212,4 @@ def fork(self): def handler(base_url, response: Fork.Response): print(base_url.url, base_url.tag.text if base_url.tag else None, response.content) -# ForkManage('https://bbs.fit2cloud.com/c/de/6', ['.md-content']).fork(3, set(), handler) +# ForkManage('https://hzqcgc.htc.edu.cn/jxky.htm', ['.md-content']).fork(3, set(), handler) diff --git a/apps/common/util/function_code.py b/apps/common/util/function_code.py index 30ce3a33d20..3a877a62367 100644 --- a/apps/common/util/function_code.py +++ b/apps/common/util/function_code.py @@ -7,13 +7,12 @@ @desc: """ import os +import pickle import subprocess import sys import uuid from textwrap import dedent -from diskcache import Cache - from smartdoc.const import BASE_DIR from smartdoc.const import PROJECT_DIR @@ -37,6 +36,8 @@ def _createdir(self): old_mask = os.umask(0o077) try: os.makedirs(self.sandbox_path, 0o700, exist_ok=True) + os.makedirs(os.path.join(self.sandbox_path, 'execute'), 0o700, exist_ok=True) + os.makedirs(os.path.join(self.sandbox_path, 'result'), 0o700, exist_ok=True) finally: os.umask(old_mask) @@ -44,10 +45,11 @@ def exec_code(self, code_str, keywords): _id = str(uuid.uuid1()) success = '{"code":200,"msg":"成功","data":exec_result}' err = '{"code":500,"msg":str(e),"data":None}' - path = r'' + self.sandbox_path + '' + result_path = f'{self.sandbox_path}/result/{_id}.result' _exec_code = f""" try: import os + import pickle env = dict(os.environ) for key in list(env.keys()): if key in os.environ and (key.startswith('MAXKB') or key.startswith('POSTGRES') or key.startswith('PG')): @@ -60,13 +62,11 @@ def exec_code(self, code_str, keywords): for local in locals_v: globals_v[local] = locals_v[local] exec_result=f(**keywords) - from diskcache import Cache - cache = Cache({path!a}) - cache.set({_id!a},{success}) + with open({result_path!a}, 'wb') as file: + file.write(pickle.dumps({success})) except Exception as e: - from diskcache import Cache - cache = Cache({path!a}) - cache.set({_id!a},{err}) + with open({result_path!a}, 'wb') as file: + file.write(pickle.dumps({err})) """ if self.sandbox: subprocess_result = self._exec_sandbox(_exec_code, _id) @@ -74,18 +74,18 @@ def exec_code(self, code_str, keywords): subprocess_result = self._exec(_exec_code) if subprocess_result.returncode == 1: raise Exception(subprocess_result.stderr) - cache = Cache(self.sandbox_path) - result = cache.get(_id) - cache.delete(_id) + with open(result_path, 'rb') as file: + result = pickle.loads(file.read()) + os.remove(result_path) if result.get('code') == 200: return result.get('data') raise Exception(result.get('msg')) def _exec_sandbox(self, _code, _id): - exec_python_file = f'{self.sandbox_path}/{_id}.py' + exec_python_file = f'{self.sandbox_path}/execute/{_id}.py' with open(exec_python_file, 'w') as file: file.write(_code) - os.system(f"chown {self.user}:{self.user} {exec_python_file}") + os.system(f"chown {self.user}:root {exec_python_file}") kwargs = {'cwd': BASE_DIR} subprocess_result = subprocess.run( ['su', '-s', python_directory, '-c', "exec(open('" + exec_python_file + "').read())", self.user], diff --git a/apps/common/util/rsa_util.py b/apps/common/util/rsa_util.py index 00301867208..452ca678d9e 100644 --- a/apps/common/util/rsa_util.py +++ b/apps/common/util/rsa_util.py @@ -40,15 +40,12 @@ def generate(): def get_key_pair(): rsa_value = rsa_cache.get(cache_key) if rsa_value is None: - lock.acquire() - rsa_value = rsa_cache.get(cache_key) - if rsa_value is not None: - return rsa_value - try: + with lock: + rsa_value = rsa_cache.get(cache_key) + if rsa_value is not None: + return rsa_value rsa_value = get_key_pair_by_sql() rsa_cache.set(cache_key, rsa_value) - finally: - lock.release() return rsa_value diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 856f3da1584..edf064236b2 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -18,13 +18,13 @@ from common.config.embedding_config import ModelManage from common.db.search import native_search -from common.db.sql_execute import update_execute +from common.db.sql_execute import update_execute, sql_execute from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork -from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet, File, Image +from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet, File, Image, Document from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR from django.utils.translation import gettext_lazy as _ @@ -224,6 +224,46 @@ def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List): return str(dataset_list[0].embedding_mode_id) + +def create_dataset_index(dataset_id=None, document_id=None): + if dataset_id is None and document_id is None: + raise AppApiException(500, _('Dataset ID or Document ID must be provided')) + + if dataset_id is not None: + k_id = dataset_id + else: + document = QuerySet(Document).filter(id=document_id).first() + k_id = document.dataset_id + + sql = f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = 'embedding' AND indexname = 'embedding_hnsw_idx_{k_id}'" + index = sql_execute(sql, []) + if not index: + sql = f"SELECT vector_dims(embedding) AS dims FROM embedding WHERE dataset_id = '{k_id}' LIMIT 1" + result = sql_execute(sql, []) + if len(result) == 0: + return + dims = result[0]['dims'] + sql = f"""CREATE INDEX "embedding_hnsw_idx_{k_id}" ON embedding USING hnsw ((embedding::vector({dims})) vector_cosine_ops) WHERE dataset_id = '{k_id}'""" + update_execute(sql, []) + + +def drop_dataset_index(dataset_id=None, document_id=None): + if dataset_id is None and document_id is None: + raise AppApiException(500, _('Dataset ID or Document ID must be provided')) + + if dataset_id is not None: + k_id = dataset_id + else: + document = QuerySet(Document).filter(id=document_id).first() + k_id = document.dataset_id + + sql = f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = 'embedding' AND indexname = 'embedding_hnsw_idx_{k_id}'" + index = sql_execute(sql, []) + if index: + sql = f'DROP INDEX "embedding_hnsw_idx_{k_id}"' + update_execute(sql, []) + + class GenerateRelatedSerializer(ApiMixin, serializers.Serializer): model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Model id'))) prompt = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_('Prompt word'))) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 895443d997f..b92ecb33a54 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -44,7 +44,7 @@ State, File, Image from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir, \ - GenerateRelatedSerializer + GenerateRelatedSerializer, drop_dataset_index from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from dataset.task import sync_web_dataset, sync_replace_web_dataset, generate_related_by_dataset_id from embedding.models import SearchMode @@ -526,7 +526,7 @@ def get_response_body_api(): def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['name', 'desc'], + required=['name', 'desc', 'embedding_mode_id'], properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title=_('dataset name'), description=_('dataset name')), @@ -788,6 +788,7 @@ def delete(self): QuerySet(ProblemParagraphMapping).filter(dataset=dataset).delete() QuerySet(Paragraph).filter(dataset=dataset).delete() QuerySet(Problem).filter(dataset=dataset).delete() + drop_dataset_index(dataset_id=dataset.id) dataset.delete() delete_embedding_by_dataset(self.data.get('id')) return True diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 5915877fc7c..94f1b9db6ea 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -23,6 +23,8 @@ from django.db.models import QuerySet, Count from django.db.models.functions import Substr, Reverse from django.http import HttpResponse +from django.utils.translation import get_language +from django.utils.translation import gettext_lazy as _, gettext, to_locale from drf_yasg import openapi from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE from rest_framework import serializers @@ -64,8 +66,6 @@ embedding_by_document_list from setting.models import Model from smartdoc.conf import PROJECT_DIR -from django.utils.translation import gettext_lazy as _, gettext, to_locale -from django.utils.translation import get_language parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()] parse_table_handle_list = [CsvSplitTableHandle(), XlsSplitTableHandle(), XlsxSplitTableHandle()] @@ -141,7 +141,8 @@ def is_valid(self, *, document: Document = None): if 'meta' in self.data and self.data.get('meta') is not None: dataset_meta_valid_map = self.get_meta_valid_map() valid_class = dataset_meta_valid_map.get(document.type) - valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) + if valid_class is not None: + valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer): @@ -661,6 +662,8 @@ def get_workbook(data_dict, document_dict): cell = worksheet.cell(row=row_idx + 1, column=col_idx + 1) if isinstance(col, str): col = re.sub(ILLEGAL_CHARACTERS_RE, '', col) + if col.startswith(('=', '+', '-', '@')): + col = '\ufeff' + col cell.value = col # 创建HttpResponse对象返回Excel文件 return workbook @@ -806,27 +809,40 @@ def delete(self): def get_response_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['id', 'name', 'char_length', 'user_id', 'paragraph_count', 'is_active' - 'update_time', 'create_time'], + required=['create_time', 'update_time', 'id', 'name', 'char_length', 'status', 'is_active', + 'type', 'meta', 'dataset_id', 'hit_handling_method', 'directly_return_similarity', + 'status_meta', 'paragraph_count'], properties={ + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title=_('create time'), + description=_('create time'), + default="1970-01-01 00:00:00"), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title=_('update time'), + description=_('update time'), + default="1970-01-01 00:00:00"), 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", description="id", default="xx"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title=_('name'), description=_('name'), default="xx"), 'char_length': openapi.Schema(type=openapi.TYPE_INTEGER, title=_('char length'), description=_('char length'), default=10), - 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title=_('user id'), description=_('user id')), - 'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="_('document count')", - description="_('document count')", default=1), + 'status':openapi.Schema(type=openapi.TYPE_STRING, title=_('status'), + description=_('status'), default="xx"), 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title=_('Is active'), description=_('Is active'), default=True), - 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title=_('update time'), - description=_('update time'), - default="1970-01-01 00:00:00"), - 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title=_('create time'), - description=_('create time'), - default="1970-01-01 00:00:00" - ) + 'type': openapi.Schema(type=openapi.TYPE_STRING, title=_('type'), + description=_('type'), default="xx"), + 'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title=_('meta'), + description=_('meta'), default="{}"), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title=_('dataset_id'), + description=_('dataset_id'), default="xx"), + 'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title=_('hit_handling_method'), + description=_('hit_handling_method'), default="xx"), + 'directly_return_similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title=_('directly_return_similarity'), + description=_('directly_return_similarity'), default="xx"), + 'status_meta': openapi.Schema(type=openapi.TYPE_OBJECT, title=_('status_meta'), + description=_('status_meta'), default="{}"), + 'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="_('document count')", + description="_('document count')", default=1), } ) @@ -853,7 +869,7 @@ def get_request_body_api(): class Create(ApiMixin, serializers.Serializer): dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( - _('document id'))) + _('dataset id'))) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -981,7 +997,7 @@ def get_request_params_api(): in_=openapi.IN_PATH, type=openapi.TYPE_STRING, required=True, - description=_('document id')) + description=_('dataset id')) ] class Split(ApiMixin, serializers.Serializer): diff --git a/apps/dataset/serializers/file_serializers.py b/apps/dataset/serializers/file_serializers.py index 37f72fc8429..899c8a088de 100644 --- a/apps/dataset/serializers/file_serializers.py +++ b/apps/dataset/serializers/file_serializers.py @@ -28,6 +28,9 @@ "woff2": "font/woff2", "jar": "application/java-archive", "war": "application/java-archive", "ear": "application/java-archive", "json": "application/json", "hqx": "application/mac-binhex40", "doc": "application/msword", "pdf": "application/pdf", "ps": "application/postscript", + "docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", "eps": "application/postscript", "ai": "application/postscript", "rtf": "application/rtf", "m3u8": "application/vnd.apple.mpegurl", "kml": "application/vnd.google-earth.kml+xml", "kmz": "application/vnd.google-earth.kmz", "xls": "application/vnd.ms-excel", @@ -87,4 +90,4 @@ def get(self, with_valid=True): 'Content-Disposition': 'attachment; filename="{}"'.format( file.file_name)}) return HttpResponse(file.get_byte(), status=200, - headers={'Content-Type': mime_types.get(file.file_name.split(".")[-1], 'text/plain')}) + headers={'Content-Type': mime_types.get(file_type, 'text/plain')}) diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 3a63fd95cd0..9b6e096ba00 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -226,6 +226,14 @@ def is_valid(self, *, raise_exception=True): def association(self, with_valid=True, with_embedding=True): if with_valid: self.is_valid(raise_exception=True) + # 已关联则直接返回 + if QuerySet(ProblemParagraphMapping).filter( + dataset_id=self.data.get('dataset_id'), + document_id=self.data.get('document_id'), + paragraph_id=self.data.get('paragraph_id'), + problem_id=self.data.get('problem_id') + ).exists(): + return True problem = QuerySet(Problem).filter(id=self.data.get("problem_id")).first() problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(), document_id=self.data.get('document_id'), diff --git a/apps/dataset/sql/update_document_char_length.sql b/apps/dataset/sql/update_document_char_length.sql index 4a4060cd9d4..2781809b23d 100644 --- a/apps/dataset/sql/update_document_char_length.sql +++ b/apps/dataset/sql/update_document_char_length.sql @@ -2,6 +2,7 @@ UPDATE "document" SET "char_length" = ( SELECT CASE WHEN "sum" ( "char_length" ( "content" ) ) IS NULL THEN 0 ELSE "sum" ( "char_length" ( "content" ) ) - END FROM paragraph WHERE "document_id" = %s ) + END FROM paragraph WHERE "document_id" = %s ), + "update_time" = CURRENT_TIMESTAMP WHERE "id" = %s \ No newline at end of file diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index bbb9e033980..aeb1af28932 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -7,13 +7,13 @@ @desc: """ +from django.utils.translation import gettext_lazy as _ from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action from rest_framework.parsers import MultiPartParser from rest_framework.views import APIView from rest_framework.views import Request -import dataset.models from common.auth import TokenAuth, has_permissions from common.constants.permission_constants import PermissionConstants, CompareConstants, Permission, Group, Operate, \ ViewPermission, RoleConstants @@ -25,7 +25,6 @@ from dataset.serializers.dataset_serializers import DataSetSerializers from dataset.views.common import get_dataset_operation_object from setting.serializers.provider_serializers import ModelSerializer -from django.utils.translation import gettext_lazy as _ class Dataset(APIView): @@ -141,21 +140,22 @@ def post(self, request: Request): class HitTest(APIView): authentication_classes = [TokenAuth] - @action(methods="GET", detail=False) + @action(methods="PUT", detail=False) @swagger_auto_schema(operation_summary=_('Hit test list'), operation_id=_('Hit test list'), - manual_parameters=CommonApi.HitTestApi.get_request_params_api(), + request_body=CommonApi.HitTestApi.get_request_body_api(), responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()), tags=[_('Knowledge Base')]) @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=keywords.get('dataset_id'))) - def get(self, request: Request, dataset_id: str): - return result.success( - DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id, - "query_text": request.query_params.get("query_text"), - "top_number": request.query_params.get("top_number"), - 'similarity': request.query_params.get('similarity'), - 'search_mode': request.query_params.get('search_mode')}).hit_test( - )) + def put(self, request: Request, dataset_id: str): + return result.success(DataSetSerializers.HitTest(data={ + 'id': dataset_id, + 'user_id': request.user.id, + "query_text": request.data.get("query_text"), + "top_number": request.data.get("top_number"), + 'similarity': request.data.get('similarity'), + 'search_mode': request.data.get('search_mode')} + ).hit_test()) class Embedding(APIView): authentication_classes = [TokenAuth] diff --git a/apps/embedding/sql/blend_search.sql b/apps/embedding/sql/blend_search.sql index afb1f0040d1..c70e66464ee 100644 --- a/apps/embedding/sql/blend_search.sql +++ b/apps/embedding/sql/blend_search.sql @@ -5,15 +5,17 @@ SELECT FROM ( SELECT DISTINCT ON - ( "paragraph_id" ) ( similarity ),* , - similarity AS comprehensive_score + ( "paragraph_id" ) ( 1 - distince + ts_similarity ) as similarity, *, + (1 - distince + ts_similarity) AS comprehensive_score FROM ( SELECT *, - (( 1 - ( embedding.embedding <=> %s ) )+ts_rank_cd( embedding.search_vector, websearch_to_tsquery('simple', %s ), 32 )) AS similarity + (embedding.embedding::vector(%s) <=> %s) as distince, + (ts_rank_cd( embedding.search_vector, websearch_to_tsquery('simple', %s ), 32 )) AS ts_similarity FROM embedding ${embedding_query} + ORDER BY distince ) TEMP ORDER BY paragraph_id, diff --git a/apps/embedding/sql/embedding_search.sql b/apps/embedding/sql/embedding_search.sql index ce3d4a580d5..1b5689959b8 100644 --- a/apps/embedding/sql/embedding_search.sql +++ b/apps/embedding/sql/embedding_search.sql @@ -5,12 +5,12 @@ SELECT FROM ( SELECT DISTINCT ON - ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score + ("paragraph_id") ( 1 - distince ),* ,(1 - distince) AS comprehensive_score FROM - ( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query}) TEMP + ( SELECT *, ( embedding.embedding::vector(%s) <=> %s ) AS distince FROM embedding ${embedding_query} ORDER BY distince) TEMP ORDER BY paragraph_id, - similarity DESC + distince ) DISTINCT_TEMP WHERE comprehensive_score>%s ORDER BY comprehensive_score DESC diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py index 3b26bd7a1db..48846750006 100644 --- a/apps/embedding/task/embedding.py +++ b/apps/embedding/task/embedding.py @@ -17,6 +17,7 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingDatasetIdArgs, \ UpdateEmbeddingDocumentIdArgs from dataset.models import Document, TaskType, State +from dataset.serializers.common_serializers import drop_dataset_index from ops import celery_app from setting.models import Model from setting.models_provider import get_model @@ -110,6 +111,7 @@ def embedding_by_dataset(dataset_id, model_id): max_kb.info(_('Start--->Vectorized dataset: {dataset_id}').format(dataset_id=dataset_id)) try: ListenerManagement.delete_embedding_by_dataset(dataset_id) + drop_dataset_index(dataset_id=dataset_id) document_list = QuerySet(Document).filter(dataset_id=dataset_id) max_kb.info(_('Dataset documentation: {document_names}').format( document_names=", ".join([d.name for d in document_list]))) diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 7929685a37c..af9ff7e4ca3 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -12,7 +12,6 @@ from abc import ABC, abstractmethod from typing import Dict, List -import jieba from django.contrib.postgres.search import SearchVector from django.db.models import QuerySet, Value from langchain_core.embeddings import Embeddings @@ -169,8 +168,13 @@ def handle(self, os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', 'embedding_search.sql')), with_table_name=True) - embedding_model = select_list(exec_sql, - [json.dumps(query_embedding), *exec_params, similarity, top_number]) + embedding_model = select_list(exec_sql, [ + len(query_embedding), + json.dumps(query_embedding), + *exec_params, + similarity, + top_number + ]) return embedding_model def support(self, search_mode: SearchMode): @@ -190,8 +194,12 @@ def handle(self, os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', 'keywords_search.sql')), with_table_name=True) - embedding_model = select_list(exec_sql, - [to_query(query_text), *exec_params, similarity, top_number]) + embedding_model = select_list(exec_sql, [ + to_query(query_text), + *exec_params, + similarity, + top_number + ]) return embedding_model def support(self, search_mode: SearchMode): @@ -211,9 +219,14 @@ def handle(self, os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', 'blend_search.sql')), with_table_name=True) - embedding_model = select_list(exec_sql, - [json.dumps(query_embedding), to_query(query_text), *exec_params, similarity, - top_number]) + embedding_model = select_list(exec_sql, [ + len(query_embedding), + json.dumps(query_embedding), + to_query(query_text), + *exec_params, + similarity, + top_number + ]) return embedding_model def support(self, search_mode: SearchMode): diff --git a/apps/function_lib/migrations/0004_functionlib_decimal_date.py b/apps/function_lib/migrations/0004_functionlib_decimal_date.py new file mode 100644 index 00000000000..82e4a6d029a --- /dev/null +++ b/apps/function_lib/migrations/0004_functionlib_decimal_date.py @@ -0,0 +1,127 @@ +# Generated by Django 4.2.15 on 2025-03-13 07:21 + +from django.db import migrations +from django.db.models import Q + +mysql_template = """ +def query_mysql(host,port, user, password, database, sql): + import pymysql + import json + from pymysql.cursors import DictCursor + from datetime import datetime, date + + def default_serializer(obj): + from decimal import Decimal + if isinstance(obj, (datetime, date)): + return obj.isoformat() # 将 datetime/date 转换为 ISO 格式字符串 + elif isinstance(obj, Decimal): + return float(obj) # 将 Decimal 转换为 float + raise TypeError(f"Type {type(obj)} not serializable") + + try: + # 创建连接 + db = pymysql.connect( + host=host, + port=int(port), + user=user, + password=password, + database=database, + cursorclass=DictCursor # 使用字典游标 + ) + + # 使用 cursor() 方法创建一个游标对象 cursor + cursor = db.cursor() + + # 使用 execute() 方法执行 SQL 查询 + cursor.execute(sql) + + # 使用 fetchall() 方法获取所有数据 + data = cursor.fetchall() + + # 处理 bytes 类型的数据 + for row in data: + for key, value in row.items(): + if isinstance(value, bytes): + row[key] = value.decode("utf-8") # 转换为字符串 + + # 将数据序列化为 JSON + json_data = json.dumps(data, default=default_serializer, ensure_ascii=False) + return json_data + + # 关闭数据库连接 + db.close() + + except Exception as e: + print(f"Error while connecting to MySQL: {e}") + raise e +""" + +pgsql_template = """ +def queryPgSQL(database, user, password, host, port, query): + import psycopg2 + import json + from datetime import datetime + + # 自定义 JSON 序列化函数 + def default_serializer(obj): + from decimal import Decimal + if isinstance(obj, datetime): + return obj.isoformat() # 将 datetime 转换为 ISO 格式字符串 + elif isinstance(obj, Decimal): + return float(obj) # 将 Decimal 转换为 float + raise TypeError(f"Type {type(obj)} not serializable") + + # 数据库连接信息 + conn_params = { + "dbname": database, + "user": user, + "password": password, + "host": host, + "port": port + } + try: + # 建立连接 + conn = psycopg2.connect(**conn_params) + print("连接成功!") + # 创建游标对象 + cursor = conn.cursor() + # 执行查询语句 + cursor.execute(query) + # 获取查询结果 + rows = cursor.fetchall() + # 处理 bytes 类型的数据 + columns = [desc[0] for desc in cursor.description] + result = [dict(zip(columns, row)) for row in rows] + # 转换为 JSON 格式 + json_result = json.dumps(result, default=default_serializer, ensure_ascii=False) + return json_result + except Exception as e: + print(f"发生错误:{e}") + raise e + finally: + # 关闭游标和连接 + if cursor: + cursor.close() + if conn: + conn.close() +""" + + +def fix_type(apps, schema_editor): + FunctionLib = apps.get_model('function_lib', 'FunctionLib') + FunctionLib.objects.filter( + Q(id='22c21b76-0308-11f0-9694-5618c4394482') | Q(template_id='22c21b76-0308-11f0-9694-5618c4394482') + ).update(code=mysql_template) + FunctionLib.objects.filter( + Q(id='bd1e8b88-0302-11f0-87bb-5618c4394482') | Q(template_id='bd1e8b88-0302-11f0-87bb-5618c4394482') + ).update(code=pgsql_template) + + +class Migration(migrations.Migration): + dependencies = [ + ('function_lib', '0003_functionlib_function_type_functionlib_icon_and_more'), + ] + + operations = [ + migrations.RunPython(fix_type) + ] diff --git a/apps/function_lib/serializers/function_lib_serializer.py b/apps/function_lib/serializers/function_lib_serializer.py index 440eb22c786..ad7ff3cce61 100644 --- a/apps/function_lib/serializers/function_lib_serializer.py +++ b/apps/function_lib/serializers/function_lib_serializer.py @@ -33,11 +33,13 @@ function_executor = FunctionExecutor(CONFIG.get('SANDBOX')) + class FlibInstance: def __init__(self, function_lib: dict, version: str): self.function_lib = function_lib self.version = version + def encryption(message: str): """ 加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890 @@ -68,7 +70,8 @@ def encryption(message: str): class FunctionLibModelSerializer(serializers.ModelSerializer): class Meta: model = FunctionLib - fields = ['id', 'name', 'icon', 'desc', 'code', 'input_field_list','init_field_list', 'init_params', 'permission_type', 'is_active', 'user_id', 'template_id', + fields = ['id', 'name', 'icon', 'desc', 'code', 'input_field_list', 'init_field_list', 'init_params', + 'permission_type', 'is_active', 'user_id', 'template_id', 'create_time', 'update_time'] @@ -148,7 +151,6 @@ class Query(serializers.Serializer): select_user_id = serializers.CharField(required=False, allow_null=True, allow_blank=True) function_type = serializers.CharField(required=False, allow_null=True, allow_blank=True) - def get_query_set(self): query_set = QuerySet(FunctionLib).filter( (Q(user_id=self.data.get('user_id')) | Q(permission_type='PUBLIC'))) @@ -269,7 +271,7 @@ class Operate(serializers.Serializer): def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) - if not QuerySet(FunctionLib).filter(id=self.data.get('id')).exists(): + if not QuerySet(FunctionLib).filter(user_id=self.data.get('user_id'), id=self.data.get('id')).exists(): raise AppApiException(500, _('Function does not exist')) def delete(self, with_valid=True): @@ -285,7 +287,8 @@ def edit(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) EditFunctionLib(data=instance).is_valid(raise_exception=True) - edit_field_list = ['name', 'desc', 'code', 'icon', 'input_field_list', 'init_field_list', 'init_params', 'permission_type', 'is_active'] + edit_field_list = ['name', 'desc', 'code', 'icon', 'input_field_list', 'init_field_list', 'init_params', + 'permission_type', 'is_active'] edit_dict = {field: instance.get(field) for field in edit_field_list if ( field in instance and instance.get(field) is not None)} @@ -317,7 +320,8 @@ def one(self, with_valid=True): if function_lib.init_params: function_lib.init_params = json.loads(rsa_long_decrypt(function_lib.init_params)) if function_lib.init_field_list: - password_fields = [i["field"] for i in function_lib.init_field_list if i.get("input_type") == "PasswordInput"] + password_fields = [i["field"] for i in function_lib.init_field_list if + i.get("input_type") == "PasswordInput"] if function_lib.init_params: for k in function_lib.init_params: if k in password_fields and function_lib.init_params[k]: diff --git a/apps/locales/en_US/LC_MESSAGES/django.po b/apps/locales/en_US/LC_MESSAGES/django.po index d13912928b9..9b83be9686d 100644 --- a/apps/locales/en_US/LC_MESSAGES/django.po +++ b/apps/locales/en_US/LC_MESSAGES/django.po @@ -7238,7 +7238,7 @@ msgstr "" msgid "" "The confirmation password must be 6-20 characters long and must be a " "combination of letters, numbers, and special characters." -msgstr "" +msgstr "The confirmation password must be 6-20 characters long and must be a combination of letters, numbers, and special characters.(Special character support:_、!、@、#、$、(、) ……)" #: community/apps/users/serializers/user_serializers.py:380 #, python-brace-format @@ -7490,4 +7490,22 @@ msgid "Field: {name} No value set" msgstr "" msgid "Generate related" +msgstr "" + +msgid "Obtain graphical captcha" +msgstr "" + +msgid "Captcha code error or expiration" +msgstr "" + +msgid "captcha" +msgstr "" + +msgid "Reasoning enable" +msgstr "" + +msgid "Reasoning start tag" +msgstr "" + +msgid "Reasoning end tag" msgstr "" \ No newline at end of file diff --git a/apps/locales/zh_CN/LC_MESSAGES/django.po b/apps/locales/zh_CN/LC_MESSAGES/django.po index b0ab7871bf6..9500103c702 100644 --- a/apps/locales/zh_CN/LC_MESSAGES/django.po +++ b/apps/locales/zh_CN/LC_MESSAGES/django.po @@ -4536,7 +4536,7 @@ msgstr "修改知识库信息" #: community/apps/dataset/views/document.py:463 #: community/apps/dataset/views/document.py:464 msgid "Get the knowledge base paginated list" -msgstr "获取知识库分页列表" +msgstr "获取知识库文档分页列表" #: community/apps/dataset/views/document.py:31 #: community/apps/dataset/views/document.py:32 @@ -7395,7 +7395,7 @@ msgstr "语言只支持:" msgid "" "The confirmation password must be 6-20 characters long and must be a " "combination of letters, numbers, and special characters." -msgstr "确认密码长度6-20个字符,必须字母、数字、特殊字符组合" +msgstr "确认密码长度6-20个字符,必须字母、数字、特殊字符组合(特殊字符支持:_、!、@、#、$、(、) ……)" #: community/apps/users/serializers/user_serializers.py:380 #, python-brace-format @@ -7653,4 +7653,22 @@ msgid "Field: {name} No value set" msgstr "字段: {name} 未设置值" msgid "Generate related" -msgstr "生成问题" \ No newline at end of file +msgstr "生成问题" + +msgid "Obtain graphical captcha" +msgstr "获取图形验证码" + +msgid "Captcha code error or expiration" +msgstr "验证码错误或过期" + +msgid "captcha" +msgstr "验证码" + +msgid "Reasoning enable" +msgstr "开启思考过程" + +msgid "Reasoning start tag" +msgstr "思考过程开始标签" + +msgid "Reasoning end tag" +msgstr "思考过程结束标签" \ No newline at end of file diff --git a/apps/locales/zh_Hant/LC_MESSAGES/django.po b/apps/locales/zh_Hant/LC_MESSAGES/django.po index dab1d176c26..6993cdd2161 100644 --- a/apps/locales/zh_Hant/LC_MESSAGES/django.po +++ b/apps/locales/zh_Hant/LC_MESSAGES/django.po @@ -4545,7 +4545,7 @@ msgstr "修改知識庫信息" #: community/apps/dataset/views/document.py:463 #: community/apps/dataset/views/document.py:464 msgid "Get the knowledge base paginated list" -msgstr "獲取知識庫分頁列表" +msgstr "獲取知識庫文档分頁列表" #: community/apps/dataset/views/document.py:31 #: community/apps/dataset/views/document.py:32 @@ -7405,7 +7405,7 @@ msgstr "語言只支持:" msgid "" "The confirmation password must be 6-20 characters long and must be a " "combination of letters, numbers, and special characters." -msgstr "確認密碼長度6-20個字符,必須字母、數字、特殊字符組合" +msgstr "確認密碼長度6-20個字符,必須字母、數字、特殊字符組合(特殊字元支持:_、!、@、#、$、(、) ……)" #: community/apps/users/serializers/user_serializers.py:380 #, python-brace-format @@ -7663,4 +7663,22 @@ msgid "Field: {name} No value set" msgstr "欄位: {name} 未設定值" msgid "Generate related" -msgstr "生成問題" \ No newline at end of file +msgstr "生成問題" + +msgid "Obtain graphical captcha" +msgstr "獲取圖形驗證碼" + +msgid "Captcha code error or expiration" +msgstr "驗證碼錯誤或過期" + +msgid "captcha" +msgstr "驗證碼" + +msgid "Reasoning enable" +msgstr "開啟思考過程" + +msgid "Reasoning start tag" +msgstr "思考過程開始標籤" + +msgid "Reasoning end tag" +msgstr "思考過程結束標籤" \ No newline at end of file diff --git a/apps/setting/migrations/0011_refresh_collation_reindex.py b/apps/setting/migrations/0011_refresh_collation_reindex.py new file mode 100644 index 00000000000..0f93d4ad481 --- /dev/null +++ b/apps/setting/migrations/0011_refresh_collation_reindex.py @@ -0,0 +1,61 @@ +import logging + +import psycopg +from django.db import migrations + +from smartdoc.const import CONFIG + + +def get_connect(db_name): + conn_params = { + "dbname": db_name, + "user": CONFIG.get('DB_USER'), + "password": CONFIG.get('DB_PASSWORD'), + "host": CONFIG.get('DB_HOST'), + "port": CONFIG.get('DB_PORT') + } + # 建立连接 + connect = psycopg.connect(**conn_params) + return connect + + +def sql_execute(conn, reindex_sql: str, alter_database_sql: str): + """ + 执行一条sql + @param reindex_sql: + @param conn: + @param alter_database_sql: + """ + conn.autocommit = True + with conn.cursor() as cursor: + cursor.execute(reindex_sql, []) + cursor.execute(alter_database_sql, []) + cursor.close() + +def re_index(apps, schema_editor): + app_db_name = CONFIG.get('DB_NAME') + try: + re_index_database(app_db_name) + except Exception as e: + logging.error(f'reindex database {app_db_name}发送错误:{str(e)}') + try: + re_index_database('root') + except Exception as e: + logging.error(f'reindex database root 发送错误:{str(e)}') + + +def re_index_database(db_name): + db_conn = get_connect(db_name) + sql_execute(db_conn, f'REINDEX DATABASE "{db_name}";', f'ALTER DATABASE "{db_name}" REFRESH COLLATION VERSION;') + db_conn.close() + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0010_log'), + ] + + operations = [ + migrations.RunPython(re_index, atomic=False) + ] diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 622be703dad..2b02bdc1fb1 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -106,7 +106,10 @@ def filter_optional_params(model_kwargs): optional_params = {} for key, value in model_kwargs.items(): if key not in ['model_id', 'use_local', 'streaming', 'show_ref_label']: - optional_params[key] = value + if key == 'extra_body' and isinstance(value, dict): + optional_params = {**optional_params, **value} + else: + optional_params[key] = value return optional_params diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py index e6bf698b01a..e68b9361f0b 100644 --- a/apps/setting/models_provider/constants/model_provider_constants.py +++ b/apps/setting/models_provider/constants/model_provider_constants.py @@ -19,6 +19,8 @@ from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider +from setting.models_provider.impl.regolo_model_provider.regolo_model_provider import \ + RegoloModelProvider from setting.models_provider.impl.siliconCloud_model_provider.siliconCloud_model_provider import \ SiliconCloudModelProvider from setting.models_provider.impl.tencent_cloud_model_provider.tencent_cloud_model_provider import \ @@ -55,3 +57,4 @@ class ModelProvideConstants(Enum): aliyun_bai_lian_model_provider = AliyunBaiLianModelProvider() model_anthropic_provider = AnthropicModelProvider() model_siliconCloud_provider = SiliconCloudModelProvider() + model_regolo_provider = RegoloModelProvider() diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py index 8c5031f08f2..b1d72f0869a 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py @@ -51,6 +51,23 @@ _("Universal text vector is Tongyi Lab's multi-language text unified vector model based on the LLM base. It provides high-level vector services for multiple mainstream languages around the world and helps developers quickly convert text data into high-quality vector data."), ModelTypeConst.EMBEDDING, aliyun_bai_lian_embedding_model_credential, AliyunBaiLianEmbedding), + ModelInfo('qwen3-0.6b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-1.7b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-4b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-8b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-14b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-32b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-30b-a3b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen3-235b-a22b', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, BaiLianChatModel), ModelInfo('qwen-plus', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py index f316a0c6d1c..9da30b72796 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py @@ -30,6 +30,29 @@ class BaiLianLLMModelParams(BaseForm): precision=0) +class BaiLianLLMStreamModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel(_('Temperature'), + _('Higher values make the output more random, while lower values make it more focused and deterministic')), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel(_('Output the maximum Tokens'), + _('Specify the maximum number of tokens that the model can generate')), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + stream = forms.SwitchField(label=TooltipLabel(_('Is the answer in streaming mode'), + _('Is the answer in streaming mode')), + required=True, default_value=True) + + class BaiLianLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -47,7 +70,11 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje return False try: model = provider.get_model(model_type, model_name, model_credential, **model_params) - model.invoke([HumanMessage(content=gettext('Hello'))]) + if model_params.get('stream'): + for res in model.stream([HumanMessage(content=gettext('Hello'))]): + pass + else: + model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: traceback.print_exc() if isinstance(e, AppApiException): @@ -68,4 +95,6 @@ def encryption_dict(self, model: Dict[str, object]): api_key = forms.PasswordInputField('API Key', required=True) def get_model_params_setting_form(self, model_name): + if 'qwen3' in model_name: + return BaiLianLLMStreamModelParams() return BaiLianLLMModelParams() diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py index 2b1fe31f228..7cda97f2388 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py @@ -15,9 +15,8 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model_name=model_name, openai_api_key=model_credential.get('api_key'), openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', - # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) return chat_tong_yi diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py index d914f7c8ad6..ee3ee6488c2 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py @@ -20,5 +20,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py index ef1c133378e..7b0088a4ab4 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -1,10 +1,12 @@ import os import re -from typing import Dict +from typing import Dict, List from botocore.config import Config from langchain_community.chat_models import BedrockChat +from langchain_core.messages import BaseMessage, get_buffer_string +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -72,6 +74,20 @@ def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[s config=config ) + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + try: + return super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + try: + return super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) + def _update_aws_credentials(profile_name, access_key_id, secret_access_key): credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") diff --git a/apps/setting/models_provider/impl/base_chat_open_ai.py b/apps/setting/models_provider/impl/base_chat_open_ai.py index 54076b7efda..626a751f740 100644 --- a/apps/setting/models_provider/impl/base_chat_open_ai.py +++ b/apps/setting/models_provider/impl/base_chat_open_ai.py @@ -1,15 +1,16 @@ # coding=utf-8 -import warnings -from typing import List, Dict, Optional, Any, Iterator, cast, Type, Union +from typing import Dict, Optional, Any, Iterator, cast, Union, Sequence, Callable, Mapping -import openai -from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LanguageModelInput -from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk -from langchain_core.outputs import ChatGenerationChunk, ChatGeneration +from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, HumanMessageChunk, AIMessageChunk, \ + SystemMessageChunk, FunctionMessageChunk, ChatMessageChunk +from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.tool import tool_call_chunk, ToolMessageChunk +from langchain_core.outputs import ChatGenerationChunk from langchain_core.runnables import RunnableConfig, ensure_config -from langchain_core.utils.pydantic import is_basemodel_subclass +from langchain_core.tools import BaseTool from langchain_openai import ChatOpenAI +from langchain_openai.chat_models.base import _create_usage_metadata from common.config.tokenizer_manage_config import TokenizerManage @@ -19,6 +20,65 @@ def custom_get_token_ids(text: str): return tokenizer.encode(text) +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] +) -> BaseMessageChunk: + id_ = _dict.get("id") + role = cast(str, _dict.get("role")) + content = cast(str, _dict.get("content") or "") + additional_kwargs: dict = {} + if 'reasoning_content' in _dict: + additional_kwargs['reasoning_content'] = _dict.get('reasoning_content') + if _dict.get("function_call"): + function_call = dict(_dict["function_call"]) + if "name" in function_call and function_call["name"] is None: + function_call["name"] = "" + additional_kwargs["function_call"] = function_call + tool_call_chunks = [] + if raw_tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + try: + tool_call_chunks = [ + tool_call_chunk( + name=rtc["function"].get("name"), + args=rtc["function"].get("arguments"), + id=rtc.get("id"), + index=rtc["index"], + ) + for rtc in raw_tool_calls + ] + except KeyError: + pass + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content, id=id_) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + id=id_, + tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] + ) + elif role in ("system", "developer") or default_class == SystemMessageChunk: + if role == "developer": + additional_kwargs = {"__openai_role__": "developer"} + else: + additional_kwargs = {} + return SystemMessageChunk( + content=content, id=id_, additional_kwargs=additional_kwargs + ) + elif role == "function" or default_class == FunctionMessageChunk: + return FunctionMessageChunk(content=content, name=_dict["name"], id=id_) + elif role == "tool" or default_class == ToolMessageChunk: + return ToolMessageChunk( + content=content, tool_call_id=_dict["tool_call_id"], id=id_ + ) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role, id=id_) + else: + return default_class(content=content, id=id_) # type: ignore + + class BaseChatOpenAI(ChatOpenAI): usage_metadata: dict = {} custom_get_token_ids = custom_get_token_ids @@ -26,14 +86,20 @@ class BaseChatOpenAI(ChatOpenAI): def get_last_generation_info(self) -> Optional[Dict[str, Any]]: return self.usage_metadata - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + def get_num_tokens_from_messages( + self, + messages: list[BaseMessage], + tools: Optional[ + Sequence[Union[dict[str, Any], type, Callable, BaseTool]] + ] = None, + ) -> int: if self.usage_metadata is None or self.usage_metadata == {}: try: return super().get_num_tokens_from_messages(messages) except Exception as e: tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - return self.usage_metadata.get('input_tokens', 0) + return self.usage_metadata.get('input_tokens', self.usage_metadata.get('prompt_tokens', 0)) def get_num_tokens(self, text: str) -> int: if self.usage_metadata is None or self.usage_metadata == {}: @@ -42,116 +108,80 @@ def get_num_tokens(self, text: str) -> int: except Exception as e: tokenizer = TokenizerManage.get_tokenizer() return len(tokenizer.encode(text)) - return self.get_last_generation_info().get('output_tokens', 0) + return self.get_last_generation_info().get('output_tokens', + self.get_last_generation_info().get('completion_tokens', 0)) + + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: + kwargs['stream_usage'] = True + for chunk in super()._stream(*args, **kwargs): + if chunk.message.usage_metadata is not None: + self.usage_metadata = chunk.message.usage_metadata + yield chunk - def _stream( + def _convert_chunk_to_generation_chunk( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - kwargs["stream"] = True - kwargs["stream_options"] = {"include_usage": True} - """Set default stream_options.""" - stream_usage = self._should_stream_usage(kwargs.get('stream_usage'), **kwargs) - # Note: stream_options is not a valid parameter for Azure OpenAI. - # To support users proxying Azure through ChatOpenAI, here we only specify - # stream_options if include_usage is set to True. - # See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new - # for release notes. - if stream_usage: - kwargs["stream_options"] = {"include_usage": stream_usage} - - payload = self._get_request_payload(messages, stop=stop, **kwargs) - default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - base_generation_info = {} - - if "response_format" in payload and is_basemodel_subclass( - payload["response_format"] - ): - # TODO: Add support for streaming with Pydantic response_format. - warnings.warn("Streaming with Pydantic response_format not yet supported.") - chat_result = self._generate( - messages, stop, run_manager=run_manager, **kwargs - ) - msg = chat_result.generations[0].message - yield ChatGenerationChunk( - message=AIMessageChunk( - **msg.dict(exclude={"type", "additional_kwargs"}), - # preserve the "parsed" Pydantic object without converting to dict - additional_kwargs=msg.additional_kwargs, - ), - generation_info=chat_result.generations[0].generation_info, + chunk: dict, + default_chunk_class: type, + base_generation_info: Optional[dict], + ) -> Optional[ChatGenerationChunk]: + if chunk.get("type") == "content.delta": # from beta.chat.completions.stream + return None + token_usage = chunk.get("usage") + choices = ( + chunk.get("choices", []) + # from beta.chat.completions.stream + or chunk.get("chunk", {}).get("choices", []) + ) + + usage_metadata: Optional[UsageMetadata] = ( + _create_usage_metadata(token_usage) if token_usage and token_usage.get("prompt_tokens") else None + ) + if len(choices) == 0: + # logprobs is implicitly None + generation_chunk = ChatGenerationChunk( + message=default_chunk_class(content="", usage_metadata=usage_metadata) ) - return - if self.include_response_headers: - raw_response = self.client.with_raw_response.create(**payload) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} - else: - response = self.client.create(**payload) - with response: - is_first_chunk = True - for chunk in response: - if not isinstance(chunk, dict): - chunk = chunk.model_dump() - - generation_chunk = super()._convert_chunk_to_generation_chunk( - chunk, - default_chunk_class, - base_generation_info if is_first_chunk else {}, - ) - if generation_chunk is None: - continue - - # custom code - if len(chunk['choices']) > 0 and 'reasoning_content' in chunk['choices'][0]['delta']: - generation_chunk.message.additional_kwargs["reasoning_content"] = chunk['choices'][0]['delta'][ - 'reasoning_content'] - - default_chunk_class = generation_chunk.message.__class__ - logprobs = (generation_chunk.generation_info or {}).get("logprobs") - if run_manager: - run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk, logprobs=logprobs - ) - is_first_chunk = False - # custom code - if generation_chunk.message.usage_metadata is not None: - self.usage_metadata = generation_chunk.message.usage_metadata - yield generation_chunk - - def _create_chat_result(self, - response: Union[dict, openai.BaseModel], - generation_info: Optional[Dict] = None): - result = super()._create_chat_result(response, generation_info) - try: - reasoning_content = '' - reasoning_content_enable = False - for res in response.choices: - if 'reasoning_content' in res.message.model_extra: - reasoning_content_enable = True - _reasoning_content = res.message.model_extra.get('reasoning_content') - if _reasoning_content is not None: - reasoning_content += _reasoning_content - if reasoning_content_enable: - result.llm_output['reasoning_content'] = reasoning_content - except Exception as e: - pass - return result + return generation_chunk + + choice = choices[0] + if choice["delta"] is None: + return None + + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {**base_generation_info} if base_generation_info else {} + + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + if model_name := chunk.get("model"): + generation_info["model_name"] = model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint + + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + + if usage_metadata and isinstance(message_chunk, AIMessageChunk): + message_chunk.usage_metadata = usage_metadata + + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None + ) + return generation_chunk def invoke( self, input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> BaseMessage: config = ensure_config(config) chat_result = cast( - ChatGeneration, + "ChatGeneration", self.generate_prompt( [self._convert_input(input)], stop=stop, @@ -162,7 +192,9 @@ def invoke( run_id=config.pop("run_id", None), **kwargs, ).generations[0][0], + ).message + self.usage_metadata = chat_result.response_metadata[ 'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata return chat_result diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py index 9db4faca7cc..081d648a716 100644 --- a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py @@ -26,6 +26,6 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base='https://api.deepseek.com', openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) return deepseek_chat_open_ai diff --git a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py index 4106cc1d6e3..af23d0341a4 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py @@ -13,7 +13,7 @@ Tool as GoogleTool, ) from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.messages import BaseMessage +from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.outputs import ChatGenerationChunk from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDict @@ -22,6 +22,8 @@ from langchain_google_genai._common import ( SafetySettingDict, ) + +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -46,10 +48,18 @@ def get_last_generation_info(self) -> Optional[Dict[str, Any]]: return self.__dict__.get('_last_generation_info') def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return self.get_last_generation_info().get('input_tokens', 0) + try: + return self.get_last_generation_info().get('input_tokens', 0) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) def get_num_tokens(self, text: str) -> int: - return self.get_last_generation_info().get('output_tokens', 0) + try: + return self.get_last_generation_info().get('output_tokens', 0) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) def _stream( self, diff --git a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py index c389c177e4e..c0ce2ec029a 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py @@ -21,11 +21,10 @@ def is_cache_model(): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) - kimi_chat_open_ai = KimiChatModel( openai_api_base=model_credential['api_base'], openai_api_key=model_credential['api_key'], model_name=model_name, - **optional_params + extra_body=optional_params, ) return kimi_chat_open_ai diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py index 0194d1f0d27..add06621937 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class OllamaLLMModelParams(BaseForm): _step=0.01, precision=2) - max_tokens = forms.SliderField( + num_predict = forms.SliderField( TooltipLabel(_('Output the maximum Tokens'), _('Specify the maximum number of tokens that the model can generate')), required=True, default_value=1024, diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/image.py b/apps/setting/models_provider/impl/ollama_model_provider/model/image.py index 4cf0f1d56fc..215ce0130d7 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/image.py @@ -28,5 +28,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/image.py b/apps/setting/models_provider/impl/openai_model_provider/model/image.py index 731f476c45f..7ac0906a786 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/image.py @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index 2e6dd89ac93..1893852100b 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -9,7 +9,6 @@ from typing import List, Dict from langchain_core.messages import BaseMessage, get_buffer_string -from langchain_openai.chat_models import ChatOpenAI from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -35,9 +34,9 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** streaming = False azure_chat_open_ai = OpenAIChatModel( model=model_name, - openai_api_base=model_credential.get('api_base'), - openai_api_key=model_credential.get('api_key'), - **optional_params, + base_url=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + extra_body=optional_params, streaming=streaming, custom_get_token_ids=custom_get_token_ids ) diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/image.py b/apps/setting/models_provider/impl/qwen_model_provider/model/image.py index 97166757e67..bf3af0e3484 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/image.py @@ -18,9 +18,8 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model_name=model_name, openai_api_key=model_credential.get('api_key'), openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', - # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) return chat_tong_yi diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py index 3b66ddfd62a..c4df28af9bb 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py @@ -26,6 +26,6 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) return chat_tong_yi diff --git a/apps/setting/models_provider/impl/regolo_model_provider/__init__.py b/apps/setting/models_provider/impl/regolo_model_provider/__init__.py new file mode 100644 index 00000000000..2dc4ab10db4 --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/3/28 16:25 + @desc: +""" diff --git a/apps/setting/models_provider/impl/regolo_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/regolo_model_provider/credential/embedding.py new file mode 100644 index 00000000000..ddea7fed52d --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/credential/embedding.py @@ -0,0 +1,52 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 16:45 + @desc: +""" +import traceback +from typing import Dict + +from django.utils.translation import gettext as _ + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class RegoloEmbeddingCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=True): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + _('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key)) + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query(_('Hello')) + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + _('Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/regolo_model_provider/credential/image.py b/apps/setting/models_provider/impl/regolo_model_provider/credential/image.py new file mode 100644 index 00000000000..5975c774806 --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/credential/image.py @@ -0,0 +1,74 @@ +# coding=utf-8 +import base64 +import os +import traceback +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from django.utils.translation import gettext_lazy as _, gettext + + +class RegoloImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel(_('Temperature'), + _('Higher values make the output more random, while lower values make it more focused and deterministic')), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel(_('Output the maximum Tokens'), + _('Specify the maximum number of tokens that the model can generate')), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class RegoloImageModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API URL', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key)) + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])]) + for chunk in res: + print(chunk) + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + gettext( + 'Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return RegoloImageModelParams() diff --git a/apps/setting/models_provider/impl/regolo_model_provider/credential/llm.py b/apps/setting/models_provider/impl/regolo_model_provider/credential/llm.py new file mode 100644 index 00000000000..60eb4ff0abf --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/credential/llm.py @@ -0,0 +1,78 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:32 + @desc: +""" +import traceback +from typing import Dict + +from django.utils.translation import gettext_lazy as _, gettext +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class RegoloLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel(_('Temperature'), + _('Higher values make the output more random, while lower values make it more focused and deterministic')), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel(_('Output the maximum Tokens'), + _('Specify the maximum number of tokens that the model can generate')), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class RegoloLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key)) + else: + return False + try: + + model = provider.get_model(model_type, model_name, model_credential, **model_params) + model.invoke([HumanMessage(content=gettext('Hello'))]) + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + gettext( + 'Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return RegoloLLMModelParams() diff --git a/apps/setting/models_provider/impl/regolo_model_provider/credential/tti.py b/apps/setting/models_provider/impl/regolo_model_provider/credential/tti.py new file mode 100644 index 00000000000..88f46ce4143 --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/credential/tti.py @@ -0,0 +1,89 @@ +# coding=utf-8 +import traceback +from typing import Dict + +from django.utils.translation import gettext_lazy as _, gettext + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class RegoloTTIModelParams(BaseForm): + size = forms.SingleSelect( + TooltipLabel(_('Image size'), + _('The image generation endpoint allows you to create raw images based on text prompts. ')), + required=True, + default_value='1024x1024', + option_list=[ + {'value': '1024x1024', 'label': '1024x1024'}, + {'value': '1024x1792', 'label': '1024x1792'}, + {'value': '1792x1024', 'label': '1792x1024'}, + ], + text_field='label', + value_field='value' + ) + + quality = forms.SingleSelect( + TooltipLabel(_('Picture quality'), _(''' +By default, images are produced in standard quality. + ''')), + required=True, + default_value='standard', + option_list=[ + {'value': 'standard', 'label': 'standard'}, + {'value': 'hd', 'label': 'hd'}, + ], + text_field='label', + value_field='value' + ) + + n = forms.SliderField( + TooltipLabel(_('Number of pictures'), + _('1 as default')), + required=True, default_value=1, + _min=1, + _max=10, + _step=1, + precision=0) + + +class RegoloTextToImageModelCredential(BaseForm, BaseModelCredential): + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key)) + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + res = model.check_auth() + print(res) + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + gettext( + 'Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return RegoloTTIModelParams() diff --git a/apps/setting/models_provider/impl/regolo_model_provider/icon/regolo_icon_svg b/apps/setting/models_provider/impl/regolo_model_provider/icon/regolo_icon_svg new file mode 100644 index 00000000000..b69154451ad --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/icon/regolo_icon_svg @@ -0,0 +1,64 @@ + + + + + + + + + + + + + + diff --git a/apps/setting/models_provider/impl/regolo_model_provider/model/embedding.py b/apps/setting/models_provider/impl/regolo_model_provider/model/embedding.py new file mode 100644 index 00000000000..b067b8eff29 --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/model/embedding.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 17:44 + @desc: +""" +from typing import Dict + +from langchain_community.embeddings import OpenAIEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class RegoloEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return RegoloEmbeddingModel( + api_key=model_credential.get('api_key'), + model=model_name, + openai_api_base="https://api.regolo.ai/v1", + ) diff --git a/apps/setting/models_provider/impl/regolo_model_provider/model/image.py b/apps/setting/models_provider/impl/regolo_model_provider/model/image.py new file mode 100644 index 00000000000..f16768fad1e --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/model/image.py @@ -0,0 +1,19 @@ +from typing import Dict + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +class RegoloImage(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return RegoloImage( + model_name=model_name, + openai_api_base="https://api.regolo.ai/v1", + openai_api_key=model_credential.get('api_key'), + streaming=True, + stream_usage=True, + extra_body=optional_params + ) diff --git a/apps/setting/models_provider/impl/regolo_model_provider/model/llm.py b/apps/setting/models_provider/impl/regolo_model_provider/model/llm.py new file mode 100644 index 00000000000..126a756a20d --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/model/llm.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: llm.py + @date:2024/4/18 15:28 + @desc: +""" +from typing import List, Dict + +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class RegoloChatModel(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return RegoloChatModel( + model=model_name, + openai_api_base="https://api.regolo.ai/v1", + openai_api_key=model_credential.get('api_key'), + extra_body=optional_params + ) diff --git a/apps/setting/models_provider/impl/regolo_model_provider/model/tti.py b/apps/setting/models_provider/impl/regolo_model_provider/model/tti.py new file mode 100644 index 00000000000..a92527295ac --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/model/tti.py @@ -0,0 +1,58 @@ +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class RegoloTextToImage(MaxKBBaseModel, BaseTextToImage): + api_base: str + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = "https://api.regolo.ai/v1" + self.model = kwargs.get('model') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return RegoloTextToImage( + model=model_name, + api_base="https://api.regolo.ai/v1", + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def is_cache_model(self): + return False + + def check_auth(self): + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + response_list = chat.models.with_raw_response.list() + + # self.generate_image('生成一个小猫图片') + + def generate_image(self, prompt: str, negative_prompt: str = None): + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + res = chat.images.generate(model=self.model, prompt=prompt, **self.params) + file_urls = [] + for content in res.data: + url = content.url + file_urls.append(url) + + return file_urls diff --git a/apps/setting/models_provider/impl/regolo_model_provider/regolo_model_provider.py b/apps/setting/models_provider/impl/regolo_model_provider/regolo_model_provider.py new file mode 100644 index 00000000000..a5e7dc36550 --- /dev/null +++ b/apps/setting/models_provider/impl/regolo_model_provider/regolo_model_provider.py @@ -0,0 +1,89 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: openai_model_provider.py + @date:2024/3/28 16:26 + @desc: +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ + ModelTypeConst, ModelInfoManage +from setting.models_provider.impl.regolo_model_provider.credential.embedding import \ + RegoloEmbeddingCredential +from setting.models_provider.impl.regolo_model_provider.credential.llm import RegoloLLMModelCredential +from setting.models_provider.impl.regolo_model_provider.credential.tti import \ + RegoloTextToImageModelCredential +from setting.models_provider.impl.regolo_model_provider.model.embedding import RegoloEmbeddingModel +from setting.models_provider.impl.regolo_model_provider.model.llm import RegoloChatModel +from setting.models_provider.impl.regolo_model_provider.model.tti import RegoloTextToImage +from smartdoc.conf import PROJECT_DIR +from django.utils.translation import gettext as _ + +openai_llm_model_credential = RegoloLLMModelCredential() +openai_tti_model_credential = RegoloTextToImageModelCredential() +model_info_list = [ + ModelInfo('Phi-4', '', ModelTypeConst.LLM, + openai_llm_model_credential, RegoloChatModel + ), + ModelInfo('DeepSeek-R1-Distill-Qwen-32B', '', ModelTypeConst.LLM, + openai_llm_model_credential, + RegoloChatModel), + ModelInfo('maestrale-chat-v0.4-beta', '', + ModelTypeConst.LLM, openai_llm_model_credential, + RegoloChatModel), + ModelInfo('Llama-3.3-70B-Instruct', + '', + ModelTypeConst.LLM, openai_llm_model_credential, + RegoloChatModel), + ModelInfo('Llama-3.1-8B-Instruct', + '', + ModelTypeConst.LLM, openai_llm_model_credential, + RegoloChatModel), + ModelInfo('DeepSeek-Coder-6.7B-Instruct', '', + ModelTypeConst.LLM, openai_llm_model_credential, + RegoloChatModel) +] +open_ai_embedding_credential = RegoloEmbeddingCredential() +model_info_embedding_list = [ + ModelInfo('gte-Qwen2', '', + ModelTypeConst.EMBEDDING, open_ai_embedding_credential, + RegoloEmbeddingModel), +] + +model_info_tti_list = [ + ModelInfo('FLUX.1-dev', '', + ModelTypeConst.TTI, openai_tti_model_credential, + RegoloTextToImage), + ModelInfo('sdxl-turbo', '', + ModelTypeConst.TTI, openai_tti_model_credential, + RegoloTextToImage), +] +model_info_manage = ( + ModelInfoManage.builder() + .append_model_info_list(model_info_list) + .append_default_model_info( + ModelInfo('gpt-3.5-turbo', _('The latest gpt-3.5-turbo, updated with OpenAI adjustments'), ModelTypeConst.LLM, + openai_llm_model_credential, RegoloChatModel + )) + .append_model_info_list(model_info_embedding_list) + .append_default_model_info(model_info_embedding_list[0]) + .append_model_info_list(model_info_tti_list) + .append_default_model_info(model_info_tti_list[0]) + + .build() +) + + +class RegoloModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_regolo_provider', name='Regolo', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'regolo_model_provider', + 'icon', + 'regolo_icon_svg'))) diff --git a/apps/setting/models_provider/impl/siliconCloud_model_provider/model/image.py b/apps/setting/models_provider/impl/siliconCloud_model_provider/model/image.py index bb840f8c6dc..2ec0689d4d2 100644 --- a/apps/setting/models_provider/impl/siliconCloud_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/siliconCloud_model_provider/model/image.py @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/siliconCloud_model_provider/model/llm.py b/apps/setting/models_provider/impl/siliconCloud_model_provider/model/llm.py index 9d79c6e0761..6fb0c7816fa 100644 --- a/apps/setting/models_provider/impl/siliconCloud_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/siliconCloud_model_provider/model/llm.py @@ -34,5 +34,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/tencent_cloud_model_provider/model/llm.py b/apps/setting/models_provider/impl/tencent_cloud_model_provider/model/llm.py index 7653cfc2f1f..cfcdf7aca21 100644 --- a/apps/setting/models_provider/impl/tencent_cloud_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/tencent_cloud_model_provider/model/llm.py @@ -33,21 +33,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params, + extra_body=optional_params, custom_get_token_ids=custom_get_token_ids ) return azure_chat_open_ai - - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - try: - return super().get_num_tokens_from_messages(messages) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - try: - return super().get_num_tokens(text) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/image.py b/apps/setting/models_provider/impl/tencent_model_provider/model/image.py index 1b66ab6d23f..6800cdd567c 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/image.py @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/vllm_model_provider/model/image.py b/apps/setting/models_provider/impl/vllm_model_provider/model/image.py index 4d5dda29dd7..c8cb0a84db9 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/model/image.py @@ -19,7 +19,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) def is_cache_model(self): diff --git a/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py index 7d2a63acd08..4662a616965 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py @@ -1,9 +1,10 @@ # coding=utf-8 -from typing import Dict, List +from typing import Dict, Optional, Sequence, Union, Any, Callable from urllib.parse import urlparse, ParseResult from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.tools import BaseTool from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -31,13 +32,19 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params, streaming=True, stream_usage=True, + extra_body=optional_params ) return vllm_chat_open_ai - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + def get_num_tokens_from_messages( + self, + messages: list[BaseMessage], + tools: Optional[ + Sequence[Union[dict[str, Any], type, Callable, BaseTool]] + ] = None, + ) -> int: if self.usage_metadata is None or self.usage_metadata == {}: tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py index 39446b4e19c..6e2517bd4ad 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py index 181ad2971db..8f089f26988 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py @@ -17,5 +17,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index d963a144625..c9aaf06e0a1 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -98,6 +98,7 @@ .append_default_model_info(model_info_list[2]) .append_default_model_info(model_info_list[3]) .append_default_model_info(model_info_list[4]) + .append_default_model_info(model_info_list[5]) .append_model_info_list(model_info_embedding_list) .append_default_model_info(model_info_embedding_list[0]) .build() diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py index 06ec94aae34..d4d379db3d5 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -27,7 +27,7 @@ class WenxinLLMModelParams(BaseForm): _step=0.01, precision=2) - max_tokens = forms.SliderField( + max_output_tokens = forms.SliderField( TooltipLabel(_('Output the maximum Tokens'), _('Specify the maximum number of tokens that the model can generate')), required=True, default_value=1024, diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/image.py b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py index a195b86491b..66a766ba8c0 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py @@ -19,7 +19,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py index d76979bd3a3..9c0316ad20a 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py @@ -34,7 +34,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** model=model_name, openai_api_base=base_url, openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py b/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py index 8820a198607..28c8d267839 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py @@ -22,6 +22,9 @@ class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor): """UID of the launched model""" api_key: Optional[str] + @staticmethod + def is_cache_model(): + return False @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): return XInferenceReranker(server_url=model_credential.get('server_url'), model_uid=model_name, diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/image.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/image.py index f13c7153803..6ac7830d8ff 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/image.py @@ -16,5 +16,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** # stream_options={"include_usage": True}, streaming=True, stream_usage=True, - **optional_params, + extra_body=optional_params ) diff --git a/apps/smartdoc/conf.py b/apps/smartdoc/conf.py index de61cb8e339..630f32cb152 100644 --- a/apps/smartdoc/conf.py +++ b/apps/smartdoc/conf.py @@ -7,6 +7,7 @@ 2. 程序需要, 用户不需要更改的写到settings中 3. 程序需要, 用户需要更改的写到本config中 """ +import datetime import errno import logging import os @@ -93,7 +94,8 @@ class Config(dict): 'SANDBOX': False, 'LOCAL_MODEL_HOST': '127.0.0.1', 'LOCAL_MODEL_PORT': '11636', - 'LOCAL_MODEL_PROTOCOL': "http" + 'LOCAL_MODEL_PROTOCOL': "http", + 'LOCAL_MODEL_HOST_WORKER': 1 } @@ -111,12 +113,19 @@ def get_db_setting(self) -> dict: "USER": self.get('DB_USER'), "PASSWORD": self.get('DB_PASSWORD'), "ENGINE": self.get('DB_ENGINE'), + "CONN_MAX_AGE": 0, "POOL_OPTIONS": { "POOL_SIZE": 20, - "MAX_OVERFLOW": int(self.get('DB_MAX_OVERFLOW')) + "MAX_OVERFLOW": int(self.get('DB_MAX_OVERFLOW')), + "RECYCLE": 1800, + "TIMEOUT": 30, + 'PRE_PING': True } } + def get_session_timeout(self): + return datetime.timedelta(seconds=int(self.get('SESSION_TIMEOUT', 60 * 60 * 2))) + def get_language_code(self): return self.get('LANGUAGE_CODE', 'zh-CN') diff --git a/apps/smartdoc/settings/base.py b/apps/smartdoc/settings/base.py index edf4586629d..de81420798a 100644 --- a/apps/smartdoc/settings/base.py +++ b/apps/smartdoc/settings/base.py @@ -126,6 +126,10 @@ "token_cache": { 'BACKEND': 'common.cache.file_cache.FileCache', 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径 + }, + 'captcha_cache': { + 'BACKEND': 'common.cache.file_cache.FileCache', + 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "captcha_cache") # 文件夹路径 } } diff --git a/apps/users/serializers/user_serializers.py b/apps/users/serializers/user_serializers.py index 6093819a46a..cea2919f31d 100644 --- a/apps/users/serializers/user_serializers.py +++ b/apps/users/serializers/user_serializers.py @@ -6,18 +6,23 @@ @date:2023/9/5 16:32 @desc: """ +import base64 import datetime +import json import os import random import re import uuid +from captcha.image import ImageCaptcha from django.conf import settings from django.core import validators, signing, cache from django.core.mail import send_mail from django.core.mail.backends.smtp import EmailBackend from django.db import transaction from django.db.models import Q, QuerySet, Prefetch +from django.utils.translation import get_language +from django.utils.translation import gettext_lazy as _, to_locale from drf_yasg import openapi from rest_framework import serializers @@ -30,18 +35,39 @@ from common.mixins.api_mixin import ApiMixin from common.models.db_model_manage import DBModelManage from common.response.result import get_api_response -from common.util.common import valid_license +from common.util.common import valid_license, get_random_chars from common.util.field_message import ErrMessage from common.util.lock import lock +from common.util.rsa_util import decrypt, get_key_pair_by_sql from dataset.models import DataSet, Document, Paragraph, Problem, ProblemParagraphMapping from embedding.task import delete_embedding_by_dataset_id_list from function_lib.models.function import FunctionLib from setting.models import Team, SystemSetting, SettingType, Model, TeamMember, TeamMemberPermission from smartdoc.conf import PROJECT_DIR from users.models.user import User, password_encrypt, get_user_dynamics_permission -from django.utils.translation import gettext_lazy as _, gettext, to_locale -from django.utils.translation import get_language + user_cache = cache.caches['user_cache'] +captcha_cache = cache.caches['captcha_cache'] + + +class CaptchaSerializer(ApiMixin, serializers.Serializer): + @staticmethod + def get_response_body_api(): + return get_api_response(openapi.Schema( + type=openapi.TYPE_STRING, + title="captcha", + default="xxxx", + description="captcha" + )) + + @staticmethod + def generate(): + chars = get_random_chars() + image = ImageCaptcha() + data = image.generate(chars) + captcha = base64.b64encode(data.getbuffer()) + captcha_cache.set(f"LOGIN:{chars.lower()}", chars, timeout=5 * 60) + return 'data:image/png;base64,' + captcha.decode() class SystemSerializer(ApiMixin, serializers.Serializer): @@ -51,7 +77,8 @@ def get_profile(): xpack_cache = DBModelManage.get_model('xpack_cache') return {'version': version, 'IS_XPACK': hasattr(settings, 'IS_XPACK'), 'XPACK_LICENSE_IS_VALID': False if xpack_cache is None else xpack_cache.get('XPACK_LICENSE_IS_VALID', - False)} + False), + 'ras': get_key_pair_by_sql().get('key')} @staticmethod def get_response_body_api(): @@ -71,30 +98,14 @@ class LoginSerializer(ApiMixin, serializers.Serializer): password = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Password"))) - def is_valid(self, *, raise_exception=False): - """ - 校验参数 - :param raise_exception: Whether to throw an exception can only be True - :return: User information - """ - super().is_valid(raise_exception=True) - username = self.data.get("username") - password = password_encrypt(self.data.get("password")) - user = QuerySet(User).filter(Q(username=username, - password=password) | Q(email=username, - password=password)).first() - if user is None: - raise ExceptionCodeConstants.INCORRECT_USERNAME_AND_PASSWORD.value.to_app_api_exception() - if not user.is_active: - raise AppApiException(1005, _("The user has been disabled, please contact the administrator!")) - return user + captcha = serializers.CharField(required=True, error_messages=ErrMessage.char(_("captcha"))) + encryptedData = serializers.CharField(required=False, label=_('encryptedData'), allow_null=True, + allow_blank=True) - def get_user_token(self): + def get_user_token(self, user): """ - Get user token :return: User Token (authentication information) """ - user = self.is_valid() token = signing.dumps({'username': user.username, 'id': str(user.id), 'email': user.email, 'type': AuthenticationType.USER.value}) return token @@ -106,10 +117,13 @@ class Meta: def get_request_body_api(self): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['username', 'password'], + required=['username', 'encryptedData'], properties={ 'username': openapi.Schema(type=openapi.TYPE_STRING, title=_("Username"), description=_("Username")), - 'password': openapi.Schema(type=openapi.TYPE_STRING, title=_("Password"), description=_("Password")) + 'password': openapi.Schema(type=openapi.TYPE_STRING, title=_("Password"), description=_("Password")), + 'captcha': openapi.Schema(type=openapi.TYPE_STRING, title=_("captcha"), description=_("captcha")), + 'encryptedData': openapi.Schema(type=openapi.TYPE_STRING, title=_("encryptedData"), + description=_("encryptedData")) } ) @@ -121,6 +135,29 @@ def get_response_body_api(self): description="认证token" )) + @staticmethod + def login(instance): + username = instance.get("username", "") + encryptedData = instance.get("encryptedData", "") + if encryptedData: + json_data = json.loads(decrypt(encryptedData)) + instance.update(json_data) + LoginSerializer(data=instance).is_valid(raise_exception=True) + password = instance.get("password") + captcha = instance.get("captcha", "") + captcha_value = captcha_cache.get(f"LOGIN:{captcha.lower()}") + if captcha_value is None: + raise AppApiException(1005, _("Captcha code error or expiration")) + user = QuerySet(User).filter(Q(username=username, + password=password_encrypt(password)) | Q(email=username, + password=password_encrypt( + password))).first() + if user is None: + raise ExceptionCodeConstants.INCORRECT_USERNAME_AND_PASSWORD.value.to_app_api_exception() + if not user.is_active: + raise AppApiException(1005, _("The user has been disabled, please contact the administrator!")) + return user + class RegisterSerializer(ApiMixin, serializers.Serializer): """ diff --git a/apps/users/urls.py b/apps/users/urls.py index e5e2fe0dfb2..a9d1e134c90 100644 --- a/apps/users/urls.py +++ b/apps/users/urls.py @@ -6,6 +6,7 @@ urlpatterns = [ path('profile', views.Profile.as_view()), path('user', views.User.as_view(), name="profile"), + path('user/captcha', views.CaptchaView.as_view(), name='captcha'), path('user/language', views.SwitchUserLanguageView.as_view(), name='language'), path('user/list', views.User.Query.as_view()), path('user/login', views.Login.as_view(), name='login'), diff --git a/apps/users/views/user.py b/apps/users/views/user.py index 55d4b6b9ad9..c77dce5bbd1 100644 --- a/apps/users/views/user.py +++ b/apps/users/views/user.py @@ -22,11 +22,11 @@ from common.log.log import log from common.response import result from common.util.common import encryption -from smartdoc.settings import JWT_AUTH +from smartdoc.const import CONFIG from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \ RePasswordSerializer, \ SendEmailSerializer, UserProfile, UserSerializer, UserManageSerializer, UserInstanceSerializer, SystemSerializer, \ - SwitchLanguageSerializer + SwitchLanguageSerializer, CaptchaSerializer from users.views.common import get_user_operation_object, get_re_password_details user_cache = cache.caches['user_cache'] @@ -84,7 +84,7 @@ class SwitchUserLanguageView(APIView): description=_("language")), } ), - responses=RePasswordSerializer().get_response_body_api(), + responses=result.get_default_response(), tags=[_("User management")]) @log(menu='User management', operate='Switch Language', get_operation_object=lambda r, k: {'name': r.user.username}) @@ -111,7 +111,7 @@ class ResetCurrentUserPasswordView(APIView): description=_("Password")) } ), - responses=RePasswordSerializer().get_response_body_api(), + responses=result.get_default_response(), tags=[_("User management")]) @log(menu='User management', operate='Modify current user password', get_operation_object=lambda r, k: {'name': r.user.username}, @@ -170,6 +170,18 @@ def _get_details(request): } +class CaptchaView(APIView): + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary=_("Obtain graphical captcha"), + operation_id=_("Obtain graphical captcha"), + responses=CaptchaSerializer().get_response_body_api(), + security=[], + tags=[_("User management")]) + def get(self, request: Request): + return result.success(CaptchaSerializer().generate()) + + class Login(APIView): @action(methods=['POST'], detail=False) @@ -183,11 +195,9 @@ class Login(APIView): get_details=_get_details, get_operation_object=lambda r, k: {'name': r.data.get('username')}) def post(self, request: Request): - login_request = LoginSerializer(data=request.data) - # 校验请求参数 - user = login_request.is_valid(raise_exception=True) - token = login_request.get_user_token() - token_cache.set(token, user, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA']) + user = LoginSerializer().login(request.data) + token = LoginSerializer().get_user_token(user) + token_cache.set(token, user, timeout=CONFIG.get_session_timeout()) return result.success(token) diff --git a/installer/Dockerfile b/installer/Dockerfile index d2c1eefb6fa..81db7241543 100644 --- a/installer/Dockerfile +++ b/installer/Dockerfile @@ -5,7 +5,7 @@ RUN cd ui && \ npm install && \ npm run build && \ rm -rf ./node_modules -FROM ghcr.io/1panel-dev/maxkb-python-pg:python3.11-pg15.8 AS stage-build +FROM ghcr.io/1panel-dev/maxkb-python-pg:python3.11-pg15.14 AS stage-build ARG DEPENDENCIES=" \ python3-pip" @@ -25,11 +25,11 @@ RUN python3 -m venv /opt/py3 && \ pip install poetry==1.8.5 --break-system-packages && \ poetry config virtualenvs.create false && \ . /opt/py3/bin/activate && \ - if [ "$(uname -m)" = "x86_64" ]; then sed -i 's/^torch.*/torch = {version = "^2.6.0+cpu", source = "pytorch"}/g' pyproject.toml; fi && \ + if [ "$(uname -m)" = "x86_64" ]; then sed -i 's/^torch.*/torch = {version = "2.6.0+cpu", source = "pytorch"}/g' pyproject.toml; fi && \ poetry install && \ export MAXKB_CONFIG_TYPE=ENV && python3 /opt/maxkb/app/apps/manage.py compilemessages -FROM ghcr.io/1panel-dev/maxkb-python-pg:python3.11-pg15.8 +FROM ghcr.io/1panel-dev/maxkb-python-pg:python3.11-pg15.14 ARG DOCKER_IMAGE_TAG=dev \ BUILD_AT \ GITHUB_COMMIT @@ -70,7 +70,9 @@ RUN chmod 755 /opt/maxkb/app/installer/run-maxkb.sh && \ useradd --no-create-home --home /opt/maxkb/app/sandbox sandbox -g root && \ chown -R sandbox:root /opt/maxkb/app/sandbox && \ chmod g-x /usr/local/bin/* /usr/bin/* /bin/* /usr/sbin/* /sbin/* /usr/lib/postgresql/15/bin/* && \ - chmod g+x /usr/local/bin/python* + chmod g+xr /usr/bin/ld.so && \ + chmod g+x /usr/local/bin/python* && \ + find /etc/ -type f ! -path '/etc/resolv.conf' ! -path '/etc/hosts' | xargs chmod g-rx EXPOSE 8080 diff --git a/installer/Dockerfile-python-pg b/installer/Dockerfile-python-pg index f871ac4ef4f..eb52eec17fe 100644 --- a/installer/Dockerfile-python-pg +++ b/installer/Dockerfile-python-pg @@ -1,5 +1,5 @@ -FROM python:3.11-slim-bullseye AS python-stage -FROM postgres:15.8-bullseye +FROM python:3.11-slim-trixie AS python-stage +FROM postgres:15.14-trixie ARG DEPENDENCIES=" \ libexpat1-dev \ diff --git a/pyproject.toml b/pyproject.toml index 35d74a52e95..2dc8337a646 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,65 +8,67 @@ package-mode = false [tool.poetry.dependencies] python = ">=3.11,<3.12" -django = "4.2.18" -djangorestframework = "^3.15.2" +django = "4.2.20" +djangorestframework = "3.16.0" drf-yasg = "1.21.7" django-filter = "23.2" -langchain-openai = "^0.3.0" -langchain-anthropic = "^0.3.0" -langchain-community = "^0.3.0" -langchain-deepseek = "^0.1.0" -langchain-google-genai = "^2.0.9" -langchain-mcp-adapters = "^0.0.5" -langchain-huggingface = "^0.1.2" -langchain-ollama = "^0.3.0" -langgraph = "^0.3.0" -mcp = "^1.4.1" -psycopg2-binary = "2.9.10" -jieba = "^0.42.1" -diskcache = "^5.6.3" -pillow = "^10.2.0" -filetype = "^1.2.0" +langchain = "0.3.23" +langchain-openai = "0.3.12" +langchain-anthropic = "0.3.12" +langchain-community = "0.3.21" +langchain-deepseek = "0.1.3" +langchain-google-genai = "2.1.2" +langchain-mcp-adapters = "0.0.11" +langchain-huggingface = "0.1.2" +langchain-ollama = "0.3.2" +langgraph = "0.3.27" +mcp = "1.8.0" +psycopg = { extras = ["binary"], version = "3.2.9" } +jieba = "0.42.1" +diskcache = "5.6.3" +pillow = "10.4.0" +filetype = "1.2.0" torch = "2.6.0" -sentence-transformers = "^4.0.2" -openai = "^1.13.3" -tiktoken = "^0.7.0" -qianfan = "^0.3.6.1" -pycryptodome = "^3.19.0" -beautifulsoup4 = "^4.12.2" -html2text = "^2024.2.26" -django-ipware = "^6.0.4" -django-apscheduler = "^0.6.2" +sentence-transformers = "4.0.2" +openai = "1.72.0" +tiktoken = "0.7.0" +qianfan = "0.3.18" +pycryptodome = "3.22.0" +beautifulsoup4 = "4.13.3" +html2text = "2024.2.26" +django-ipware = "6.0.5" +django-apscheduler = "0.6.2" pymupdf = "1.24.9" -pypdf = "4.3.1" +pypdf = "6.0.0" rapidocr-onnxruntime = "1.3.24" -python-docx = "^1.1.0" -xlwt = "^1.3.0" -dashscope = "^1.17.0" -zhipuai = "^2.0.1" -httpx = "^0.27.0" -httpx-sse = "^0.4.0" -websockets = "^13.0" -openpyxl = "^3.1.2" -xlrd = "^2.0.1" -gunicorn = "^23.0.0" +python-docx = "1.1.2" +xlwt = "1.3.0" +dashscope = "1.23.1" +zhipuai = "2.1.5.20250410" +httpx = "0.27.2" +httpx-sse = "0.4.0" +websockets = "13.1" +openpyxl = "3.1.5" +xlrd = "2.0.1" +gunicorn = "23.0.0" python-daemon = "3.0.1" -boto3 = "^1.34.160" -tencentcloud-sdk-python = "^3.0.1209" -xinference-client = "^1.3.0" -psutil = "^6.0.0" -celery = { extras = ["sqlalchemy"], version = "^5.4.0" } -django-celery-beat = "^2.6.0" -celery-once = "^3.0.1" -anthropic = "^0.49.0" -pylint = "3.1.0" -pydub = "^0.25.1" -cffi = "^1.17.1" -pysilk = "^0.0.1" -django-db-connection-pool = "^1.2.5" -opencv-python-headless = "^4.11.0.86" -pymysql = "^1.1.1" -accelerate = "^1.6.0" +boto3 = "1.37.31" +tencentcloud-sdk-python = "3.0.1357" +xinference-client = "1.4.1" +psutil = "6.1.1" +celery = { extras = ["sqlalchemy"], version = "5.5.1" } +django-celery-beat = "2.7.0" +celery-once = "3.0.1" +anthropic = "0.49.0" +pylint = "3.3.6" +pydub = "0.25.1" +cffi = "1.17.1" +pysilk = "0.0.1" +django-db-connection-pool = "1.2.6" +opencv-python-headless = "4.11.0.86" +pymysql = "1.1.1" +accelerate = "1.6.0" +captcha = "0.7.1" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/ui/package.json b/ui/package.json index cee7a41c8fd..32adfd27108 100644 --- a/ui/package.json +++ b/ui/package.json @@ -27,17 +27,19 @@ "cropperjs": "^1.6.2", "dingtalk-jsapi": "^2.15.6", "echarts": "^5.5.0", - "element-plus": "^2.9.1", + "element-plus": "^2.13.5", "file-saver": "^2.0.5", "highlight.js": "^11.9.0", "install": "^0.13.0", "katex": "^0.16.10", "lodash": "^4.17.21", "marked": "^12.0.2", - "md-editor-v3": "^4.16.7", + "md-editor-v3": "^5.8.2", "mermaid": "^10.9.0", "mitt": "^3.0.0", "moment": "^2.30.1", + "nanoid": "^5.1.5", + "node-forge": "^1.3.1", "npm": "^10.2.4", "nprogress": "^0.2.0", "pinia": "^2.1.6", @@ -53,8 +55,7 @@ "vue-draggable-plus": "^0.6.0", "vue-i18n": "^9.13.1", "vue-router": "^4.2.4", - "vue3-menus": "^1.1.2", - "vuedraggable": "^4.1.0" + "vue3-menus": "^1.1.2" }, "devDependencies": { "@rushstack/eslint-patch": "^1.3.2", @@ -62,6 +63,7 @@ "@types/file-saver": "^2.0.7", "@types/jsdom": "^21.1.1", "@types/node": "^18.17.5", + "@types/node-forge": "^1.3.14", "@types/nprogress": "^0.2.0", "@vitejs/plugin-vue": "^4.3.1", "@vue/eslint-config-prettier": "^8.0.0", diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index efd4a4985a8..bc903c957eb 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -227,7 +227,7 @@ const getApplicationHitTest: ( data: any, loading?: Ref ) => Promise>> = (application_id, data, loading) => { - return get(`${prefix}/${application_id}/hit_test`, data, loading) + return put(`${prefix}/${application_id}/hit_test`, data, undefined, loading) } /** diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index a5a663b03c7..83de865b3bc 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -186,7 +186,7 @@ const getDatasetHitTest: ( data: any, loading?: Ref ) => Promise>> = (dataset_id, data, loading) => { - return get(`${prefix}/${dataset_id}/hit_test`, data, loading) + return put(`${prefix}/${dataset_id}/hit_test`, data, undefined, loading) } /** diff --git a/ui/src/api/team.ts b/ui/src/api/team.ts index 82e8f986e46..462534b0eba 100644 --- a/ui/src/api/team.ts +++ b/ui/src/api/team.ts @@ -36,7 +36,7 @@ const getMemberPermissions: (member_id: String) => Promise> = (membe } /** - * 获取成员权限 + * 修改成员权限 * @param 参数 member_id * @param 参数 { "team_member_permission_list": [ diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index 077e230973e..c423f11105a 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -72,6 +72,7 @@ interface chatType { document_list: Array image_list: Array audio_list: Array + other_list: Array } } diff --git a/ui/src/api/type/user.ts b/ui/src/api/type/user.ts index a452673546a..197dba888c7 100644 --- a/ui/src/api/type/user.ts +++ b/ui/src/api/type/user.ts @@ -37,6 +37,11 @@ interface LoginRequest { * 密码 */ password: string + /** + * 验证码 + */ + captcha: string + encryptedData?: string } interface RegisterRequest { diff --git a/ui/src/api/user.ts b/ui/src/api/user.ts index eb12fd2ebf8..0d669705442 100644 --- a/ui/src/api/user.ts +++ b/ui/src/api/user.ts @@ -10,22 +10,26 @@ import type { } from '@/api/type/user' import type { Ref } from 'vue' + +const login: (request: LoginRequest, loading?: Ref) => Promise> = ( + request, + loading +) => { + return post('/user/login', request, undefined, loading) +} + +const ldapLogin: (request: LoginRequest, loading?: Ref) => Promise> = ( + request, + loading +) => { + return post('/LDAP/login', request, undefined, loading) +} /** - * 登录 - * @param auth_type - * @param request 登录接口请求表单 - * @param loading 接口加载器 - * @returns 认证数据 + * 获取图形验证码 + * @returns */ -const login: ( - auth_type: string, - request: LoginRequest, - loading?: Ref -) => Promise> = (auth_type, request, loading) => { - if (auth_type !== '') { - return post(`/${auth_type}/login`, request, undefined, loading) - } - return post('/user/login', request, undefined, loading) +const getCaptcha: () => Promise> = () => { + return get('user/captcha') } /** * 登出 @@ -226,5 +230,7 @@ export default { postLanguage, getDingOauth2Callback, getlarkCallback, - getQrSource + getQrSource, + getCaptcha, + ldapLogin } diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index 0f2296439ae..98ec1e6fa3e 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -125,6 +125,28 @@ +
+

+ {{ $t('common.fileUpload.document') }}: +

+ + + + +
diff --git a/ui/src/components/ai-chat/component/answer-content/index.vue b/ui/src/components/ai-chat/component/answer-content/index.vue index 7f09fa04c68..26cd8a0d06f 100644 --- a/ui/src/components/ai-chat/component/answer-content/index.vue +++ b/ui/src/components/ai-chat/component/answer-content/index.vue @@ -80,7 +80,7 @@ const props = defineProps<{ chatRecord: chatType application: any loading: boolean - sendMessage: (question: string, other_params_data?: any, chat?: chatType) => void + sendMessage: (question: string, other_params_data?: any, chat?: chatType) => Promise chatManagement: any type: 'log' | 'ai-chat' | 'debug-ai-chat' }>() @@ -98,9 +98,10 @@ const showUserAvatar = computed(() => { const chatMessage = (question: string, type: 'old' | 'new', other_params_data?: any) => { if (type === 'old') { add_answer_text_list(props.chatRecord.answer_text_list) - props.sendMessage(question, other_params_data, props.chatRecord) - props.chatManagement.open(props.chatRecord.id) - props.chatManagement.write(props.chatRecord.id) + props.sendMessage(question, other_params_data, props.chatRecord).then(() => { + props.chatManagement.open(props.chatRecord.id) + props.chatManagement.write(props.chatRecord.id) + }) } else { props.sendMessage(question, other_params_data) } diff --git a/ui/src/components/ai-chat/component/chat-input-operate/index.vue b/ui/src/components/ai-chat/component/chat-input-operate/index.vue index acf3085ed97..a2f95812365 100644 --- a/ui/src/components/ai-chat/component/chat-input-operate/index.vue +++ b/ui/src/components/ai-chat/component/chat-input-operate/index.vue @@ -10,7 +10,8 @@ uploadDocumentList.length || uploadImageList.length || uploadAudioList.length || - uploadVideoList.length + uploadVideoList.length || + uploadOtherList.length " > @@ -30,22 +31,62 @@ class="file cursor" >
+
+ +
+ {{ item && item?.name }} +
+
- +
- -
- {{ item && item?.name }} +
+ + + + +
+
+ +
+ {{ item && item?.name }} +
+
+
+ + +
@@ -63,23 +104,25 @@ >
+
+ +
+ {{ item && item?.name }} +
+
- +
- -
- {{ item && item?.name }} -
@@ -87,7 +130,7 @@ - + @@ -221,11 +268,11 @@ - - + +
@@ -241,7 +288,7 @@ diff --git a/ui/src/components/model-select/index.vue b/ui/src/components/model-select/index.vue index 116824e3c63..6a3b63acec1 100644 --- a/ui/src/components/model-select/index.vue +++ b/ui/src/components/model-select/index.vue @@ -72,7 +72,7 @@ @@ -82,8 +82,6 @@ import type { Provider } from '@/api/type/model' import { relatedObject } from '@/utils/utils' import CreateModelDialog from '@/views/template/component/CreateModelDialog.vue' import SelectProviderDialog from '@/views/template/component/SelectProviderDialog.vue' - -import { t } from '@/locales' import useStore from '@/stores' defineOptions({ name: 'ModelSelect' }) diff --git a/ui/src/layout/components/breadcrumb/index.vue b/ui/src/layout/components/breadcrumb/index.vue index 9140e8209a8..ae5011318d6 100644 --- a/ui/src/layout/components/breadcrumb/index.vue +++ b/ui/src/layout/components/breadcrumb/index.vue @@ -228,6 +228,11 @@ function getApplication() { } function refresh() { common.saveBreadcrumb(null) + if (isDataset.value) { + getDataset() + } else if (isApplication.value) { + getApplication() + } } onMounted(() => { if (!breadcrumbData.value) { diff --git a/ui/src/locales/lang/en-US/ai-chat.ts b/ui/src/locales/lang/en-US/ai-chat.ts index 3a52270977c..a857836869b 100644 --- a/ui/src/locales/lang/en-US/ai-chat.ts +++ b/ui/src/locales/lang/en-US/ai-chat.ts @@ -63,6 +63,9 @@ export default { limitMessage2: 'files', sizeLimit: 'Each file must not exceed', imageMessage: 'Please process the image content', + documentMessage: 'Please understand the content of the document', + audioMessage: 'Please understand the audio content', + otherMessage: 'Please understand the file content', errorMessage: 'Upload Failed' }, executionDetails: { diff --git a/ui/src/locales/lang/en-US/common.ts b/ui/src/locales/lang/en-US/common.ts index 96afd9916da..2fd0a30b32d 100644 --- a/ui/src/locales/lang/en-US/common.ts +++ b/ui/src/locales/lang/en-US/common.ts @@ -45,7 +45,10 @@ export default { document: 'Documents', image: 'Image', audio: 'Audio', - video: 'Video' + video: 'Video', + other: 'Other', + addExtensions: 'Add suffix', + existingExtensionsTip: 'File suffix already exists', }, status: { label: 'Status', @@ -55,7 +58,7 @@ export default { param: { outputParam: 'Output Parameters', inputParam: 'Input Parameters', - initParam: 'Startup Parameters', + initParam: 'Startup Parameters' }, inputPlaceholder: 'Please input', diff --git a/ui/src/locales/lang/en-US/views/application-workflow.ts b/ui/src/locales/lang/en-US/views/application-workflow.ts index e4385ea3791..e1fd8009e68 100644 --- a/ui/src/locales/lang/en-US/views/application-workflow.ts +++ b/ui/src/locales/lang/en-US/views/application-workflow.ts @@ -104,7 +104,8 @@ export default { label: 'File types allowed for upload', documentText: 'Requires "Document Content Extraction" node to parse document content', imageText: 'Requires "Image Understanding" node to parse image content', - audioText: 'Requires "Speech-to-Text" node to parse audio content' + audioText: 'Requires "Speech-to-Text" node to parse audio content', + otherText: 'Need to parse this type of file by yourself' } } }, @@ -222,14 +223,14 @@ export default { }, mcpNode: { label: 'MCP Server', - text: 'Call MCP Tools through SSE', + text: 'Call MCP Tools through SSE/Streamable HTTP', getToolsSuccess: 'Get Tools Successfully', getTool: 'Get Tools', tool: 'Tool', toolParam: 'Tool Params', mcpServerTip: 'Please enter the JSON format of the MCP server config', mcpToolTip: 'Please select a tool', - configLabel: 'MCP Server Config (Only supports SSE call method)' + configLabel: 'MCP Server Config (Only supports SSE/Streamable HTTP call method)' }, imageGenerateNode: { label: 'Image Generation', diff --git a/ui/src/locales/lang/en-US/views/application.ts b/ui/src/locales/lang/en-US/views/application.ts index b69ede6d890..8247e48ec6b 100644 --- a/ui/src/locales/lang/en-US/views/application.ts +++ b/ui/src/locales/lang/en-US/views/application.ts @@ -139,7 +139,7 @@ Response requirements: hybridSearch: 'Hybrid Search', hybridSearchTooltip: 'Hybrid search is a retrieval method based on both vector and text similarity, suitable for medium data volumes in the knowledge.', - similarityThreshold: 'Similarity higher than', + similarityThreshold: 'Similarity not lower than', similarityTooltip: 'The higher the similarity, the stronger the correlation.', topReferences: 'Top N Segments', maxCharacters: 'Maximum Characters per Reference', diff --git a/ui/src/locales/lang/en-US/views/document.ts b/ui/src/locales/lang/en-US/views/document.ts index 9a3f1da7387..ff17a61340c 100644 --- a/ui/src/locales/lang/en-US/views/document.ts +++ b/ui/src/locales/lang/en-US/views/document.ts @@ -149,7 +149,7 @@ export default { tooltip: 'When user asks a question, handle matched segments according to the set method.' }, similarity: { - label: 'Similarity Higher Than', + label: 'Similarity not lower than', placeholder: 'Directly return segment content', requiredMessage: 'Please enter similarity value' } diff --git a/ui/src/locales/lang/en-US/views/system.ts b/ui/src/locales/lang/en-US/views/system.ts index 303d1175dcf..8d3e50ad74b 100644 --- a/ui/src/locales/lang/en-US/views/system.ts +++ b/ui/src/locales/lang/en-US/views/system.ts @@ -1,5 +1,6 @@ export default { title: 'System', + subTitle: 'Setting', test: 'Test Connection', testSuccess: 'Successful', testFailed: 'Test connection failed', @@ -76,8 +77,8 @@ export default { dingtalk: 'DingTalk', lark: 'Lark', effective: 'Effective', - alreadyTurnedOn: 'Turned On', - notEnabled: 'Not Enabled', + alreadyTurnedOn: 'Enabled', + notEnabled: 'Disabled', validate: 'Validate', validateSuccess: 'Successful', validateFailed: 'Validation failed', @@ -122,7 +123,7 @@ export default { websiteSlogan: 'Welcome Slogan', websiteSloganPlaceholder: 'Please enter the welcome slogan', websiteSloganTip: 'The welcome slogan below the product logo', - defaultSlogan: 'Ready-to-use, flexible RAG Chatbot', + defaultSlogan: 'Ready-to-use open-source AI assistant', defaultTip: 'The default is the MaxKB platform interface, supports custom settings', logoDefaultTip: 'The default is the MaxKB login interface, supports custom settings', platformSetting: 'Platform Settings', diff --git a/ui/src/locales/lang/en-US/views/user.ts b/ui/src/locales/lang/en-US/views/user.ts index ae41fd564c0..2bbc1404363 100644 --- a/ui/src/locales/lang/en-US/views/user.ts +++ b/ui/src/locales/lang/en-US/views/user.ts @@ -28,6 +28,10 @@ export default { requiredMessage: 'Please enter username', lengthMessage: 'Length must be between 6 and 20 words' }, + captcha: { + label: 'captcha', + placeholder: 'Please enter the captcha' + }, nick_name: { label: 'Name', placeholder: 'Please enter name' diff --git a/ui/src/locales/lang/zh-CN/ai-chat.ts b/ui/src/locales/lang/zh-CN/ai-chat.ts index 76bb53d4f53..de702d2347c 100644 --- a/ui/src/locales/lang/zh-CN/ai-chat.ts +++ b/ui/src/locales/lang/zh-CN/ai-chat.ts @@ -61,6 +61,9 @@ export default { limitMessage2: '个文件', sizeLimit: '单个文件大小不能超过', imageMessage: '请解析图片内容', + documentMessage: '请理解文档内容', + audioMessage: '请理解音频内容', + otherMessage: '请理解文件内容', errorMessage: '上传失败' }, executionDetails: { diff --git a/ui/src/locales/lang/zh-CN/common.ts b/ui/src/locales/lang/zh-CN/common.ts index 97e25b704cf..db1e7e7318a 100644 --- a/ui/src/locales/lang/zh-CN/common.ts +++ b/ui/src/locales/lang/zh-CN/common.ts @@ -45,7 +45,10 @@ export default { document: '文档', image: '图片', audio: '音频', - video: '视频' + video: '视频', + other: '其他文件', + addExtensions: '添加后缀名', + existingExtensionsTip: '文件后缀已存在', }, status: { label: '状态', diff --git a/ui/src/locales/lang/zh-CN/views/application-workflow.ts b/ui/src/locales/lang/zh-CN/views/application-workflow.ts index 4c5a19d7686..c7c6038cc5f 100644 --- a/ui/src/locales/lang/zh-CN/views/application-workflow.ts +++ b/ui/src/locales/lang/zh-CN/views/application-workflow.ts @@ -105,8 +105,10 @@ export default { label: '上传的文件类型', documentText: '需要使用“文档内容提取”节点解析文档内容', imageText: '需要使用“视觉模型”节点解析图片内容', - audioText: '需要使用“语音转文本”节点解析音频内容' - } + audioText: '需要使用“语音转文本”节点解析音频内容', + otherText: '需要自行解析该类型文件' + }, + } }, aiChatNode: { @@ -222,14 +224,14 @@ export default { }, mcpNode: { label: 'MCP 调用', - text: '通过SSE方式执行MCP服务中的工具', + text: '通过SSE/Streamable HTTP方式执行MCP服务中的工具', getToolsSuccess: '获取工具成功', getTool: '获取工具', tool: '工具', toolParam: '工具参数', mcpServerTip: '请输入JSON格式的MCP服务器配置', mcpToolTip: '请选择工具', - configLabel: 'MCP Server Config (仅支持SSE调用方式)' + configLabel: 'MCP Server Config (仅支持SSE/Streamable HTTP调用方式)' }, imageGenerateNode: { label: '图片生成', @@ -264,7 +266,7 @@ export default { label: '文本转语音', text: '将文本通过语音合成模型转换为音频', tts_model: { - label: '语音识别模型' + label: '语音合成模型' }, content: { label: '选择文本内容' diff --git a/ui/src/locales/lang/zh-CN/views/application.ts b/ui/src/locales/lang/zh-CN/views/application.ts index dc9b16216bc..99db3f1f5cd 100644 --- a/ui/src/locales/lang/zh-CN/views/application.ts +++ b/ui/src/locales/lang/zh-CN/views/application.ts @@ -130,7 +130,7 @@ export default { hybridSearch: '混合检索', hybridSearchTooltip: '混合检索是一种基于向量和文本相似度的检索方式,适用于知识库中的中等数据量场景。', - similarityThreshold: '相似度高于', + similarityThreshold: '相似度不低于', similarityTooltip: '相似度越高相关性越强。', topReferences: '引用分段数 TOP', maxCharacters: '最多引用字符数', diff --git a/ui/src/locales/lang/zh-CN/views/document.ts b/ui/src/locales/lang/zh-CN/views/document.ts index bfdf2907ea4..0f5b03ba8b3 100644 --- a/ui/src/locales/lang/zh-CN/views/document.ts +++ b/ui/src/locales/lang/zh-CN/views/document.ts @@ -147,7 +147,7 @@ export default { tooltip: '用户提问时,命中文档下的分段时按照设置的方式进行处理。' }, similarity: { - label: '相似度高于', + label: '相似度不低于', placeholder: '直接返回分段内容', requiredMessage: '请输入相似度' } diff --git a/ui/src/locales/lang/zh-CN/views/system.ts b/ui/src/locales/lang/zh-CN/views/system.ts index 9ce23d90d86..72624d26a48 100644 --- a/ui/src/locales/lang/zh-CN/views/system.ts +++ b/ui/src/locales/lang/zh-CN/views/system.ts @@ -1,5 +1,6 @@ export default { - title: '系统设置', + title: '系统管理', + subTitle: '系统设置', test: '测试连接', testSuccess: '测试连接成功', testFailed: '测试连接失败', @@ -120,7 +121,7 @@ export default { websiteSlogan: '欢迎语', websiteSloganPlaceholder: '请输入欢迎语', websiteSloganTip: '产品 Logo 下的欢迎语', - defaultSlogan: '欢迎使用 MaxKB 智能知识库问答系统', + defaultSlogan: '欢迎使用 MaxKB 开源 AI 助手', logoDefaultTip: '默认为 MaxKB 登录界面,支持自定义设置', defaultTip: '默认为 MaxKB 平台界面,支持自定义设置', platformSetting: '平台设置', diff --git a/ui/src/locales/lang/zh-CN/views/user.ts b/ui/src/locales/lang/zh-CN/views/user.ts index 4e2a8760f92..191074c0c06 100644 --- a/ui/src/locales/lang/zh-CN/views/user.ts +++ b/ui/src/locales/lang/zh-CN/views/user.ts @@ -25,6 +25,10 @@ export default { requiredMessage: '请输入用户名', lengthMessage: '长度在 6 到 20 个字符' }, + captcha: { + label: '验证码', + placeholder: '请输入验证码' + }, nick_name: { label: '姓名', placeholder: '请输入姓名' @@ -33,7 +37,7 @@ export default { label: '邮箱', placeholder: '请输入邮箱', requiredMessage: '请输入邮箱', - validatorEmail: '请输入有效邮箱格式!', + validatorEmail: '请输入有效邮箱格式!' }, phone: { label: '手机号', @@ -48,13 +52,13 @@ export default { new_password: { label: '新密码', placeholder: '请输入新密码', - requiredMessage: '请输入新密码', + requiredMessage: '请输入新密码' }, re_password: { label: '确认密码', placeholder: '请输入确认密码', requiredMessage: '请输入确认密码', - validatorMessage: '密码不一致', + validatorMessage: '密码不一致' } } }, diff --git a/ui/src/locales/lang/zh-Hant/ai-chat.ts b/ui/src/locales/lang/zh-Hant/ai-chat.ts index 75f9949a6dc..1c717335574 100644 --- a/ui/src/locales/lang/zh-Hant/ai-chat.ts +++ b/ui/src/locales/lang/zh-Hant/ai-chat.ts @@ -61,6 +61,9 @@ export default { limitMessage2: '個文件', sizeLimit: '單個文件大小不能超過', imageMessage: '請解析圖片內容', + documentMessage: '請理解檔案內容', + audioMessage: '請理解音訊內容', + otherMessage: '請理解檔案內容', errorMessage: '上傳失敗' }, executionDetails: { diff --git a/ui/src/locales/lang/zh-Hant/common.ts b/ui/src/locales/lang/zh-Hant/common.ts index 0ccbb5c1159..8e6293076c9 100644 --- a/ui/src/locales/lang/zh-Hant/common.ts +++ b/ui/src/locales/lang/zh-Hant/common.ts @@ -45,7 +45,10 @@ export default { document: '文檔', image: '圖片', audio: '音頻', - video: '視頻' + video: '視頻', + other: '其他文件', + addExtensions: '添加後綴名', + existingExtensionsTip: '文件後綴已存在', }, status: { label: '狀態', diff --git a/ui/src/locales/lang/zh-Hant/views/application-workflow.ts b/ui/src/locales/lang/zh-Hant/views/application-workflow.ts index 60269c021b2..f3a0a0c1e7d 100644 --- a/ui/src/locales/lang/zh-Hant/views/application-workflow.ts +++ b/ui/src/locales/lang/zh-Hant/views/application-workflow.ts @@ -105,7 +105,8 @@ export default { label: '上傳的文件類型', documentText: '需要使用「文檔內容提取」節點解析文檔內容', imageText: '需要使用「圖片理解」節點解析圖片內容', - audioText: '需要使用「語音轉文本」節點解析音頻內容' + audioText: '需要使用「語音轉文本」節點解析音頻內容', + otherText: '需要自行解析該類型文件' } } }, @@ -207,8 +208,8 @@ export default { text: '識別出圖片中的物件、場景等信息回答用戶問題', answer: 'AI 回答內容', model: { - label: '圖片理解模型', - requiredMessage: '請選擇圖片理解模型' + label: '視覺模型', + requiredMessage: '請選擇視覺模型' }, image: { label: '選擇圖片', @@ -222,14 +223,14 @@ export default { }, mcpNode: { label: 'MCP 調用', - text: '透過SSE方式執行MCP服務中的工具', + text: '透過SSE/Streamable HTTP方式執行MCP服務中的工具', getToolsSuccess: '獲取工具成功', getTool: '獲取工具', tool: '工具', toolParam: '工具變數', mcpServerTip: '請輸入JSON格式的MCP服務器配置', mcpToolTip: '請選擇工具', - configLabel: 'MCP Server Config (僅支持SSE調用方式)' + configLabel: 'MCP Server Config (僅支持SSE/Streamable HTTP調用方式)' }, imageGenerateNode: { label: '圖片生成', diff --git a/ui/src/locales/lang/zh-Hant/views/application.ts b/ui/src/locales/lang/zh-Hant/views/application.ts index 3b6f1756ed7..d0df9b6b906 100644 --- a/ui/src/locales/lang/zh-Hant/views/application.ts +++ b/ui/src/locales/lang/zh-Hant/views/application.ts @@ -129,7 +129,7 @@ export default { hybridSearch: '混合檢索', hybridSearchTooltip: '混合檢索是一種基於向量和文本相似度的檢索方式,適用於知識庫中的中等數據量場景。', - similarityThreshold: '相似度高於', + similarityThreshold: '相似度不低於', similarityTooltip: '相似度越高相關性越強。', topReferences: '引用分段數 TOP', maxCharacters: '最多引用字元數', diff --git a/ui/src/locales/lang/zh-Hant/views/document.ts b/ui/src/locales/lang/zh-Hant/views/document.ts index adfc8cc463b..d8406908e6a 100644 --- a/ui/src/locales/lang/zh-Hant/views/document.ts +++ b/ui/src/locales/lang/zh-Hant/views/document.ts @@ -146,7 +146,7 @@ export default { tooltip: '用戶提問時,命中文檔下的分段時按照設置的方式進行處理。' }, similarity: { - label: '相似度高于', + label: '相似度不低於', placeholder: '直接返回分段内容', requiredMessage: '请输入相似度' } diff --git a/ui/src/locales/lang/zh-Hant/views/system.ts b/ui/src/locales/lang/zh-Hant/views/system.ts index 10259390be1..1e33f22fb33 100644 --- a/ui/src/locales/lang/zh-Hant/views/system.ts +++ b/ui/src/locales/lang/zh-Hant/views/system.ts @@ -1,5 +1,6 @@ export default { - title: '系統設置', + title: '系統管理', + subTitle: '系統設置', test: '測試連線', testSuccess: '測試連線成功', testFailed: '測試連線失敗', @@ -122,7 +123,7 @@ export default { websiteSloganPlaceholder: '請輸入歡迎語', websiteSloganTip: '產品 Logo 下的歡迎語', logoDefaultTip: '默认为 MaxKB 登錄界面,支持自定义设置', - defaultSlogan: '歡迎使用 MaxKB 智能知識庫問答系統', + defaultSlogan: '歡迎使用 MaxKB 開源 AI 助手', defaultTip: '默認為 MaxKB 平台界面,支持自定義設置', platformSetting: '平台設置', showUserManual: '顯示用戶手冊', diff --git a/ui/src/locales/lang/zh-Hant/views/template.ts b/ui/src/locales/lang/zh-Hant/views/template.ts index 241f9d8c516..05f24fed575 100644 --- a/ui/src/locales/lang/zh-Hant/views/template.ts +++ b/ui/src/locales/lang/zh-Hant/views/template.ts @@ -30,7 +30,7 @@ export default { RERANKER: '重排模型', STT: '語音辨識', TTS: '語音合成', - IMAGE: '圖片理解', + IMAGE: '視覺模型', TTI: '圖片生成' }, templateForm: { diff --git a/ui/src/locales/lang/zh-Hant/views/user.ts b/ui/src/locales/lang/zh-Hant/views/user.ts index 18ea3326acf..7b8f1a88000 100644 --- a/ui/src/locales/lang/zh-Hant/views/user.ts +++ b/ui/src/locales/lang/zh-Hant/views/user.ts @@ -26,6 +26,10 @@ export default { requiredMessage: '請輸入使用者名稱', lengthMessage: '長度須介於 6 到 20 個字元之間' }, + captcha: { + label: '驗證碼', + placeholder: '請輸入驗證碼' + }, nick_name: { label: '姓名', placeholder: '請輸入姓名' diff --git a/ui/src/request/index.ts b/ui/src/request/index.ts index 72588d2c6f2..a9f490149bf 100644 --- a/ui/src/request/index.ts +++ b/ui/src/request/index.ts @@ -11,7 +11,7 @@ import { ref, type WritableComputedRef } from 'vue' const axiosConfig = { baseURL: '/api', withCredentials: false, - timeout: 600000, + timeout: 1800000, headers: {} } diff --git a/ui/src/router/modules/setting.ts b/ui/src/router/modules/setting.ts index e97a658b02b..eaedb6a5f50 100644 --- a/ui/src/router/modules/setting.ts +++ b/ui/src/router/modules/setting.ts @@ -59,7 +59,7 @@ const settingRouter = { meta: { icon: 'app-setting', iconActive: 'app-setting-active', - title: 'common.setting', + title: 'views.system.subTitle', activeMenu: '/setting', parentPath: '/setting', parentName: 'setting', diff --git a/ui/src/stores/modules/user.ts b/ui/src/stores/modules/user.ts index b065d7596a5..c805715a662 100644 --- a/ui/src/stores/modules/user.ts +++ b/ui/src/stores/modules/user.ts @@ -8,6 +8,7 @@ import { useElementPlusTheme } from 'use-element-plus-theme' import { defaultPlatformSetting } from '@/utils/theme' import { useLocalStorage } from '@vueuse/core' import { localeConfigKey, getBrowserLang } from '@/locales/index' + export interface userStateTypes { userType: number // 1 系统操作者 2 对话用户 userInfo: User | null @@ -17,6 +18,7 @@ export interface userStateTypes { XPACK_LICENSE_IS_VALID: false isXPack: false themeInfo: any + rasKey: string } const useUserStore = defineStore({ @@ -29,7 +31,8 @@ const useUserStore = defineStore({ userAccessToken: '', XPACK_LICENSE_IS_VALID: false, isXPack: false, - themeInfo: null + themeInfo: null, + rasKey: '' }), actions: { getLanguage() { @@ -65,7 +68,7 @@ const useUserStore = defineStore({ if (token) { return token } - const local_token = localStorage.getItem(`${token}-accessToken`) + const local_token = localStorage.getItem(`${this.userAccessToken}-accessToken`) if (local_token) { return local_token } @@ -100,6 +103,7 @@ const useUserStore = defineStore({ this.version = ok.data?.version || '-' this.isXPack = ok.data?.IS_XPACK this.XPACK_LICENSE_IS_VALID = ok.data?.XPACK_LICENSE_IS_VALID + this.rasKey = ok.data?.ras || '' if (this.isEnterprise()) { await this.theme() @@ -135,8 +139,15 @@ const useUserStore = defineStore({ }) }, - async login(auth_type: string, username: string, password: string) { - return UserApi.login(auth_type, { username, password }).then((ok) => { + async login(data: any, loading?: Ref) { + return UserApi.login(data).then((ok) => { + this.token = ok.data + localStorage.setItem('token', ok.data) + return this.profile() + }) + }, + async asyncLdapLogin(data: any, loading?: Ref) { + return UserApi.ldapLogin(data).then((ok) => { this.token = ok.data localStorage.setItem('token', ok.data) return this.profile() diff --git a/ui/src/styles/element-plus.scss b/ui/src/styles/element-plus.scss index d1f067b18fd..9e5345cee32 100644 --- a/ui/src/styles/element-plus.scss +++ b/ui/src/styles/element-plus.scss @@ -62,7 +62,7 @@ } .el-form-item__label { font-weight: 400; - width: 100%; + width: 100% !important; } .el-form-item__error { diff --git a/ui/src/styles/md-editor.scss b/ui/src/styles/md-editor.scss index 6b117711412..c60f51f4e96 100644 --- a/ui/src/styles/md-editor.scss +++ b/ui/src/styles/md-editor.scss @@ -6,7 +6,8 @@ padding: 0; margin: 0; font-size: inherit; - table{ + word-break: break-word; + table { display: block; } p { diff --git a/ui/src/utils/utils.ts b/ui/src/utils/utils.ts index 44e68895c7f..7f76a93da3c 100644 --- a/ui/src/utils/utils.ts +++ b/ui/src/utils/utils.ts @@ -1,5 +1,5 @@ import { MsgError } from '@/utils/message' - +import { nanoid } from 'nanoid' export function toThousands(num: any) { return num?.toString().replace(/\d+/, function (n: any) { return n.replace(/(\d)(?=(?:\d{3})+$)/g, '$1,') @@ -25,7 +25,7 @@ export function filesize(size: number) { 随机id */ export const randomId = function () { - return Math.floor(Math.random() * 10000) + '' + return nanoid() } /* @@ -48,7 +48,9 @@ const typeList: any = { export function getImgUrl(name: string) { const list = Object.values(typeList).flat() - const type = list.includes(fileType(name).toLowerCase()) ? fileType(name).toLowerCase() : 'unknown' + const type = list.includes(fileType(name).toLowerCase()) + ? fileType(name).toLowerCase() + : 'unknown' return new URL(`../assets/fileType/${type}-icon.svg`, import.meta.url).href } // 是否是白名单后缀 diff --git a/ui/src/views/application-workflow/index.vue b/ui/src/views/application-workflow/index.vue index f9a30983943..e6a95cbf010 100644 --- a/ui/src/views/application-workflow/index.vue +++ b/ui/src/views/application-workflow/index.vue @@ -3,7 +3,7 @@
-

{{ detail?.name }}

+

{{ detail?.name }}

{{ $t('views.applicationWorkflow.info.previewVersion') }} @@ -101,7 +101,7 @@ />
-

+

{{ detail?.name || $t('views.application.applicationForm.form.appName.label') }}

@@ -279,7 +279,6 @@ async function publicHandle() { return } applicationApi.putPublishApplication(id as String, obj, loading).then(() => { - application.asyncGetApplicationDetail(id, loading).then((res: any) => { detail.value.name = res.data.name MsgSuccess(t('views.applicationWorkflow.tip.publicSuccess')) diff --git a/ui/src/views/application/ApplicationAccess.vue b/ui/src/views/application/ApplicationAccess.vue index ce2fe6aab82..8e1bf03b7e6 100644 --- a/ui/src/views/application/ApplicationAccess.vue +++ b/ui/src/views/application/ApplicationAccess.vue @@ -135,51 +135,4 @@ onMounted(() => { }) - + diff --git a/ui/src/views/application/component/CreateApplicationDialog.vue b/ui/src/views/application/component/CreateApplicationDialog.vue index 438bfe211a9..7415753c1af 100644 --- a/ui/src/views/application/component/CreateApplicationDialog.vue +++ b/ui/src/views/application/component/CreateApplicationDialog.vue @@ -242,6 +242,7 @@ const submitHandle = async (formEl: FormInstance | undefined) => { } applicationApi.postApplication(applicationForm.value, loading).then((res) => { MsgSuccess(t('common.createSuccess')) + emit('refresh') if (isWorkFlow(applicationForm.value.type)) { router.push({ path: `/application/${res.data.id}/workflow` }) } else { diff --git a/ui/src/views/application/component/ParamSettingDialog.vue b/ui/src/views/application/component/ParamSettingDialog.vue index cdae5bf6e85..bd0cb5545c3 100644 --- a/ui/src/views/application/component/ParamSettingDialog.vue +++ b/ui/src/views/application/component/ParamSettingDialog.vue @@ -11,7 +11,7 @@ >
- + { if (!bool) { - form.value = { - dataset_setting: { - search_mode: 'embedding', - top_n: 3, - similarity: 0.6, - max_paragraph_char_number: 5000, - no_references_setting: { - status: 'ai_questioning', - value: '{question}' - } - }, - problem_optimization: false, - problem_optimization_prompt: '' - } + // form.value = { + // dataset_setting: { + // search_mode: 'embedding', + // top_n: 3, + // similarity: 0.6, + // max_paragraph_char_number: 5000, + // no_references_setting: { + // status: 'ai_questioning', + // value: '{question}' + // } + // }, + // problem_optimization: false, + // problem_optimization_prompt: '' + // } noReferencesform.value = { ai_questioning: defaultValue['ai_questioning'], designated_answer: defaultValue['designated_answer'] diff --git a/ui/src/views/authentication/component/OIDC.vue b/ui/src/views/authentication/component/OIDC.vue index 2666bc6479d..d71158b9a8e 100644 --- a/ui/src/views/authentication/component/OIDC.vue +++ b/ui/src/views/authentication/component/OIDC.vue @@ -61,6 +61,15 @@ show-password /> + + + ({ state: '', clientId: '', clientSecret: '', + fieldMapping: '{"username": "preferred_username", "email": "email"}', redirectUrl: '' }, is_active: true @@ -156,6 +166,13 @@ const rules = reactive>({ trigger: 'blur' } ], + 'config_data.fieldMapping': [ + { + required: true, + message: t('views.system.authentication.oauth2.filedMappingPlaceholder'), + trigger: 'blur' + } + ], 'config_data.redirectUrl': [ { required: true, @@ -187,6 +204,12 @@ function getDetail() { authApi.getAuthSetting(form.value.auth_type, loading).then((res: any) => { if (res.data && JSON.stringify(res.data) !== '{}') { form.value = res.data + if ( + form.value.config_data.fieldMapping === '' || + form.value.config_data.fieldMapping === undefined + ) { + form.value.config_data.fieldMapping = '{"username": "preferred_username", "email": "email"}' + } } }) } diff --git a/ui/src/views/chat/base/index.vue b/ui/src/views/chat/base/index.vue index 27be286f25a..7156f7d894a 100644 --- a/ui/src/views/chat/base/index.vue +++ b/ui/src/views/chat/base/index.vue @@ -42,7 +42,6 @@
-
diff --git a/ui/src/views/login/index.vue b/ui/src/views/login/index.vue index 714c439c6bb..d4264406a65 100644 --- a/ui/src/views/login/index.vue +++ b/ui/src/views/login/index.vue @@ -34,6 +34,27 @@ +
+ +
+ + + + +
+
+
(false) const { user } = useStore() const router = useRouter() +import forge from 'node-forge' + const loginForm = ref({ username: '', - password: '' + password: '', + captcha: '', + encryptedData: '' }) +const identifyCode = ref('') + +function makeCode() { + useApi.getCaptcha().then((res: any) => { + identifyCode.value = res.data + }) +} const rules = ref>({ username: [ @@ -137,6 +170,13 @@ const rules = ref>({ message: t('views.user.userForm.form.password.requiredMessage'), trigger: 'blur' } + ], + captcha: [ + { + required: true, + message: t('views.user.userForm.form.captcha.placeholder'), + trigger: 'blur' + } ] }) const loginFormRef = ref() @@ -222,22 +262,43 @@ function changeMode(val: string) { showQrCodeTab.value = false loginForm.value = { username: '', - password: '' + password: '', + captcha: '' } redirectAuth(val) loginFormRef.value?.clearValidate() } const login = () => { - loginFormRef.value?.validate().then(() => { - loading.value = true - user - .login(loginMode.value, loginForm.value.username, loginForm.value.password) - .then(() => { - locale.value = localStorage.getItem('MaxKB-locale') || getBrowserLang() || 'en-US' - router.push({ name: 'home' }) - }) - .finally(() => (loading.value = false)) + if (!loginFormRef.value) { + return + } + loginFormRef.value?.validate((valid) => { + if (valid) { + loading.value = true + if (loginMode.value === 'LDAP') { + user + .asyncLdapLogin(loginForm.value) + .then(() => { + locale.value = localStorage.getItem('MaxKB-locale') || getBrowserLang() || 'en-US' + router.push({ name: 'home' }) + }) + .finally(() => (loading.value = false)) + } else { + const publicKey = forge.pki.publicKeyFromPem(user.rasKey) + const jsonData = JSON.stringify(loginForm.value) + const utf8Bytes = forge.util.encodeUtf8(jsonData) + const encrypted = publicKey.encrypt(utf8Bytes, 'RSAES-PKCS1-V1_5') + const encryptedBase64 = forge.util.encode64(encrypted) + user + .login({ encryptedData: encryptedBase64, username: loginForm.value.username }) + .then(() => { + locale.value = localStorage.getItem('MaxKB-locale') || getBrowserLang() || 'en-US' + router.push({ name: 'home' }) + }) + .finally(() => (loading.value = false)) + } + } }) } @@ -285,6 +346,7 @@ onBeforeMount(() => { declare const window: any onMounted(() => { + makeCode() const route = useRoute() const currentUrl = ref(route.fullPath) const params = new URLSearchParams(currentUrl.value.split('?')[1]) diff --git a/ui/src/views/login/reset-password/index.vue b/ui/src/views/login/reset-password/index.vue index 2c2ff02576e..876afde1470 100644 --- a/ui/src/views/login/reset-password/index.vue +++ b/ui/src/views/login/reset-password/index.vue @@ -1,6 +1,10 @@ - + diff --git a/ui/src/workflow/common/NodeContainer.vue b/ui/src/workflow/common/NodeContainer.vue index b61fe25b61a..990eb96c143 100644 --- a/ui/src/workflow/common/NodeContainer.vue +++ b/ui/src/workflow/common/NodeContainer.vue @@ -301,7 +301,8 @@ function clickNodes(item: any) { type: 'app-edge', sourceNodeId: props.nodeModel.id, sourceAnchorId: anchorData.value?.id, - targetNodeId: nodeModel.id + targetNodeId: nodeModel.id, + targetAnchorId: nodeModel.id + '_left' }) closeNodeMenu() diff --git a/ui/src/workflow/nodes/application-node/index.vue b/ui/src/workflow/nodes/application-node/index.vue index 77bff4ac0ca..4fc9fba5483 100644 --- a/ui/src/workflow/nodes/application-node/index.vue +++ b/ui/src/workflow/nodes/application-node/index.vue @@ -238,7 +238,8 @@ const update_field = () => { const new_user_input_field_list = cloneDeep( ok.data.work_flow.nodes[0].properties.user_input_field_list ) - const merge_api_input_field_list = new_api_input_field_list.map((item: any) => { + + const merge_api_input_field_list = (new_api_input_field_list || []).map((item: any) => { const find_field = old_api_input_field_list.find( (old_item: any) => old_item.variable == item.variable ) @@ -258,7 +259,7 @@ const update_field = () => { 'api_input_field_list', merge_api_input_field_list ) - const merge_user_input_field_list = new_user_input_field_list.map((item: any) => { + const merge_user_input_field_list = (new_user_input_field_list || []).map((item: any) => { const find_field = old_user_input_field_list.find( (old_item: any) => old_item.field == item.field ) @@ -294,6 +295,7 @@ const update_field = () => { } }) .catch((err) => { + console.log(err) set(props.nodeModel.properties, 'status', 500) }) } diff --git a/ui/src/workflow/nodes/base-node/component/ApiInputFieldTable.vue b/ui/src/workflow/nodes/base-node/component/ApiInputFieldTable.vue index c81ebc94f72..b7bed17fe86 100644 --- a/ui/src/workflow/nodes/base-node/component/ApiInputFieldTable.vue +++ b/ui/src/workflow/nodes/base-node/component/ApiInputFieldTable.vue @@ -13,7 +13,7 @@ :data="props.nodeModel.properties.api_input_field_list" class="mb-16" ref="tableRef" - row-key="field" + row-key="variable" >
Feature