diff --git a/docs/en_US/release_notes_9_14.rst b/docs/en_US/release_notes_9_14.rst index f841dcd5cd5..d919be63b58 100644 --- a/docs/en_US/release_notes_9_14.rst +++ b/docs/en_US/release_notes_9_14.rst @@ -29,5 +29,6 @@ Bug fixes ********* | `Issue #9279 `_ - Fixed an issue where OAuth2 authentication fails with 'object has no attribute' if OAUTH2_AUTO_CREATE_USER is False. + | `Issue #9736 `_ - Fixed an issue where the AI Assistant was not retaining conversation context between messages, with chat history compaction to manage token budgets. | `Issue #9392 `_ - Ensure that the Geometry Viewer refreshes when re-running queries or switching geometry columns, preventing stale data from being displayed. | `Issue #9721 `_ - Fixed an issue where permissions page is not completely accessible on full scroll. diff --git a/web/pgadmin/llm/compaction.py b/web/pgadmin/llm/compaction.py new file mode 100644 index 00000000000..5d93570f245 --- /dev/null +++ b/web/pgadmin/llm/compaction.py @@ -0,0 +1,429 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Conversation history compaction for managing LLM token budgets. + +This module implements a compaction strategy to keep conversation history +within token limits. It classifies messages by importance and drops +lower-value messages first, while preserving tool call/result pairs +and recent conversation context. + +Inspired by the approach described at: +https://www.pgedge.com/blog/lessons-learned-writing-an-mcp-server-for-postgresql +""" + +import re +from typing import Optional + +from pgadmin.llm.models import Message, Role, ToolCall + + +# Token budget defaults +DEFAULT_MAX_TOKENS = 100000 +DEFAULT_RECENT_WINDOW = 10 + +# Provider-specific characters-per-token ratios +CHARS_PER_TOKEN = { + 'anthropic': 3.8, + 'openai': 4.0, + 'ollama': 4.5, + 'docker': 4.0, +} + +# SQL content is tokenized less efficiently +SQL_TOKEN_MULTIPLIER = 1.2 + +# Overhead per message (role markers, formatting, etc.) +MESSAGE_OVERHEAD_TOKENS = 10 + +# Importance tiers +CLASS_ANCHOR = 1.0 # Schema info, corrections - always keep +CLASS_IMPORTANT = 0.8 # Query analysis, errors, insights +CLASS_CONTEXTUAL = 0.6 # Detailed responses, tool results +CLASS_ROUTINE = 0.4 # Short responses, standard messages +CLASS_TRANSIENT = 0.1 # Acknowledgments, short phrases + +# Patterns for classification +_SCHEMA_PATTERNS = re.compile( + r'\b(CREATE|ALTER|DROP)\s+(TABLE|INDEX|VIEW|SCHEMA)\b' + r'|PRIMARY\s+KEY|FOREIGN\s+KEY|CONSTRAINT\b', + re.IGNORECASE +) + +_QUERY_PATTERNS = re.compile( + r'\bEXPLAIN\s+ANALYZE\b|execution\s+time\b' + r'|seq\s+scan\b|index\s+scan\b|query\s+plan\b', + re.IGNORECASE +) + +_ERROR_PATTERNS = re.compile( + r'\berror\b|\bfailed\b|\bsyntax\s+error\b' + r'|\bpermission\s+denied\b|\bdoes\s+not\s+exist\b', + re.IGNORECASE +) + + +def estimate_tokens(text: str, provider: str = 'openai') -> int: + """Estimate the number of tokens in a text string. + + Uses provider-specific character-per-token ratios and applies + a multiplier for SQL-heavy content. + + Args: + text: The text to estimate tokens for. + provider: The LLM provider name for ratio selection. + + Returns: + Estimated token count. + """ + if not text: + return 0 + + chars_per_token = CHARS_PER_TOKEN.get(provider, 4.0) + base_tokens = len(text) / chars_per_token + + # Apply SQL multiplier if content looks like SQL + if re.search(r'\b(SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER)\b', + text, re.IGNORECASE): + base_tokens *= SQL_TOKEN_MULTIPLIER + + return int(base_tokens) + MESSAGE_OVERHEAD_TOKENS + + +def estimate_message_tokens(message: Message, provider: str = 'openai') -> int: + """Estimate token count for a single Message object. + + Args: + message: The Message to estimate. + provider: The LLM provider name. + + Returns: + Estimated token count. + """ + total = estimate_tokens(message.content, provider) + + # Account for tool call arguments + for tc in message.tool_calls: + import json + total += estimate_tokens(json.dumps(tc.arguments), provider) + total += estimate_tokens(tc.name, provider) + + # Account for tool results + for tr in message.tool_results: + total += estimate_tokens(tr.content, provider) + + return total + + +def estimate_history_tokens( + messages: list[Message], provider: str = 'openai' +) -> int: + """Estimate total token count for a conversation history. + + Args: + messages: List of Message objects. + provider: The LLM provider name. + + Returns: + Estimated total token count. + """ + return sum(estimate_message_tokens(m, provider) for m in messages) + + +def _classify_message(message: Message) -> float: + """Classify a message by importance for compaction decisions. + + Args: + message: The message to classify. + + Returns: + Importance score from 0.0 to 1.0. + """ + content = message.content or '' + + # Tool results containing schema info are anchors + if message.role == Role.TOOL: + for tr in message.tool_results: + if _SCHEMA_PATTERNS.search(tr.content): + return CLASS_ANCHOR + if _ERROR_PATTERNS.search(tr.content): + return CLASS_IMPORTANT + # Large tool results are contextual + if len(tr.content) > 500: + return CLASS_CONTEXTUAL + return CLASS_ROUTINE + + # Assistant messages with tool calls are important (they reference tools) + if message.role == Role.ASSISTANT and message.tool_calls: + return CLASS_IMPORTANT + + # Check content patterns + if _SCHEMA_PATTERNS.search(content): + return CLASS_ANCHOR + if _ERROR_PATTERNS.search(content): + return CLASS_IMPORTANT + if _QUERY_PATTERNS.search(content): + return CLASS_IMPORTANT + + # Short messages are transient + if len(content) < 30: + return CLASS_TRANSIENT + + # Medium messages are routine + if len(content) < 100: + return CLASS_ROUTINE + + return CLASS_CONTEXTUAL + + +def _find_tool_pair_indices( + messages: list[Message] +) -> dict[int, frozenset[int]]: + """Find indices of tool_call/tool_result groups that must stay together. + + An assistant message may contain multiple tool_calls, each with a + corresponding tool result message. All messages in such a group + must be dropped or kept together. + + Returns a mapping where every index in a group maps to the full + set of indices in that group. + + Args: + messages: The message list. + + Returns: + Dict mapping index -> frozenset of all indices in the group. + """ + groups: dict[int, frozenset[int]] = {} + + for i, msg in enumerate(messages): + if msg.role == Role.ASSISTANT and msg.tool_calls: + tool_call_ids = {tc.id for tc in msg.tool_calls} + group_indices = {i} + for j in range(i + 1, len(messages)): + if messages[j].role == Role.TOOL: + for tr in messages[j].tool_results: + if tr.tool_call_id in tool_call_ids: + group_indices.add(j) + break + group = frozenset(group_indices) + for idx in group: + groups[idx] = group + + return groups + + +def compact_history( + messages: list[Message], + max_tokens: int = DEFAULT_MAX_TOKENS, + recent_window: int = DEFAULT_RECENT_WINDOW, + provider: str = 'openai' +) -> list[Message]: + """Compact conversation history to fit within a token budget. + + Strategy: + 1. Always keep the first message (provides original context) + 2. Always keep the last `recent_window` messages + 3. Among remaining messages, classify by importance and drop + lowest-value messages first + 4. Keep tool_call/tool_result pairs together + + Args: + messages: Full conversation history. + max_tokens: Maximum token budget for the history. + recent_window: Number of recent messages to always preserve. + provider: LLM provider name for token estimation. + + Returns: + Compacted list of messages that fits within the token budget. + """ + if not messages: + return messages + + # Check if we're already within budget + current_tokens = estimate_history_tokens(messages, provider) + if current_tokens <= max_tokens: + return messages + + total = len(messages) + + # Determine protected indices + protected = set() + + # Always protect the first message + protected.add(0) + + # Always protect the recent window + recent_start = max(1, total - recent_window) + for i in range(recent_start, total): + protected.add(i) + + # If protected messages alone exceed the budget, shrink the + # recent window until we have room for compaction candidates. + while recent_window > 0: + protected_tokens = sum( + estimate_message_tokens(messages[i], provider) + for i in protected + ) + if protected_tokens <= max_tokens: + break + recent_window -= 1 + recent_start = max(1, total - recent_window) + protected = {0} | set(range(recent_start, total)) + + # Find tool groups + tool_groups = _find_tool_pair_indices(messages) + + # Classify and score all non-protected messages + candidates = [] + for i in range(len(messages)): + if i not in protected: + score = _classify_message(messages[i]) + candidates.append((i, score)) + + # Sort by importance (lowest first - these get dropped first) + candidates.sort(key=lambda x: x[1]) + + # Drop messages starting from lowest importance until within budget + dropped = set() + for idx, score in candidates: + if current_tokens <= max_tokens: + break + + # Skip if already dropped (as part of a group) + if idx in dropped: + continue + + # Don't drop anchor messages unless we absolutely must + if score >= CLASS_ANCHOR: + break + + # Calculate tokens saved by dropping this message + saved = estimate_message_tokens(messages[idx], provider) + dropped.add(idx) + + # If this is part of a tool group, drop all partners too + if idx in tool_groups: + for partner in tool_groups[idx]: + if partner != idx and partner not in protected: + saved += estimate_message_tokens( + messages[partner], provider + ) + dropped.add(partner) + + current_tokens -= saved + + # If still over budget, drop anchor messages too + if current_tokens > max_tokens: + for idx, score in candidates: + if current_tokens <= max_tokens: + break + if idx in dropped: + continue + + saved = estimate_message_tokens(messages[idx], provider) + dropped.add(idx) + + if idx in tool_groups: + for partner in tool_groups[idx]: + if partner != idx and partner not in protected: + saved += estimate_message_tokens( + messages[partner], provider + ) + dropped.add(partner) + + current_tokens -= saved + + # Build the compacted message list preserving order + result = [msg for i, msg in enumerate(messages) if i not in dropped] + + return result + + +def deserialize_history( + history_data: list[dict] +) -> list[Message]: + """Deserialize conversation history from JSON request data. + + Converts a list of message dictionaries (from the frontend) into + Message objects suitable for passing to chat_with_database(). + + Args: + history_data: List of dicts with 'role' and 'content' keys, + and optionally 'tool_calls' and 'tool_results'. + + Returns: + List of Message objects. + """ + messages = [] + for item in history_data: + role_str = item.get('role', '') + content = item.get('content', '') + + try: + role = Role(role_str) + except ValueError: + continue # Skip unknown roles + + # Reconstruct tool calls if present + tool_calls = [] + for tc_data in item.get('tool_calls', []): + tool_calls.append(ToolCall( + id=tc_data.get('id', ''), + name=tc_data.get('name', ''), + arguments=tc_data.get('arguments', {}) + )) + + # Reconstruct tool results if present + from pgadmin.llm.models import ToolResult + tool_results = [] + for tr_data in item.get('tool_results', []): + tool_results.append(ToolResult( + tool_call_id=tr_data.get('tool_call_id', ''), + content=tr_data.get('content', ''), + is_error=tr_data.get('is_error', False) + )) + + messages.append(Message( + role=role, + content=content, + tool_calls=tool_calls, + tool_results=tool_results + )) + + return messages + + +def filter_conversational(messages: list[Message]) -> list[Message]: + """Filter history to only conversational messages for storage. + + Keeps user messages and final assistant responses (those without + tool calls). Drops intermediate assistant messages that contain + tool_use requests and all tool result messages, since these are + internal to each turn and don't need to persist between turns. + + This dramatically reduces history size since tool results often + contain large schema dumps and query results. + + Args: + messages: Full message history including tool call internals. + + Returns: + Filtered list with only user messages and final assistant + responses. + """ + result = [] + for msg in messages: + if msg.role == Role.USER: + result.append(msg) + elif msg.role == Role.ASSISTANT and not msg.tool_calls: + # Final assistant response (no pending tool calls) + result.append(msg) + # Skip Role.TOOL and assistant messages with tool_calls + return result diff --git a/web/pgadmin/llm/tests/README.md b/web/pgadmin/llm/tests/README.md index 8a17532d594..6fd23c16584 100644 --- a/web/pgadmin/llm/tests/README.md +++ b/web/pgadmin/llm/tests/README.md @@ -49,6 +49,23 @@ Tests interactive chat functionality including: - Validates context integration - Tests memory management +#### `test_compaction.py` - Conversation Compaction Tests +Tests the conversation history compaction module including: +- Token estimation with provider-specific ratios +- SQL content token multiplier +- History compaction with token budget enforcement +- First message and recent window preservation +- Low-value message dropping by importance classification +- Tool call/result pair integrity during compaction +- History deserialization from frontend JSON format +- Conversational message filtering (stripping tool internals) + +**Key Features:** +- Tests all five importance classification tiers +- Validates tool pair preservation (no orphaned tool results) +- Tests round-trip serialization/deserialization +- Tests edge cases (empty history, within-budget, unknown roles) + #### `test_views.py` - API Endpoint Tests Tests Flask endpoints including: - `/llm/status` - LLM availability check @@ -125,6 +142,9 @@ yarn run test:karma -- --file regression/javascript/llm/AIReport.spec.js ✅ Report generation for all categories (security, performance, design) ✅ Report generation for all levels (server, database, schema) ✅ Chat session management and history +✅ Conversation history compaction and token budgets +✅ Conversational message filtering +✅ History serialization/deserialization round-trip ✅ Streaming progress updates via SSE ✅ API endpoint authentication and authorization ✅ React component rendering in both themes diff --git a/web/pgadmin/llm/tests/test_compaction.py b/web/pgadmin/llm/tests/test_compaction.py new file mode 100644 index 00000000000..3c83d59b6b7 --- /dev/null +++ b/web/pgadmin/llm/tests/test_compaction.py @@ -0,0 +1,359 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Tests for the conversation history compaction module.""" + +import json +from unittest.mock import patch + +from pgadmin.utils.route import BaseTestGenerator +from pgadmin.llm.models import Message, Role, ToolCall, ToolResult +from pgadmin.llm.compaction import ( + estimate_tokens, + estimate_message_tokens, + estimate_history_tokens, + compact_history, + deserialize_history, + filter_conversational, +) + + +class TokenEstimationTestCase(BaseTestGenerator): + """Test cases for token estimation functions.""" + + scenarios = [ + ('Token Estimation - Empty String', dict( + text='', + provider='openai', + expected_tokens=0 + )), + ('Token Estimation - Short Text', dict( + text='Hello world', + provider='openai', + # 11 chars / 4.0 = 2.75 + 10 overhead = ~12 + expected_min=10, + expected_max=15 + )), + ('Token Estimation - SQL Content', dict( + text='SELECT id, name FROM users WHERE active = true', + provider='openai', + # SQL gets a 1.2x multiplier + expected_min=15, + expected_max=30 + )), + ('Token Estimation - Anthropic Provider', dict( + text='Hello world test string', + provider='anthropic', + # 23 chars / 3.8 = ~6 + 10 overhead = ~16 + expected_min=14, + expected_max=20 + )), + ] + + def setUp(self): + pass + + def runTest(self): + """Test token estimation.""" + result = estimate_tokens(self.text, self.provider) + + if hasattr(self, 'expected_tokens'): + self.assertEqual(result, self.expected_tokens) + else: + self.assertGreaterEqual(result, self.expected_min) + self.assertLessEqual(result, self.expected_max) + + def tearDown(self): + pass + + +class CompactHistoryTestCase(BaseTestGenerator): + """Test cases for conversation history compaction.""" + + scenarios = [ + ('Compact - Empty History', dict( + test_method='test_empty_history' + )), + ('Compact - Within Budget', dict( + test_method='test_within_budget' + )), + ('Compact - Preserves First And Recent', dict( + test_method='test_preserves_first_and_recent' + )), + ('Compact - Drops Low Value Messages', dict( + test_method='test_drops_low_value' + )), + ('Compact - Keeps Tool Pairs Together', dict( + test_method='test_tool_pairs' + )), + ] + + def setUp(self): + pass + + def runTest(self): + """Run the specified test method.""" + getattr(self, self.test_method)() + + def test_empty_history(self): + """Empty history should return empty list.""" + result = compact_history([], max_tokens=1000) + self.assertEqual(result, []) + + def test_within_budget(self): + """History within budget should be returned unchanged.""" + messages = [ + Message.user('Hello'), + Message.assistant('Hi there!'), + ] + result = compact_history(messages, max_tokens=100000) + self.assertEqual(len(result), 2) + + def test_preserves_first_and_recent(self): + """First message and recent window should always be preserved.""" + messages = [Message.user('First message')] + for i in range(20): + messages.append(Message.user(f'Message {i}')) + messages.append(Message.assistant(f'Response {i}')) + + # Use a very small token budget to force compaction + result = compact_history( + messages, max_tokens=500, recent_window=4 + ) + + # First message should be preserved + self.assertEqual(result[0].content, 'First message') + # Last 4 messages should be preserved + last_4_original = messages[-4:] + last_4_result = result[-4:] + for orig, res in zip(last_4_original, last_4_result): + self.assertEqual(orig.content, res.content) + + def test_drops_low_value(self): + """Low-value messages should be dropped first.""" + # Filler only on important messages to inflate token count; + # keep transient messages short so they classify as low-value. + filler = ' This is extra text to increase token count.' * 5 + messages = [ + Message.user('First important query' + filler), + # Short transient messages (low value) - no filler + Message.user('ok'), + Message.assistant('ok'), + Message.user('thanks'), + Message.assistant('sure'), + # More substantial messages + Message.user('Show me the schema with CREATE TABLE' + filler), + Message.assistant( + 'Here is the schema with CREATE TABLE...' + filler + ), + # Recent messages + Message.user('Final question' + filler), + Message.assistant('Final answer with details' + filler), + ] + + result = compact_history( + messages, max_tokens=500, recent_window=2 + ) + + # Should have fewer messages than original + self.assertLess(len(result), len(messages)) + # First message preserved + self.assertIn('First important query', result[0].content) + # Last 2 preserved + self.assertIn('Final answer with details', result[-1].content) + # Transient messages should be dropped + contents = [m.content for m in result] + for short_msg in ['ok', 'thanks', 'sure']: + self.assertNotIn(short_msg, contents) + + def test_tool_pairs(self): + """Tool call/result pairs should be dropped together.""" + tc = ToolCall(id='tc1', name='get_schema', arguments={}) + messages = [ + Message.user('Get schema'), + Message.assistant('', tool_calls=[tc]), + Message.tool_result( + tool_call_id='tc1', + content='{"tables": ["users", "orders"]}' * 100 + ), + Message.assistant('Found tables: users and orders'), + Message.user('Recent query'), + Message.assistant('Recent response'), + ] + + result = compact_history( + messages, max_tokens=200, recent_window=2 + ) + + # If assistant+tool_call is dropped, tool_result should + # also be dropped (not left orphaned) + has_tool_call = any( + m.role == Role.ASSISTANT and m.tool_calls + for m in result + ) + has_tool_result = any( + m.role == Role.TOOL for m in result + ) + # Both should be present or both absent + self.assertEqual(has_tool_call, has_tool_result) + + def tearDown(self): + pass + + +class DeserializeHistoryTestCase(BaseTestGenerator): + """Test cases for deserializing conversation history.""" + + scenarios = [ + ('Deserialize - Empty', dict( + test_method='test_empty' + )), + ('Deserialize - Basic Messages', dict( + test_method='test_basic_messages' + )), + ('Deserialize - With Tool Calls', dict( + test_method='test_with_tool_calls' + )), + ('Deserialize - Skips Unknown Roles', dict( + test_method='test_unknown_roles' + )), + ] + + def setUp(self): + pass + + def runTest(self): + """Run the specified test method.""" + getattr(self, self.test_method)() + + def test_empty(self): + """Empty list should return empty list.""" + result = deserialize_history([]) + self.assertEqual(result, []) + + def test_basic_messages(self): + """Should deserialize user and assistant messages.""" + data = [ + {'role': 'user', 'content': 'Hello'}, + {'role': 'assistant', 'content': 'Hi there!'}, + ] + result = deserialize_history(data) + self.assertEqual(len(result), 2) + self.assertEqual(result[0].role, Role.USER) + self.assertEqual(result[0].content, 'Hello') + self.assertEqual(result[1].role, Role.ASSISTANT) + self.assertEqual(result[1].content, 'Hi there!') + + def test_with_tool_calls(self): + """Should deserialize messages with tool calls.""" + data = [ + { + 'role': 'assistant', + 'content': '', + 'tool_calls': [{ + 'id': 'tc1', + 'name': 'get_schema', + 'arguments': {'schema': 'public'} + }] + }, + ] + result = deserialize_history(data) + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0].tool_calls), 1) + self.assertEqual(result[0].tool_calls[0].name, 'get_schema') + + def test_unknown_roles(self): + """Should skip messages with unknown roles.""" + data = [ + {'role': 'user', 'content': 'Hello'}, + {'role': 'unknown_role', 'content': 'Skip me'}, + {'role': 'assistant', 'content': 'Hi'}, + ] + result = deserialize_history(data) + self.assertEqual(len(result), 2) + + def tearDown(self): + pass + + +class FilterConversationalTestCase(BaseTestGenerator): + """Test cases for filtering conversational messages.""" + + scenarios = [ + ('Filter - Keeps User And Final Assistant', dict( + test_method='test_keeps_conversational' + )), + ('Filter - Drops Tool Messages', dict( + test_method='test_drops_tool_messages' + )), + ('Filter - Drops Intermediate Assistant', dict( + test_method='test_drops_intermediate_assistant' + )), + ] + + def setUp(self): + pass + + def runTest(self): + """Run the specified test method.""" + getattr(self, self.test_method)() + + def test_keeps_conversational(self): + """Should keep user messages and final assistant responses.""" + messages = [ + Message.user('Hello'), + Message.assistant('Hi there!'), + Message.user('Show me users'), + Message.assistant('Here is the SQL'), + ] + result = filter_conversational(messages) + self.assertEqual(len(result), 4) + + def test_drops_tool_messages(self): + """Should drop tool result messages.""" + tc = ToolCall(id='tc1', name='get_schema', arguments={}) + messages = [ + Message.user('Get schema'), + Message.assistant('', tool_calls=[tc]), + Message.tool_result( + tool_call_id='tc1', + content='{"tables": ["users"]}' + ), + Message.assistant('Found the users table.'), + ] + result = filter_conversational(messages) + self.assertEqual(len(result), 2) + self.assertEqual(result[0].role, Role.USER) + self.assertEqual(result[1].role, Role.ASSISTANT) + self.assertEqual(result[1].content, 'Found the users table.') + + def test_drops_intermediate_assistant(self): + """Should drop assistant messages that have tool calls.""" + tc1 = ToolCall(id='tc1', name='get_schema', arguments={}) + tc2 = ToolCall(id='tc2', name='execute_sql', arguments={}) + messages = [ + Message.user('Complex query'), + Message.assistant('', tool_calls=[tc1]), + Message.tool_result( + tool_call_id='tc1', content='schema data' + ), + Message.assistant('', tool_calls=[tc2]), + Message.tool_result( + tool_call_id='tc2', content='query results' + ), + Message.assistant('Here are the final results.'), + ] + result = filter_conversational(messages) + self.assertEqual(len(result), 2) + self.assertEqual(result[0].content, 'Complex query') + self.assertEqual(result[1].content, + 'Here are the final results.') + + def tearDown(self): + pass diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index ce2cd3fe0b4..73f11059438 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -2875,6 +2875,7 @@ def nlq_chat_stream(trans_id): data = request.get_json(silent=True) or {} user_message = data.get('message', '').strip() conversation_id = data.get('conversation_id') + history_data = data.get('history', []) if not user_message: return make_json_response( @@ -2885,6 +2886,10 @@ def nlq_chat_stream(trans_id): def generate(): """Generator for SSE events.""" import secrets as py_secrets + from pgadmin.llm.compaction import ( + deserialize_history, compact_history + ) + from pgadmin.llm.utils import get_default_provider try: # Send thinking status @@ -2893,12 +2898,23 @@ def generate(): 'message': gettext('Analyzing your request...') }) - # Call the LLM with database tools - response_text, _ = chat_with_database( + # Deserialize and compact conversation history + conversation_history = None + if history_data: + conversation_history = deserialize_history(history_data) + provider = get_default_provider() or 'openai' + conversation_history = compact_history( + conversation_history, + provider=provider + ) + + # Call the LLM with database tools and history + response_text, updated_history = chat_with_database( user_message=user_message, sid=trans_obj.sid, did=trans_obj.did, - system_prompt=NLQ_SYSTEM_PROMPT + system_prompt=NLQ_SYSTEM_PROMPT, + conversation_history=conversation_history ) # Try to parse the response as JSON @@ -2968,12 +2984,24 @@ def generate(): else: new_conversation_id = conversation_id + # Serialize updated history for the frontend. + # Only include conversational messages (user + final + # assistant responses) to keep history size manageable. + # Internal tool call/result messages are ephemeral to + # each turn and don't need to round-trip. + from pgadmin.llm.compaction import filter_conversational + serialized_history = [ + m.to_dict() for m in + filter_conversational(updated_history) + ] if updated_history else [] + # Send the final result yield _nlq_sse_event({ 'type': 'complete', 'sql': sql, 'explanation': explanation, - 'conversation_id': new_conversation_id + 'conversation_id': new_conversation_id, + 'history': serialized_history }) except Exception as e: diff --git a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx index 5bbd2c413bd..30501f81223 100644 --- a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx +++ b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx @@ -272,6 +272,7 @@ export function NLQChatPanel() { const [inputValue, setInputValue] = useState(''); const [isLoading, setIsLoading] = useState(false); const [conversationId, setConversationId] = useState(null); + const [conversationHistory, setConversationHistory] = useState([]); const [thinkingMessageId, setThinkingMessageId] = useState(null); const [llmInfo, setLlmInfo] = useState({ provider: null, model: null }); @@ -291,6 +292,7 @@ export function NLQChatPanel() { const abortControllerRef = useRef(null); const readerRef = useRef(null); const stoppedRef = useRef(false); + const clearedRef = useRef(false); const eventBus = useContext(QueryToolEventsContext); const queryToolCtx = useContext(QueryToolContext); const editorPrefs = usePreferences().getPreferencesForModule('editor'); @@ -405,8 +407,21 @@ export function NLQChatPanel() { }; const handleClearConversation = () => { + // Mark as cleared so in-flight stream handlers ignore late events + clearedRef.current = true; + // Cancel any active stream + if (readerRef.current) { + readerRef.current.cancel(); + readerRef.current = null; + } + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + } setMessages([]); setConversationId(null); + setConversationHistory([]); + setIsLoading(false); }; // Stop the current request @@ -444,8 +459,9 @@ export function NLQChatPanel() { const handleSubmit = async () => { if (!inputValue.trim() || isLoading) return; - // Reset stopped flag + // Reset stopped and cleared flags stoppedRef.current = false; + clearedRef.current = false; // Fetch latest LLM provider/model info before submitting fetchLlmInfo(); @@ -505,6 +521,7 @@ export function NLQChatPanel() { body: JSON.stringify({ message: userMessage, conversation_id: conversationId, + history: conversationHistory, }), signal: controller.signal, } @@ -545,8 +562,8 @@ export function NLQChatPanel() { readerRef.current = null; - // Check if user manually stopped - if (stoppedRef.current) { + // Check if user manually stopped (but not cleared) + if (stoppedRef.current && !clearedRef.current) { setMessages((prev) => [ ...prev.filter((m) => m.id !== thinkingId), { @@ -559,8 +576,10 @@ export function NLQChatPanel() { clearTimeout(timeoutId); abortControllerRef.current = null; readerRef.current = null; - // Show appropriate message based on error type - if (error.name === 'AbortError') { + // If conversation was cleared, ignore all late errors + if (clearedRef.current) { + // Do nothing - conversation was wiped + } else if (error.name === 'AbortError') { // Check if this was a user-initiated stop or a timeout if (stoppedRef.current) { // User manually stopped @@ -630,6 +649,9 @@ export function NLQChatPanel() { if (event.conversation_id) { setConversationId(event.conversation_id); } + if (event.history) { + setConversationHistory(event.history); + } break; case 'error': diff --git a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py index 6f1f3447990..cd9ca7ab0a3 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py +++ b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py @@ -47,6 +47,22 @@ class NLQChatTestCase(BaseTestGenerator): '"explanation": "Gets all users"}' ) )), + ('NLQ Chat - With History', dict( + llm_enabled=True, + valid_transaction=True, + message='Now filter by active users', + history=[ + {'role': 'user', 'content': 'Find all users'}, + {'role': 'assistant', + 'content': '{"sql": "SELECT * FROM users;", ' + '"explanation": "Gets all users"}'}, + ], + expected_error=False, + mock_response=( + '{"sql": "SELECT * FROM users WHERE active = true;", ' + '"explanation": "Gets active users"}' + ) + )), ] def setUp(self): @@ -93,12 +109,14 @@ def runTest(self): patches.append(mock_check_trans) # Mock chat_with_database + mock_chat_patcher = None + mock_chat_obj = None if hasattr(self, 'mock_response'): - mock_chat = patch( + mock_chat_patcher = patch( 'pgadmin.llm.chat.chat_with_database', return_value=(self.mock_response, []) ) - patches.append(mock_chat) + patches.append(mock_chat_patcher) # Mock CSRF protection mock_csrf = patch( @@ -108,15 +126,22 @@ def runTest(self): patches.append(mock_csrf) # Start all patches + started_mocks = [] for p in patches: - p.start() + m = p.start() + started_mocks.append(m) + if p is mock_chat_patcher: + mock_chat_obj = m try: # Make request message = getattr(self, 'message', 'test query') + request_data = {'message': message} + if hasattr(self, 'history'): + request_data['history'] = self.history response = self.tester.post( f'/sqleditor/nlq/chat/{trans_id}/stream', - data=json.dumps({'message': message}), + data=json.dumps(request_data), content_type='application/json', follow_redirects=True ) @@ -137,6 +162,19 @@ def runTest(self): self.assertEqual(response.status_code, 200) self.assertIn('text/event-stream', response.content_type) + # Verify history was passed to chat_with_database + if hasattr(self, 'history') and mock_chat_obj: + mock_chat_obj.assert_called_once() + call_kwargs = mock_chat_obj.call_args.kwargs + conv_hist = call_kwargs.get( + 'conversation_history', [] + ) + self.assertTrue( + len(conv_hist) > 0, + 'conversation_history should be non-empty ' + 'when history is provided' + ) + finally: # Stop all patches for p in patches: