From 5ce177551c9097db2a351179fa83a13b52b3713c Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Wed, 20 May 2026 14:41:18 +0800 Subject: [PATCH] fix: clear deleted knowledge bases from sessions --- astrbot/dashboard/routes/knowledge_base.py | 38 +++++++++- tests/test_kb_import.py | 82 ++++++++++++++++++++++ 2 files changed, 119 insertions(+), 1 deletion(-) diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 1b6f7a435d..d5b4079e2e 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -9,7 +9,7 @@ import aiofiles from quart import request -from astrbot.core import logger +from astrbot.core import logger, sp from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider from astrbot.core.utils.astrbot_path import get_astrbot_temp_path @@ -69,6 +69,36 @@ def __init__( def _get_kb_manager(self): return self.core_lifecycle.kb_manager + @staticmethod + async def _remove_kb_from_session_configs(kb_id: str) -> int: + prefs = await sp.session_get(None, "kb_config") + if not isinstance(prefs, list): + return 0 + + updated = 0 + + for pref in prefs: + scope_id = getattr(pref, "scope_id", None) + if not isinstance(scope_id, str): + continue + + value = await sp.session_get(scope_id, "kb_config") + if not isinstance(value, dict): + continue + + kb_ids = value.get("kb_ids") + if not isinstance(kb_ids, list) or kb_id not in kb_ids: + continue + + new_value = { + **value, + "kb_ids": [item for item in kb_ids if item != kb_id], + } + await sp.session_put(scope_id, "kb_config", new_value) + updated += 1 + + return updated + def _init_task(self, task_id: str, status: str = "pending") -> None: self.upload_tasks[task_id] = { "status": status, @@ -569,6 +599,12 @@ async def delete_kb(self): if not success: return Response().error("知识库不存在").__dict__ + updated_sessions = await self._remove_kb_from_session_configs(kb_id) + if updated_sessions: + logger.info( + f"已从 {updated_sessions} 个会话配置中移除已删除知识库 {kb_id}", + ) + return Response().ok(message="删除知识库成功").__dict__ except ValueError as e: diff --git a/tests/test_kb_import.py b/tests/test_kb_import.py index 8795b06da1..03fb6f903f 100644 --- a/tests/test_kb_import.py +++ b/tests/test_kb_import.py @@ -1,10 +1,12 @@ import asyncio +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock import pytest import pytest_asyncio from quart import Quart +import astrbot.dashboard.routes.knowledge_base as knowledge_base_route_module from astrbot.core import LogBroker from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db.sqlite import SQLiteDatabase @@ -117,6 +119,86 @@ async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecyc return {"Authorization": f"Bearer {token}"} +@pytest.mark.asyncio +async def test_remove_deleted_kb_from_session_configs(monkeypatch: pytest.MonkeyPatch): + updates = [] + configs = { + "platform:GroupMessage:group!alice": { + "kb_ids": ["kb-old", "kb-keep"], + "top_k": 3, + }, + "platform:GroupMessage:group!bob": {"kb_ids": ["kb-old"]}, + "platform:FriendMessage:charlie": {"kb_ids": ["kb-keep"]}, + "platform:FriendMessage:broken": {"kb_ids": "kb-old"}, + } + + class FakeSharedPreferences: + async def session_get(self, umo, key): + assert key == "kb_config" + if umo is not None: + return configs.get(umo) + + return [ + SimpleNamespace( + scope_id="platform:GroupMessage:group!alice", + value={"val": configs["platform:GroupMessage:group!alice"]}, + ), + SimpleNamespace( + scope_id="platform:GroupMessage:group!bob", + value={"val": configs["platform:GroupMessage:group!bob"]}, + ), + SimpleNamespace( + scope_id="platform:FriendMessage:charlie", + value={"val": configs["platform:FriendMessage:charlie"]}, + ), + SimpleNamespace( + scope_id="platform:FriendMessage:broken", + value={"val": configs["platform:FriendMessage:broken"]}, + ), + ] + + async def session_put(self, umo, key, value): + updates.append((umo, key, value)) + + monkeypatch.setattr(knowledge_base_route_module, "sp", FakeSharedPreferences()) + + updated = await KnowledgeBaseRoute._remove_kb_from_session_configs("kb-old") + + assert updated == 2 + assert updates == [ + ( + "platform:GroupMessage:group!alice", + "kb_config", + {"kb_ids": ["kb-keep"], "top_k": 3}, + ), + ( + "platform:GroupMessage:group!bob", + "kb_config", + {"kb_ids": []}, + ), + ] + + +@pytest.mark.asyncio +async def test_remove_deleted_kb_ignores_missing_session_list( + monkeypatch: pytest.MonkeyPatch, +): + class FakeSharedPreferences: + async def session_get(self, umo, key): + assert umo is None + assert key == "kb_config" + return None + + async def session_put(self, umo, key, value): + raise AssertionError("session_put should not be called") + + monkeypatch.setattr(knowledge_base_route_module, "sp", FakeSharedPreferences()) + + updated = await KnowledgeBaseRoute._remove_kb_from_session_configs("kb-old") + + assert updated == 0 + + @pytest.mark.asyncio async def test_import_documents( app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle