Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def choose_table_schema(self, _session: Session):
operate=OperationEnum.CHOOSE_TABLE,
record_id=self.record.id,
local_operation=True)
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
self.chat_question.db_schema, tables = self.out_ds_instance.get_db_schema(
self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema(
session=_session,
current_user=self.current_user,
Expand All @@ -392,7 +392,8 @@ def choose_table_schema(self, _session: Session):
self.chat_question.sample_data = get_tables_sample_data(
session=_session,
current_user=self.current_user,
ds=self.ds)
ds=self.ds,
table_list=tables)

self.current_logs[OperationEnum.CHOOSE_TABLE] = end_log(session=_session,
log=self.current_logs[OperationEnum.CHOOSE_TABLE],
Expand Down Expand Up @@ -508,19 +509,19 @@ def generate_recommend_questions_task(self, _session: Session):

# get schema
if self.ds and not self.chat_question.db_schema:
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
self.chat_question.db_schema, tables = self.out_ds_instance.get_db_schema(
self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema(
session=_session,
current_user=self.current_user, ds=self.ds,
question=self.chat_question.question,
embedding=False)

# Get sample data for all tables
if not self.out_ds_instance:
self.chat_question.sample_data = get_tables_sample_data(
session=_session,
current_user=self.current_user,
ds=self.ds)
# if not self.out_ds_instance:
# self.chat_question.sample_data = get_tables_sample_data(
# session=_session,
# current_user=self.current_user,
# ds=self.ds)

guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
guess_msg.append(SystemPromptMessage(content=self.chat_question.guess_sys_question(self.articles_number)))
Expand Down Expand Up @@ -1356,7 +1357,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
return

# generate chart
used_tables_schema = self.out_ds_instance.get_db_schema(
used_tables_schema, used_tables = self.out_ds_instance.get_db_schema(
self.ds.id, self.chat_question.question, embedding=False,
table_list=tables) if self.out_ds_instance else get_table_schema(
session=_session,
Expand Down
22 changes: 15 additions & 7 deletions backend/apps/datasource/crud/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def check_name(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDat
if ds_list is not None and len(ds_list) > 0:
raise HTTPException(status_code=500, detail=trans('i18n_ds_name_exist'))


@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="user.oid")
async def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: CreateDatasource):
ds = CoreDatasource()
Expand Down Expand Up @@ -490,7 +491,7 @@ def get_table_sample_data(ds: CoreDatasource, table_name: str, fields: list) ->


def get_tables_sample_data(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource,
table_list: list[str] = None) -> str:
table_list: list[str] = None) -> str:
"""Get sample data (3 rows) for all tables to help AI understand the data"""
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
if len(table_objs) == 0:
Expand All @@ -508,15 +509,16 @@ def get_tables_sample_data(session: SessionDep, current_user: CurrentUser, ds: C


def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
embedding: bool = True, table_list: list[str] = None) -> str:
embedding: bool = True, table_list: list[str] = None) -> tuple[str, list]:
schema_str = ""
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
if len(table_objs) == 0:
return schema_str
return schema_str, []
db_name = table_objs[0].schema
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
tables = []
all_tables = [] # temp save all tables
table_name_list = []
for obj in table_objs:
# 如果传入了table_list,则只处理在列表中的表
if table_list is not None and obj.table.table_name not in table_list:
Expand Down Expand Up @@ -546,13 +548,14 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
schema_table += ",\n".join(field_list)
schema_table += '\n]\n'

t_obj = {"id": obj.table.id, "schema_table": schema_table, "embedding": obj.table.embedding}
t_obj = {"id": obj.table.id, "table_name": obj.table.table_name, "schema_table": schema_table,
"embedding": obj.table.embedding}
tables.append(t_obj)
all_tables.append(t_obj)

# 如果没有符合过滤条件的表,直接返回
if not tables:
return schema_str
return schema_str, []

# do table embedding
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
Expand All @@ -561,6 +564,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
if tables:
for s in tables:
schema_str += s.get('schema_table')
table_name_list.append(s.get('table_name'))

# field relation
if tables and ds.table_relation:
Expand Down Expand Up @@ -592,6 +596,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
if lost_tables:
for s in lost_tables:
schema_str += s.get('schema_table')
table_name_list.append(s.get('table_name'))

# get field dict
relation_field_ids = []
Expand All @@ -609,13 +614,16 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
for ele in all_relations:
schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n"

return schema_str
return schema_str, table_name_list


@cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="oid")
async def get_ws_ds(session, oid) -> list:
stmt = select(CoreDatasource.id).distinct().where(CoreDatasource.oid == oid)
db_list = session.exec(stmt).all()
return db_list


@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.DS_ID_LIST, keyExpression="oid")
async def clear_ws_ds_cache(oid):
SQLBotLogUtil.info(f"ds cache for ws [{oid}] has been cleaned")
SQLBotLogUtil.info(f"ds cache for ws [{oid}] has been cleaned")
2 changes: 1 addition & 1 deletion backend/apps/datasource/embedding/ds_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
if out_ds.ds_list:
for _ds in out_ds.ds_list:
ds = out_ds.get_ds(_ds.id)
table_schema = out_ds.get_db_schema(_ds.id, question, embedding=False)
table_schema, tables = out_ds.get_db_schema(_ds.id, question, embedding=False)
ds_info = f"{ds.name}, {ds.description}\n"
ds_schema = ds_info + table_schema
_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})
Expand Down
4 changes: 2 additions & 2 deletions backend/apps/datasource/embedding/table_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def calc_table_embedding(tables: list[dict], question: str):
for table in tables:
_list.append(
{"id": table.get('id'), "schema_table": table.get('schema_table'), "embedding": table.get('embedding'),
"cosine_similarity": 0.0})
"cosine_similarity": 0.0, "table_name": table.get('table_name')})

if _list:
try:
Expand All @@ -70,7 +70,7 @@ def calc_table_embedding(tables: list[dict], question: str):
end_time = time.time()
SQLBotLogUtil.info(str(end_time - start_time))
SQLBotLogUtil.info(json.dumps([{"id": ele.get('id'), "schema_table": ele.get('schema_table'),
"cosine_similarity": ele.get('cosine_similarity')}
"cosine_similarity": ele.get('cosine_similarity'), "table_name": ele.get('table_name')}
for ele in _list]))
return _list
except Exception:
Expand Down
4 changes: 2 additions & 2 deletions backend/apps/system/crud/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def get_simple_ds_list(self):
raise Exception("Datasource list is not found.")

def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
table_list: list[str] = None) -> str:
table_list: list[str] = None) -> tuple[str, list]:
ds = self.get_ds(ds_id)
schema_str = ""
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
Expand Down Expand Up @@ -222,7 +222,7 @@ def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
for s in tables:
schema_str += s.get('schema_table')

return schema_str
return schema_str, []

def get_ds(self, ds_id: int, trans: Trans = None):
if self.ds_list:
Expand Down
64 changes: 31 additions & 33 deletions frontend/src/views/chat/execution-component/LogWithAi.vue
Original file line number Diff line number Diff line change
Expand Up @@ -51,46 +51,44 @@ const recordsBeforeCurrentQuestion = computed(() =>
<template v-if="item.error">
{{ error }}
</template>
<template v-else>
<div class="item-list">
<div class="inner-title">{{ t('chat.log_system') }}</div>
<div class="inner-item">
<div class="inner-item-title">
{{ systemRecord.type }}
</div>
<div class="inner-item-description">
<SQLComponent :sql="systemRecord.content" />
</div>
<div class="item-list">
<div class="inner-title">{{ t('chat.log_system') }}</div>
<div class="inner-item">
<div class="inner-item-title">
{{ systemRecord.type }}
</div>
<div class="inner-item-description">
<SQLComponent :sql="systemRecord.content" />
</div>
<template v-if="recordsBeforeCurrentQuestion.length > 0">
<div class="inner-title">{{ t('chat.log_history') }}</div>
<div class="inner-item">
<div v-for="(ele, index) in recordsBeforeCurrentQuestion" :key="index">
<div class="inner-item-title">
{{ ele.type }}
</div>
<div class="inner-item-description">
<SQLComponent :sql="ele.content" />
</div>
</div>
<template v-if="recordsBeforeCurrentQuestion.length > 0">
<div class="inner-title">{{ t('chat.log_history') }}</div>
<div class="inner-item">
<div v-for="(ele, index) in recordsBeforeCurrentQuestion" :key="index">
<div class="inner-item-title">
{{ ele.type }}
</div>
<div class="inner-item-description">
<SQLComponent :sql="ele.content" />
</div>
</div>
</template>
<div class="inner-title">{{ t('chat.log_question') }}</div>
</div>
</template>
<div class="inner-title">{{ t('chat.log_question') }}</div>
<div class="inner-item">
<div class="inner-item-description">
<SQLComponent :sql="lastHumanRecord.content" />
</div>
</div>
<template v-if="lastAiAfterHuman">
<div class="inner-title">{{ t('chat.log_answer') }}</div>
<div class="inner-item">
<div class="inner-item-description">
<SQLComponent :sql="lastHumanRecord.content" />
<SQLComponent :sql="lastAiAfterHuman.content" />
</div>
</div>
<template v-if="lastAiAfterHuman">
<div class="inner-title">{{ t('chat.log_answer') }}</div>
<div class="inner-item">
<div class="inner-item-description">
<SQLComponent :sql="lastAiAfterHuman.content" />
</div>
</div>
</template>
</div>
</template>
</template>
</div>
</BaseContent>
</template>

Expand Down
Loading