From 51e7247aa9ca24bef95bbaa71636b8a1426cd345 Mon Sep 17 00:00:00 2001 From: ulleo Date: Mon, 11 May 2026 15:06:32 +0800 Subject: [PATCH] fix: Fix the data source-related issues in the PR --- backend/apps/chat/task/llm.py | 19 +++--- backend/apps/datasource/crud/datasource.py | 22 +++++-- .../apps/datasource/embedding/ds_embedding.py | 2 +- .../datasource/embedding/table_embedding.py | 4 +- backend/apps/system/crud/assistant.py | 4 +- .../chat/execution-component/LogWithAi.vue | 64 +++++++++---------- 6 files changed, 61 insertions(+), 54 deletions(-) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 4a407680f..85d89dbaf 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -380,7 +380,7 @@ def choose_table_schema(self, _session: Session): operate=OperationEnum.CHOOSE_TABLE, record_id=self.record.id, local_operation=True) - self.chat_question.db_schema = self.out_ds_instance.get_db_schema( + self.chat_question.db_schema, tables = self.out_ds_instance.get_db_schema( self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema( session=_session, current_user=self.current_user, @@ -392,7 +392,8 @@ def choose_table_schema(self, _session: Session): self.chat_question.sample_data = get_tables_sample_data( session=_session, current_user=self.current_user, - ds=self.ds) + ds=self.ds, + table_list=tables) self.current_logs[OperationEnum.CHOOSE_TABLE] = end_log(session=_session, log=self.current_logs[OperationEnum.CHOOSE_TABLE], @@ -508,7 +509,7 @@ def generate_recommend_questions_task(self, _session: Session): # get schema if self.ds and not self.chat_question.db_schema: - self.chat_question.db_schema = self.out_ds_instance.get_db_schema( + self.chat_question.db_schema, tables = self.out_ds_instance.get_db_schema( self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema( session=_session, current_user=self.current_user, ds=self.ds, @@ -516,11 +517,11 @@ def generate_recommend_questions_task(self, _session: Session): embedding=False) # Get sample data for all tables - if not self.out_ds_instance: - self.chat_question.sample_data = get_tables_sample_data( - session=_session, - current_user=self.current_user, - ds=self.ds) + # if not self.out_ds_instance: + # self.chat_question.sample_data = get_tables_sample_data( + # session=_session, + # current_user=self.current_user, + # ds=self.ds) guess_msg: List[Union[BaseMessage, dict[str, Any]]] = [] guess_msg.append(SystemPromptMessage(content=self.chat_question.guess_sys_question(self.articles_number))) @@ -1356,7 +1357,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, return # generate chart - used_tables_schema = self.out_ds_instance.get_db_schema( + used_tables_schema, used_tables = self.out_ds_instance.get_db_schema( self.ds.id, self.chat_question.question, embedding=False, table_list=tables) if self.out_ds_instance else get_table_schema( session=_session, diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index 033a96073..372e4ee37 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -65,6 +65,7 @@ def check_name(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDat if ds_list is not None and len(ds_list) > 0: raise HTTPException(status_code=500, detail=trans('i18n_ds_name_exist')) + @clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="user.oid") async def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: CreateDatasource): ds = CoreDatasource() @@ -490,7 +491,7 @@ def get_table_sample_data(ds: CoreDatasource, table_name: str, fields: list) -> def get_tables_sample_data(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, - table_list: list[str] = None) -> str: + table_list: list[str] = None) -> str: """Get sample data (3 rows) for all tables to help AI understand the data""" table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds) if len(table_objs) == 0: @@ -508,15 +509,16 @@ def get_tables_sample_data(session: SessionDep, current_user: CurrentUser, ds: C def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str, - embedding: bool = True, table_list: list[str] = None) -> str: + embedding: bool = True, table_list: list[str] = None) -> tuple[str, list]: schema_str = "" table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds) if len(table_objs) == 0: - return schema_str + return schema_str, [] db_name = table_objs[0].schema schema_str += f"【DB_ID】 {db_name}\n【Schema】\n" tables = [] all_tables = [] # temp save all tables + table_name_list = [] for obj in table_objs: # 如果传入了table_list,则只处理在列表中的表 if table_list is not None and obj.table.table_name not in table_list: @@ -546,13 +548,14 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat schema_table += ",\n".join(field_list) schema_table += '\n]\n' - t_obj = {"id": obj.table.id, "schema_table": schema_table, "embedding": obj.table.embedding} + t_obj = {"id": obj.table.id, "table_name": obj.table.table_name, "schema_table": schema_table, + "embedding": obj.table.embedding} tables.append(t_obj) all_tables.append(t_obj) # 如果没有符合过滤条件的表,直接返回 if not tables: - return schema_str + return schema_str, [] # do table embedding if embedding and tables and settings.TABLE_EMBEDDING_ENABLED: @@ -561,6 +564,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat if tables: for s in tables: schema_str += s.get('schema_table') + table_name_list.append(s.get('table_name')) # field relation if tables and ds.table_relation: @@ -592,6 +596,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat if lost_tables: for s in lost_tables: schema_str += s.get('schema_table') + table_name_list.append(s.get('table_name')) # get field dict relation_field_ids = [] @@ -609,13 +614,16 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat for ele in all_relations: schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n" - return schema_str + return schema_str, table_name_list + @cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="oid") async def get_ws_ds(session, oid) -> list: stmt = select(CoreDatasource.id).distinct().where(CoreDatasource.oid == oid) db_list = session.exec(stmt).all() return db_list + + @clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="oid") async def clear_ws_ds_cache(oid): - SQLBotLogUtil.info(f"ds cache for ws [{oid}] has been cleaned") \ No newline at end of file + SQLBotLogUtil.info(f"ds cache for ws [{oid}] has been cleaned") diff --git a/backend/apps/datasource/embedding/ds_embedding.py b/backend/apps/datasource/embedding/ds_embedding.py index 19feada2d..9bfe4a48e 100644 --- a/backend/apps/datasource/embedding/ds_embedding.py +++ b/backend/apps/datasource/embedding/ds_embedding.py @@ -23,7 +23,7 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o if out_ds.ds_list: for _ds in out_ds.ds_list: ds = out_ds.get_ds(_ds.id) - table_schema = out_ds.get_db_schema(_ds.id, question, embedding=False) + table_schema, tables = out_ds.get_db_schema(_ds.id, question, embedding=False) ds_info = f"{ds.name}, {ds.description}\n" ds_schema = ds_info + table_schema _list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds}) diff --git a/backend/apps/datasource/embedding/table_embedding.py b/backend/apps/datasource/embedding/table_embedding.py index c467ecd8d..186debec4 100644 --- a/backend/apps/datasource/embedding/table_embedding.py +++ b/backend/apps/datasource/embedding/table_embedding.py @@ -45,7 +45,7 @@ def calc_table_embedding(tables: list[dict], question: str): for table in tables: _list.append( {"id": table.get('id'), "schema_table": table.get('schema_table'), "embedding": table.get('embedding'), - "cosine_similarity": 0.0}) + "cosine_similarity": 0.0, "table_name": table.get('table_name')}) if _list: try: @@ -70,7 +70,7 @@ def calc_table_embedding(tables: list[dict], question: str): end_time = time.time() SQLBotLogUtil.info(str(end_time - start_time)) SQLBotLogUtil.info(json.dumps([{"id": ele.get('id'), "schema_table": ele.get('schema_table'), - "cosine_similarity": ele.get('cosine_similarity')} + "cosine_similarity": ele.get('cosine_similarity'), "table_name": ele.get('table_name')} for ele in _list])) return _list except Exception: diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index 130ef9f74..1fa5eb27c 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -181,7 +181,7 @@ def get_simple_ds_list(self): raise Exception("Datasource list is not found.") def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True, - table_list: list[str] = None) -> str: + table_list: list[str] = None) -> tuple[str, list]: ds = self.get_ds(ds_id) schema_str = "" db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase @@ -222,7 +222,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True, for s in tables: schema_str += s.get('schema_table') - return schema_str + return schema_str, [] def get_ds(self, ds_id: int, trans: Trans = None): if self.ds_list: diff --git a/frontend/src/views/chat/execution-component/LogWithAi.vue b/frontend/src/views/chat/execution-component/LogWithAi.vue index 95087d130..b4c456b20 100644 --- a/frontend/src/views/chat/execution-component/LogWithAi.vue +++ b/frontend/src/views/chat/execution-component/LogWithAi.vue @@ -51,46 +51,44 @@ const recordsBeforeCurrentQuestion = computed(() => - +