-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchat_interface.py
More file actions
177 lines (145 loc) · 7.25 KB
/
chat_interface.py
File metadata and controls
177 lines (145 loc) · 7.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import json
import re
from langchain_core.messages import HumanMessage, AIMessageChunk, ToolMessage
SILENT_NODES = {"rewrite_query"}
SYSTEM_NODES = {"summarize_history", "rewrite_query"}
SYSTEM_NODE_CONFIG = {
"rewrite_query": {"title": "🔍 Query Analysis & Rewriting"},
"summarize_history": {"title": "📋 Chat History Summary"},
}
# --- Helpers ---
def make_message(content, *, title=None, node=None):
"""
Standardizes the message format for the UI.
Includes metadata like 'title' (for headers) and 'node' (source of the message).
"""
msg = {"role": "assistant", "content": content}
if title or node:
msg["metadata"] = {k: v for k, v in {"title": title, "node": node}.items() if v}
return msg
def find_msg_idx(messages, node):
"""Helper to locate a specific node's message in the state list."""
return next(
(i for i, m in enumerate(messages) if m.get("metadata", {}).get("node") == node),
None,
)
def parse_rewrite_json(buffer):
"""Attempts to extract a JSON block from the LLM's raw character stream."""
match = re.search(r"\{.*\}", buffer, re.DOTALL)
if not match:
return None
try:
return json.loads(match.group())
except Exception:
return None
def format_rewrite_content(buffer):
"""
Translates the raw JSON rewrite data into pretty Markdown for the UI.
This shows the user exactly how their query was analyzed.
"""
data = parse_rewrite_json(buffer)
if not data:
return "⏳ Analyzing query..."
if data.get("is_clear"):
lines = ["✅ **Query is clear**"]
if data.get("questions"):
lines += ["\n**Rewritten queries:**"] + [f"- {q}" for q in data["questions"]]
else:
lines = ["❓ **Query is unclear**"]
clarification = data.get("clarification_needed", "")
if clarification and clarification.strip().lower() != "no":
lines.append(f"\nClarification needed: *{clarification}*")
return "\n".join(lines)
# --- End of Helpers ---
class ChatInterface:
"""
The UI-agnostic interface that translates raw LangGraph message streams
into a format suitable for front-end displays (like Gradio).
It handles foldable system messages, tool call previews, and streaming LLM tokens.
"""
def __init__(self, rag_system):
self.rag_system = rag_system
def _handle_system_node(self, chunk, node, response_messages, system_node_buffer):
"""
Setup or refresh the foldable node message block, and capture clarification requests.
System nodes (like query rewriting) are displayed in a distinct UI block.
"""
system_node_buffer[node] = system_node_buffer.get(node, "") + chunk.content
buffer = system_node_buffer[node]
title = SYSTEM_NODE_CONFIG[node]["title"]
content = format_rewrite_content(buffer) if node == "rewrite_query" else buffer
idx = find_msg_idx(response_messages, node)
if idx is None:
response_messages.append(make_message(content, title=title, node=node))
else:
response_messages[idx]["content"] = content
if node == "rewrite_query":
self._surface_clarification(buffer, response_messages)
def _surface_clarification(self, buffer, response_messages):
"""Generates a text prompt for user clarification if the model determines the intent is unclear."""
data = parse_rewrite_json(buffer) or {}
clarification = data.get("clarification_needed", "")
if not data.get("is_clear") and clarification.strip().lower() not in ("", "no"):
cidx = find_msg_idx(response_messages, "clarification")
if cidx is None:
response_messages.append(make_message(clarification, node="clarification"))
else:
response_messages[cidx]["content"] = clarification
def _handle_tool_call(self, chunk, response_messages, active_tool_calls):
"""Append newly executed tool events as foldable elements in the UI."""
for tc in chunk.tool_calls:
if tc.get("id") and tc["id"] not in active_tool_calls:
response_messages.append(
make_message(f"Running `{tc['name']}`...", title=f"🛠️ {tc['name']}")
)
active_tool_calls[tc["id"]] = len(response_messages) - 1
def _handle_tool_result(self, chunk, response_messages, active_tool_calls):
"""Insert the output of a tool into the appropriate expanding message block."""
idx = active_tool_calls.get(chunk.tool_call_id)
if idx is not None:
preview = str(chunk.content)[:300]
suffix = "\n..." if len(str(chunk.content)) > 300 else ""
response_messages[idx]["content"] = f"```\n{preview}{suffix}\n```"
def _handle_llm_token(self, chunk, node, response_messages):
"""Add incoming generator text tokens to the final AI message box."""
last = response_messages[-1] if response_messages else None
if not (last and last.get("role") == "assistant" and "metadata" not in last):
response_messages.append(make_message(""))
response_messages[-1]["content"] += chunk.content
def chat(self, message, history):
"""A streaming generator yielding UI payload dictionaries."""
if not self.rag_system.agent_graph:
yield "⚠️ System not initialized!"
return
config = self.rag_system.get_config()
current_state = self.rag_system.agent_graph.get_state(config)
try:
if current_state.next:
self.rag_system.agent_graph.update_state(config, {"messages": [HumanMessage(content=message.strip())]})
stream_input = None
else:
stream_input = {"messages": [HumanMessage(content=message.strip())]}
response_messages = []
active_tool_calls = {}
system_node_buffer = {}
for chunk, metadata in self.rag_system.agent_graph.stream(stream_input, config=config, stream_mode="messages"):
"""
Iterate over the token stream from the LangGraph execution.
We identify the source node of each token to route it to the
correct UI component (e.g., final answer vs. tool logs).
"""
node = metadata.get("langgraph_node", "")
if node in SYSTEM_NODES and isinstance(chunk, AIMessageChunk) and chunk.content:
self._handle_system_node(chunk, node, response_messages, system_node_buffer)
elif hasattr(chunk, "tool_calls") and chunk.tool_calls:
self._handle_tool_call(chunk, response_messages, active_tool_calls)
elif isinstance(chunk, ToolMessage):
self._handle_tool_result(chunk, response_messages, active_tool_calls)
elif isinstance(chunk, AIMessageChunk) and chunk.content and node not in SILENT_NODES:
self._handle_llm_token(chunk, node, response_messages)
yield response_messages
except Exception as e:
yield f"❌ Error: {str(e)}"
def clear_session(self):
self.rag_system.reset_thread()
self.rag_system.observability.flush()