Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
170 changes: 170 additions & 0 deletions beisi_rag/backend/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from __future__ import annotations
import os
from pathlib import Path
from typing import List, AsyncGenerator
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import dashscope
from dashscope.aigc.generation import AioGeneration
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_core.documents import Document

# 加载环境变量
load_dotenv(dotenv_path=Path(__file__).resolve().parents[1] / "config" / ".env")

# 初始化DashScope API
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
BASE_URL = os.getenv("OPENAI_BASE_URL")
MODEL_NAME = os.getenv("MODEL_NAME", "qwen-plus")
TOP_K_DEFAULT = int(os.getenv("TOP_K", "4"))

# 初始化 DashScope 生成客户端
client = AioGeneration()

# 加载 FAISS 向量数据库
INDEX_DIR = Path(__file__).parent.parent / "vectordb"
emb = DashScopeEmbeddings(dashscope_api_key=DASHSCOPE_API_KEY, model="text-embedding-v3")
vectordb = FAISS.load_local(str(INDEX_DIR), emb, allow_dangerous_deserialization=True)

# 格式化上下文
def formatContext(docs: List[Document]) -> str:
parts = []
for i, d in enumerate(docs, 1):
src = (d.metadata or {}).get("source", "unknown")
txt = (d.page_content or "").replace("\n", " ")
if len(txt) > 500:
txt = txt[:500] + "…"
parts.append(f"[{i}] ({src}) {txt}")
return "\n".join(parts)

# 同步检索
def retrieve(question: str, k: int) -> List[Document]:
retriever = vectordb.as_retriever(search_type="similarity", search_kwargs={"k": k})
return retriever.invoke(question)

SYSTEM_PROMPT = (
"你是严谨的中文检索增强助手。严格依据给定上下文回答;"
"若上下文没有答案,请明确说“不确定”,并给出你能确认的线索。"
)

# 构建用户提示
def buildUserPrompt(question: str, context: str) -> str:
return (
"结合<已检索上下文>作答:\n"
f"<已检索上下文>\n{context}\n</已检索上下文>\n\n"
f"用户问题:{question}\n"
"要求:若答案不在上下文里,明确说明不确定;用中文、分点作答,必要时给出引用的原句摘要。"
)

# FastAPI应用设置
app = FastAPI(title="Qwen RAG QA Backend", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# 请求模型
class ChatReq(BaseModel):
question: str
top_k: int = Field(default=TOP_K_DEFAULT, ge=1, le=12)
model: str = Field(default=MODEL_NAME)
temperature: float = Field(default=0.3, ge=0, le=1)

class ChatResp(BaseModel):
answer: str

# 流式接口
async def _sse_event(data: str) -> bytes:
print(f"发送流数据: {data}") # 增加调试信息
return f"data: {data}\n\n".encode("utf-8")

@app.post("/chat/stream")
async def chat_stream(req: ChatReq):
try:
docs = retrieve(req.question, req.top_k)
context = formatContext(docs)
prompt = buildUserPrompt(req.question, context)

async def event_gen() -> AsyncGenerator[bytes, None]:
yield await _sse_event("{\"type\":\"meta\",\"message\":\"stream-start\"}")

# 调试: 输出流的开始
print("流式处理开始...")

try:
# 使用 DashScope 生成流式回答
stream = dashscope.Generation.call(
api_key=DASHSCOPE_API_KEY,
model=req.model,
messages=[{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt}],
result_format="message",
enable_thinking=True,
stream=True,
incremental_output=True,
)

reasoning_content = ""
answer_content = ""
is_answering = False

# 逐步处理流数据
for chunk in stream:
print(f"处理流数据: {chunk.output}") # 增加调试信息

if chunk.output.choices[0].message.content == "" and chunk.output.choices[0].message.reasoning_content == "":
continue
elif chunk.output.choices[0].message.reasoning_content != "" and chunk.output.choices[0].message.content == "":
reasoning_content += chunk.output.choices[0].message.reasoning_content
yield await _sse_event(chunk.output.choices[0].message.reasoning_content)
elif chunk.output.choices[0].message.content != "":
if not is_answering:
yield await _sse_event("\n" + "=" * 20 + "完整回复" + "=" * 20)
is_answering = True
answer_content += chunk.output.choices[0].message.content
yield await _sse_event(chunk.output.choices[0].message.content)

yield await _sse_event("[DONE]")
except Exception as ie:
# 错误事件
err = str(ie).replace("\n", " ")
yield await _sse_event(f"{{\"type\":\"error\",\"message\":\"{err}\"}}")

return StreamingResponse(event_gen(), media_type="text/event-stream")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


# 非流式接口
@app.post("/chat", response_model=ChatResp)
async def chat(req: ChatReq):
try:
docs = retrieve(req.question, req.top_k)
context = formatContext(docs)
prompt = buildUserPrompt(req.question, context)

response = await AioGeneration.call(
api_key=DASHSCOPE_API_KEY,
model=req.model,
messages=[{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt}],
result_format="message",
)
answer = response.output.choices[0].message.content.strip()
return ChatResp(answer=answer)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
import uvicorn
HOST = os.getenv("HOST", "0.0.0.0")
PORT = int(os.getenv("PORT", "8000"))
uvicorn.run(app, host=HOST, port=PORT, reload=False)
7 changes: 7 additions & 0 deletions beisi_rag/config/.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
export DASHSCOPE_API_KEY="***************************************************"
export OPENAI_API_KEY="$DASHSCOPE_API_KEY"
export OPENAI_BASE_URL="https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
export MODEL_NAME="qwen-plus"
export TOP_K="4"


14 changes: 14 additions & 0 deletions beisi_rag/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
faiss-cpu==1.12.0
importlib-metadata==8.0.0
jaraco.collections==5.1.0
langchain==1.0.3
langchain-community==0.4.1
langchain-openai==1.0.1
pip-chill==1.0.3
platformdirs==4.2.2
tomli==2.0.1
unicorn==2.1.4
uvicorn==0.38.0
fastapi
dotenv
dashscope
161 changes: 161 additions & 0 deletions beisi_rag/src/chat_qwen_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# /Users/profighted/beisi-tech/docs/RAG-Anything/beisi_rag/chat_qwen_rag.py
import os
from pathlib import Path
from dotenv import load_dotenv

from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage

from openai import OpenAI

load_dotenv(dotenv_path=Path(__file__).resolve().parents[1] / "config" / ".env")

# 本地向量库目录(确保与 ingest 阶段一致)
INDEX_DIR = Path(__file__).parent.parent / "vectordb"

# ====== 配置 Qwen 兼容端点 ======
DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
assert DASHSCOPE_API_KEY, "请先 export DASHSCOPE_API_KEY=你的通义 DashScope Key"

# 如需新加坡地域,改成 https://dashscope-intl.aliyuncs.com/compatible-mode/v1
DASHSCOPE_BASE_URL = os.environ.get(
"DASHSCOPE_BASE_URL",
"https://dashscope.aliyuncs.com/compatible-mode/v1",
)

# 对话模型 & 向量模型(按你账号可用情况调整)
CHAT_MODEL = os.environ.get("QWEN_CHAT_MODEL", "qwen-plus") # 也可用 qwen2.5-7b-instruct / qwen-turbo 等
EMBEDDING_MODEL = os.environ.get("QWEN_EMBED_MODEL", "text-embedding-v3")


def load_retriever():
"""加载 FAISS 检索器(与 ingest 使用同一 Embedding 模型)"""
embeddings = DashScopeEmbeddings(
dashscope_api_key=DASHSCOPE_API_KEY,
model=EMBEDDING_MODEL,
)
vectordb = FAISS.load_local(str(INDEX_DIR), embeddings, allow_dangerous_deserialization=True)
return vectordb.as_retriever(search_type="similarity", search_kwargs={"k": 4})


def build_llm_runnable():
"""
用 OpenAI 兼容端点(Qwen)构建一个 LangChain Runnable。
避免使用 langchain_openai,降低版本依赖冲突风险。
"""
client = OpenAI(api_key=DASHSCOPE_API_KEY, base_url=DASHSCOPE_BASE_URL)

def _lc_to_openai_messages(prompt_value) -> list[dict]:
"""
将 LangChain 的 PromptValue / BaseMessage 列表转换为 OpenAI 兼容 messages。
"""
if hasattr(prompt_value, "to_messages"):
msgs = prompt_value.to_messages() # List[BaseMessage]
elif isinstance(prompt_value, list) and all(isinstance(m, BaseMessage) for m in prompt_value):
msgs = prompt_value
else:
# 兜底:当成用户单轮输入
msgs = [HumanMessage(content=str(prompt_value))]

out = []
for m in msgs:
if isinstance(m, SystemMessage):
role = "system"
elif isinstance(m, HumanMessage):
role = "user"
elif isinstance(m, AIMessage):
role = "assistant"
else:
role = "user"
out.append({"role": role, "content": m.content})
return out

def _invoke(prompt_value: BaseMessage | list[BaseMessage] | str) -> str:
messages = _lc_to_openai_messages(prompt_value)
resp = client.chat.completions.create(
model=CHAT_MODEL,
messages=messages,
temperature=0.3,
)
return resp.choices[0].message.content

return RunnableLambda(_invoke)


# RAG_PROMPT = ChatPromptTemplate.from_template(
# """你是严谨的检索增强助手。结合<已检索上下文>回答用户问题。
# - 如果答案不在上下文里,请明确说明“不确定”并给出你能确认的线索。
# - 用中文输出,尽量给出引用的原句摘要,并在末尾标注引用编号(如 [1][3])。

# <已检索上下文>
# {context}
# </已检索上下文>

# 用户问题:{question}
# """
# )

RAG_PROMPT = ChatPromptTemplate.from_template(
"""你是严谨的检索增强助手。结合<已检索上下文>回答用户问题。
-- 如果答案不在上下文里,请明确说明“不确定”并给出你能确认的线索。
-- 用中文输出,尽量给出引用的原句摘要,并在末尾标注引用编号(如 [1][3])。
你是严谨的检索增强助手。请**用你自己的话**综合回答,禁止大段原文粘贴。
+规则:
+1) 先总结,再给出处;答案主体必须是**你自己的表述**。
+2) 如需引用原句,每处引用≤50字,并用引号与编号标注,如 “……”[1]。
+3) 如果上下文没有明确答案,请说“不确定”,并给出可验证的线索。
+4) 输出中文、结构化要点,并在末尾列出参考编号。

<已检索上下文>
{context}
</已检索上下文>

用户问题:{question}
"""
)



def format_docs(docs):
out = []
for i, d in enumerate(docs, 1):
meta = d.metadata or {}
src = meta.get("source", "unknown")
# 摘要最多 500 字符,避免过长提示上下文
out.append(f"[{i}] ({src}) {d.page_content[:500]}")
return "\n\n".join(out)


def main():
retriever = load_retriever()
llm_runnable = build_llm_runnable()

# RAG 链:检索 → 拼接上下文 → 提示词 → Qwen(兼容端点) → 解析文本
chain = (
RunnableParallel(context=retriever | format_docs, question=RunnablePassthrough())
| RAG_PROMPT
| llm_runnable
| StrOutputParser()
)

print("💬 输入你的问题(Ctrl+C 退出)")
while True:
try:
q = input("> ").strip()
if not q:
continue
ans = chain.invoke(q)
print("\n" + ans + "\n")
except (EOFError, KeyboardInterrupt):
print("\n再见~")
break
except Exception as e:
print("❌ 出错:", e)


if __name__ == "__main__":
main()
Loading