diff --git a/docs/en_US/release_notes_9_14.rst b/docs/en_US/release_notes_9_14.rst index f841dcd5cd5..4ebbb207e64 100644 --- a/docs/en_US/release_notes_9_14.rst +++ b/docs/en_US/release_notes_9_14.rst @@ -31,3 +31,4 @@ Bug fixes | `Issue #9279 `_ - Fixed an issue where OAuth2 authentication fails with 'object has no attribute' if OAUTH2_AUTO_CREATE_USER is False. | `Issue #9392 `_ - Ensure that the Geometry Viewer refreshes when re-running queries or switching geometry columns, preventing stale data from being displayed. | `Issue #9721 `_ - Fixed an issue where permissions page is not completely accessible on full scroll. + | `Issue #9734 `_ - Fixed an issue where LLM responses are not streamed or rendered properly in the AI Assistant. diff --git a/web/pgadmin/llm/chat.py b/web/pgadmin/llm/chat.py index 40e99219111..8c9fc1593eb 100644 --- a/web/pgadmin/llm/chat.py +++ b/web/pgadmin/llm/chat.py @@ -14,10 +14,11 @@ """ import json -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union from pgadmin.llm.client import get_llm_client, is_llm_available -from pgadmin.llm.models import Message, StopReason +from pgadmin.llm.models import Message, LLMResponse, StopReason from pgadmin.llm.tools import DATABASE_TOOLS, execute_tool, DatabaseToolError from pgadmin.llm.utils import get_max_tool_iterations @@ -153,6 +154,117 @@ def chat_with_database( ) +def chat_with_database_stream( + user_message: str, + sid: int, + did: int, + conversation_history: Optional[list[Message]] = None, + system_prompt: Optional[str] = None, + max_tool_iterations: Optional[int] = None, + provider: Optional[str] = None, + model: Optional[str] = None +) -> Generator[Union[str, tuple[str, list[Message]]], None, None]: + """ + Stream an LLM chat conversation with database tool access. + + Like chat_with_database, but yields text chunks as the final + response streams in. During tool-use iterations, no text is + yielded (tools are executed silently). + + Yields: + str: Text content chunks from the final LLM response. + + The last item yielded is a tuple of + (final_response_text, updated_conversation_history). + + Raises: + LLMClientError: If the LLM request fails. + RuntimeError: If LLM is not available or max iterations exceeded. + """ + if not is_llm_available(): + raise RuntimeError("LLM is not configured. Please configure an LLM " + "provider in Preferences > AI.") + + client = get_llm_client(provider=provider, model=model) + if not client: + raise RuntimeError("Failed to create LLM client") + + messages = list(conversation_history) if conversation_history else [] + messages.append(Message.user(user_message)) + + if system_prompt is None: + system_prompt = DEFAULT_SYSTEM_PROMPT + + if max_tool_iterations is None: + max_tool_iterations = get_max_tool_iterations() + + iteration = 0 + while iteration < max_tool_iterations: + iteration += 1 + + # Stream the LLM response, yielding text chunks as they arrive + response = None + for item in client.chat_stream( + messages=messages, + tools=DATABASE_TOOLS, + system_prompt=system_prompt + ): + if isinstance(item, LLMResponse): + response = item + elif isinstance(item, str): + yield item + + if response is None: + raise RuntimeError("No response received from LLM") + + messages.append(response.to_message()) + + if response.stop_reason != StopReason.TOOL_USE: + # Final response - yield the completion tuple + yield (response.content, messages) + return + + # Signal that tools are being executed so the caller can + # reset streaming state and show a thinking indicator + yield ('tool_use', [tc.name for tc in response.tool_calls]) + + # Execute tool calls + tool_results = [] + for tool_call in response.tool_calls: + try: + result = execute_tool( + tool_name=tool_call.name, + arguments=tool_call.arguments, + sid=sid, + did=did + ) + tool_results.append(Message.tool_result( + tool_call_id=tool_call.id, + content=json.dumps(result, default=str), + is_error=False + )) + except (DatabaseToolError, ValueError) as e: + tool_results.append(Message.tool_result( + tool_call_id=tool_call.id, + content=json.dumps({"error": str(e)}), + is_error=True + )) + except Exception as e: + tool_results.append(Message.tool_result( + tool_call_id=tool_call.id, + content=json.dumps({ + "error": f"Unexpected error: {str(e)}" + }), + is_error=True + )) + + messages.extend(tool_results) + + raise RuntimeError( + f"Exceeded maximum tool iterations ({max_tool_iterations})" + ) + + def single_query( question: str, sid: int, diff --git a/web/pgadmin/llm/client.py b/web/pgadmin/llm/client.py index 5a4f114e6d7..4aa45808a1f 100644 --- a/web/pgadmin/llm/client.py +++ b/web/pgadmin/llm/client.py @@ -10,7 +10,8 @@ """Base LLM client interface and factory.""" from abc import ABC, abstractmethod -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union from pgadmin.llm.models import ( Message, Tool, LLMResponse, LLMError @@ -76,6 +77,48 @@ def chat( """ pass + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """ + Stream a chat response from the LLM. + + Yields text chunks (str) as they arrive, then yields + a final LLMResponse with the complete response metadata. + + The default implementation falls back to non-streaming chat(). + + Args: + messages: List of conversation messages. + tools: Optional list of tools the LLM can use. + system_prompt: Optional system prompt to set context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0 = deterministic). + **kwargs: Additional provider-specific parameters. + + Yields: + str: Text content chunks as they arrive. + LLMResponse: Final response with complete metadata (last item). + """ + # Default: fall back to non-streaming + response = self.chat( + messages=messages, + tools=tools, + system_prompt=system_prompt, + max_tokens=max_tokens, + temperature=temperature, + **kwargs + ) + if response.content: + yield response.content + yield response + def validate_connection(self) -> tuple[bool, Optional[str]]: """ Validate the connection to the LLM provider. diff --git a/web/pgadmin/llm/prompts/nlq.py b/web/pgadmin/llm/prompts/nlq.py index 78dd337466a..3e89d57cdca 100644 --- a/web/pgadmin/llm/prompts/nlq.py +++ b/web/pgadmin/llm/prompts/nlq.py @@ -28,12 +28,10 @@ - Use explicit column names instead of SELECT * - For UPDATE/DELETE, always include WHERE clauses -Your response MUST be a JSON object in this exact format: -{"sql": "YOUR SQL QUERY HERE", "explanation": "Brief explanation"} - -Rules: -- Return ONLY the JSON object, nothing else -- No markdown code blocks -- If you need clarification, set "sql" to null and put \ -your question in "explanation" +Response format: +- Always put SQL in fenced code blocks with the sql language tag +- You may include multiple SQL blocks if the request needs \ +multiple statements +- Briefly explain what each query does +- If you need clarification, just ask — no code blocks needed """ diff --git a/web/pgadmin/llm/providers/anthropic.py b/web/pgadmin/llm/providers/anthropic.py index d2e6d4af4bd..cc6595c2c4c 100644 --- a/web/pgadmin/llm/providers/anthropic.py +++ b/web/pgadmin/llm/providers/anthropic.py @@ -10,10 +10,12 @@ """Anthropic Claude LLM client implementation.""" import json +import socket import ssl import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid # Try to use certifi for proper SSL certificate handling @@ -274,3 +276,229 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from Anthropic.""" + payload = { + 'model': self._model, + 'max_tokens': max_tokens, + 'messages': self._convert_messages(messages), + 'stream': True + } + + if system_prompt: + payload['system'] = system_prompt + + if temperature > 0: + payload['temperature'] = temperature + + if tools: + payload['tools'] = self._convert_tools(tools) + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + headers = { + 'Content-Type': 'application/json', + 'x-api-key': self._api_key, + 'anthropic-version': API_VERSION + } + + request = urllib.request.Request( + API_URL, + data=json.dumps(payload).encode('utf-8'), + headers=headers, + method='POST' + ) + + try: + response = urllib.request.urlopen( + request, timeout=120, context=SSL_CONTEXT + ) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get( + 'error', {} + ).get('message', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Connection error: {e.reason}", + provider=self.provider_name, + retryable=True + )) + except socket.timeout: + raise LLMClientError(LLMError( + message="Request timed out.", + code='timeout', + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_anthropic_stream(response) + finally: + response.close() + + def _read_anthropic_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an Anthropic SSE stream. + + Uses readline() for incremental reading. + """ + content_parts = [] + tool_calls = [] + current_tool_block = None + tool_input_json = '' + stop_reason_str = None + model_name = self._model + usage = Usage() + in_text_block = False + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line or line.startswith(':'): + continue + + if line.startswith('event: '): + continue + + if not line.startswith('data: '): + continue + + try: + data = json.loads(line[6:]) + except json.JSONDecodeError: + continue + + event_type = data.get('type', '') + + if event_type == 'message_start': + msg = data.get('message', {}) + model_name = msg.get('model', self._model) + u = msg.get('usage', {}) + usage = Usage( + input_tokens=u.get('input_tokens', 0), + output_tokens=u.get('output_tokens', 0), + total_tokens=( + u.get('input_tokens', 0) + + u.get('output_tokens', 0) + ) + ) + + elif event_type == 'content_block_start': + block = data.get('content_block', {}) + if block.get('type') == 'tool_use': + current_tool_block = { + 'id': block.get('id', str(uuid.uuid4())), + 'name': block.get('name', '') + } + tool_input_json = '' + elif block.get('type') == 'text': + # Emit a separator between text blocks to + # match _parse_response() which joins with '\n' + if in_text_block: + content_parts.append('\n') + yield '\n' + in_text_block = True + + elif event_type == 'content_block_delta': + delta = data.get('delta', {}) + if delta.get('type') == 'text_delta': + text = delta.get('text', '') + if text: + content_parts.append(text) + yield text + elif delta.get('type') == 'input_json_delta': + tool_input_json += delta.get( + 'partial_json', '' + ) + + elif event_type == 'content_block_stop': + if current_tool_block is not None: + try: + arguments = json.loads( + tool_input_json + ) if tool_input_json else {} + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=current_tool_block['id'], + name=current_tool_block['name'], + arguments=arguments + )) + current_tool_block = None + tool_input_json = '' + + elif event_type == 'message_delta': + delta = data.get('delta', {}) + stop_reason_str = delta.get('stop_reason') + u = data.get('usage', {}) + if u: + usage = Usage( + input_tokens=usage.input_tokens, + output_tokens=u.get( + 'output_tokens', + usage.output_tokens + ), + total_tokens=( + usage.input_tokens + + u.get( + 'output_tokens', + usage.output_tokens + ) + ) + ) + + # Build final response + stop_reason_map = { + 'end_turn': StopReason.END_TURN, + 'tool_use': StopReason.TOOL_USE, + 'max_tokens': StopReason.MAX_TOKENS, + 'stop_sequence': StopReason.STOP_SEQUENCE + } + stop_reason = stop_reason_map.get( + stop_reason_str or '', StopReason.UNKNOWN + ) + + yield LLMResponse( + content=''.join(content_parts), + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=usage + ) diff --git a/web/pgadmin/llm/providers/docker.py b/web/pgadmin/llm/providers/docker.py index 2d65a21a46c..2c2efedea11 100644 --- a/web/pgadmin/llm/providers/docker.py +++ b/web/pgadmin/llm/providers/docker.py @@ -16,9 +16,11 @@ import json import socket import ssl +import urllib.parse import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid # Try to use certifi for proper SSL certificate handling @@ -42,6 +44,25 @@ DEFAULT_API_URL = 'http://localhost:12434' DEFAULT_MODEL = 'ai/qwen3-coder' +# Allowed loopback hostnames for the Docker endpoint +_LOOPBACK_HOSTS = {'localhost', '127.0.0.1', '::1', '[::1]'} + + +def _validate_loopback_url(url: str) -> None: + """Ensure the URL uses HTTP(S) and points to a loopback address.""" + parsed = urllib.parse.urlparse(url) + if parsed.scheme not in ('http', 'https'): + raise ValueError( + f"Docker Model Runner URL must use http or https, " + f"got: {parsed.scheme}" + ) + hostname = (parsed.hostname or '').lower() + if hostname not in _LOOPBACK_HOSTS: + raise ValueError( + f"Docker Model Runner URL must point to a loopback address " + f"(localhost/127.0.0.1/::1), got: {hostname}" + ) + class DockerClient(LLMClient): """ @@ -63,6 +84,7 @@ def __init__( model: Optional model name. Defaults to ai/qwen3-coder. """ self._api_url = (api_url or DEFAULT_API_URL).rstrip('/') + _validate_loopback_url(self._api_url) self._model = model or DEFAULT_MODEL @property @@ -357,3 +379,216 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from Docker Model Runner.""" + converted_messages = self._convert_messages(messages) + + if system_prompt: + converted_messages.insert(0, { + 'role': 'system', + 'content': system_prompt + }) + + payload = { + 'model': self._model, + 'messages': converted_messages, + 'max_completion_tokens': max_tokens, + 'temperature': temperature, + 'stream': True, + 'stream_options': {'include_usage': True} + } + + if tools: + payload['tools'] = self._convert_tools(tools) + payload['tool_choice'] = 'auto' + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + headers = { + 'Content-Type': 'application/json' + } + + url = f'{self._api_url}/engines/v1/chat/completions' + + request = urllib.request.Request( + url, + data=json.dumps(payload).encode('utf-8'), + headers=headers, + method='POST' + ) + + try: + response = urllib.request.urlopen( + request, timeout=300, context=SSL_CONTEXT + ) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get( + 'error', {} + ).get('message', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Connection error: {e.reason}. " + f"Is Docker Model Runner running at " + f"{self._api_url}?", + provider=self.provider_name, + retryable=True + )) + except socket.timeout: + raise LLMClientError(LLMError( + message="Request timed out.", + code='timeout', + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_openai_stream(response) + finally: + response.close() + + def _read_openai_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an OpenAI-format SSE stream. + + Uses readline() for incremental reading. + """ + content_parts = [] + tool_calls_data = {} + finish_reason = None + model_name = self._model + usage = Usage() + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line or line.startswith(':'): + continue + + if line == 'data: [DONE]': + continue + + if not line.startswith('data: '): + continue + + try: + data = json.loads(line[6:]) + except json.JSONDecodeError: + continue + + if 'usage' in data and data['usage']: + u = data['usage'] + usage = Usage( + input_tokens=u.get('prompt_tokens', 0), + output_tokens=u.get('completion_tokens', 0), + total_tokens=u.get('total_tokens', 0) + ) + + if 'model' in data: + model_name = data['model'] + + choices = data.get('choices', []) + if not choices: + continue + + choice = choices[0] + delta = choice.get('delta', {}) + + if choice.get('finish_reason'): + finish_reason = choice['finish_reason'] + + text_chunk = delta.get('content') + if text_chunk: + content_parts.append(text_chunk) + yield text_chunk + + for tc_delta in delta.get('tool_calls', []): + idx = tc_delta.get('index', 0) + if idx not in tool_calls_data: + tool_calls_data[idx] = { + 'id': '', 'name': '', 'arguments': '' + } + tc = tool_calls_data[idx] + if 'id' in tc_delta: + tc['id'] = tc_delta['id'] + func = tc_delta.get('function', {}) + if 'name' in func: + tc['name'] = func['name'] + if 'arguments' in func: + tc['arguments'] += func['arguments'] + + content = ''.join(content_parts) + tool_calls = [] + for idx in sorted(tool_calls_data.keys()): + tc = tool_calls_data[idx] + try: + arguments = json.loads(tc['arguments']) \ + if tc['arguments'] else {} + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=tc['id'] or str(uuid.uuid4()), + name=tc['name'], + arguments=arguments + )) + + stop_reason_map = { + 'stop': StopReason.END_TURN, + 'tool_calls': StopReason.TOOL_USE, + 'length': StopReason.MAX_TOKENS, + 'content_filter': StopReason.STOP_SEQUENCE + } + stop_reason = stop_reason_map.get( + finish_reason or '', StopReason.UNKNOWN + ) + + if not content and not tool_calls: + raise LLMClientError(LLMError( + message='No response content returned from API', + provider=self.provider_name, + retryable=False + )) + + yield LLMResponse( + content=content, + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=usage + ) diff --git a/web/pgadmin/llm/providers/ollama.py b/web/pgadmin/llm/providers/ollama.py index ad683109f72..02d1ca4b5c7 100644 --- a/web/pgadmin/llm/providers/ollama.py +++ b/web/pgadmin/llm/providers/ollama.py @@ -10,10 +10,10 @@ """Ollama LLM client implementation.""" import json -import re import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid from pgadmin.llm.client import LLMClient, LLMClientError @@ -223,7 +223,7 @@ def _make_request(self, payload: dict) -> dict: message=error_msg, code=str(e.code), provider=self.provider_name, - retryable=e.code in (500, 502, 503, 504) + retryable=e.code in (429, 500, 502, 503, 504) )) except urllib.error.URLError as e: raise LLMClientError(LLMError( @@ -234,8 +234,6 @@ def _make_request(self, payload: dict) -> dict: def _parse_response(self, data: dict) -> LLMResponse: """Parse the Ollama API response into an LLMResponse.""" - import re - message = data.get('message', {}) content = message.get('content', '') tool_calls = [] @@ -288,3 +286,167 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from Ollama.""" + converted_messages = self._convert_messages(messages) + + if system_prompt: + converted_messages.insert(0, { + 'role': 'system', + 'content': system_prompt + }) + + payload = { + 'model': self._model, + 'messages': converted_messages, + 'stream': True, + 'options': { + 'num_predict': max_tokens, + 'temperature': temperature + } + } + + if tools: + payload['tools'] = self._convert_tools(tools) + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + url = f'{self._api_url}/api/chat' + + request = urllib.request.Request( + url, + data=json.dumps(payload).encode('utf-8'), + headers={'Content-Type': 'application/json'}, + method='POST' + ) + + try: + response = urllib.request.urlopen(request, timeout=300) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get('error', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Cannot connect to Ollama: {e.reason}", + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_ollama_stream(response) + finally: + response.close() + + def _read_ollama_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an Ollama NDJSON stream. + + Uses readline() for incremental reading. + """ + content_parts = [] + tool_calls = [] + done_reason = None + model_name = self._model + input_tokens = 0 + output_tokens = 0 + final_data = None + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line: + continue + + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + + msg = data.get('message', {}) + + # Text content + text = msg.get('content', '') + if text: + content_parts.append(text) + yield text + + # Tool calls (in final message) + for tc in msg.get('tool_calls', []): + func = tc.get('function', {}) + arguments = func.get('arguments', {}) + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=str(uuid.uuid4()), + name=func.get('name', ''), + arguments=arguments + )) + + if data.get('done'): + final_data = data + done_reason = data.get('done_reason', '') + model_name = data.get('model', self._model) + input_tokens = data.get('prompt_eval_count', 0) + output_tokens = data.get('eval_count', 0) + + # Build final response + if tool_calls: + stop_reason = StopReason.TOOL_USE + elif done_reason == 'stop': + stop_reason = StopReason.END_TURN + elif done_reason == 'length': + stop_reason = StopReason.MAX_TOKENS + else: + stop_reason = StopReason.UNKNOWN + + yield LLMResponse( + content=''.join(content_parts), + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens + ), + raw_response=final_data + ) diff --git a/web/pgadmin/llm/providers/openai.py b/web/pgadmin/llm/providers/openai.py index 3e7c169af1e..1f4f81ae7f2 100644 --- a/web/pgadmin/llm/providers/openai.py +++ b/web/pgadmin/llm/providers/openai.py @@ -14,7 +14,8 @@ import ssl import urllib.request import urllib.error -from typing import Optional +from collections.abc import Generator +from typing import Optional, Union import uuid # Try to use certifi for proper SSL certificate handling @@ -343,3 +344,220 @@ def _parse_response(self, data: dict) -> LLMResponse: usage=usage, raw_response=data ) + + def chat_stream( + self, + messages: list[Message], + tools: Optional[list[Tool]] = None, + system_prompt: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 0.0, + **kwargs + ) -> Generator[Union[str, LLMResponse], None, None]: + """Stream a chat response from OpenAI.""" + converted_messages = self._convert_messages(messages) + + if system_prompt: + converted_messages.insert(0, { + 'role': 'system', + 'content': system_prompt + }) + + payload = { + 'model': self._model, + 'messages': converted_messages, + 'max_completion_tokens': max_tokens, + 'temperature': temperature, + 'stream': True, + 'stream_options': {'include_usage': True} + } + + if tools: + payload['tools'] = self._convert_tools(tools) + payload['tool_choice'] = 'auto' + + try: + yield from self._process_stream(payload) + except LLMClientError: + raise + except Exception as e: + raise LLMClientError(LLMError( + message=f"Streaming request failed: {str(e)}", + provider=self.provider_name + )) + + def _process_stream( + self, payload: dict + ) -> Generator[Union[str, LLMResponse], None, None]: + """Make a streaming request and yield chunks.""" + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self._api_key}' + } + + request = urllib.request.Request( + API_URL, + data=json.dumps(payload).encode('utf-8'), + headers=headers, + method='POST' + ) + + try: + response = urllib.request.urlopen( + request, timeout=120, context=SSL_CONTEXT + ) + except urllib.error.HTTPError as e: + error_body = e.read().decode('utf-8') + try: + error_data = json.loads(error_body) + error_msg = error_data.get( + 'error', {} + ).get('message', str(e)) + except json.JSONDecodeError: + error_msg = error_body or str(e) + raise LLMClientError(LLMError( + message=error_msg, + code=str(e.code), + provider=self.provider_name, + retryable=e.code in (429, 500, 502, 503, 504) + )) + except urllib.error.URLError as e: + raise LLMClientError(LLMError( + message=f"Connection error: {e.reason}", + provider=self.provider_name, + retryable=True + )) + except socket.timeout: + raise LLMClientError(LLMError( + message="Request timed out.", + code='timeout', + provider=self.provider_name, + retryable=True + )) + + try: + yield from self._read_openai_stream(response) + finally: + response.close() + + def _read_openai_stream( + self, response + ) -> Generator[Union[str, LLMResponse], None, None]: + """Read and parse an OpenAI-format SSE stream. + + Uses readline() for incremental reading — it returns as soon + as a complete line arrives from the server, unlike read() + which blocks until a buffer fills up. + """ + content_parts = [] + # tool_calls_data: {index: {id, name, arguments_str}} + tool_calls_data = {} + finish_reason = None + model_name = self._model + usage = Usage() + + while True: + line_bytes = response.readline() + if not line_bytes: + break + + line = line_bytes.decode('utf-8', errors='replace').strip() + + if not line or line.startswith(':'): + continue + + if line == 'data: [DONE]': + continue + + if not line.startswith('data: '): + continue + + try: + data = json.loads(line[6:]) + except json.JSONDecodeError: + continue + + # Extract usage from the final chunk + if 'usage' in data and data['usage']: + u = data['usage'] + usage = Usage( + input_tokens=u.get('prompt_tokens', 0), + output_tokens=u.get('completion_tokens', 0), + total_tokens=u.get('total_tokens', 0) + ) + + if 'model' in data: + model_name = data['model'] + + choices = data.get('choices', []) + if not choices: + continue + + choice = choices[0] + delta = choice.get('delta', {}) + + if choice.get('finish_reason'): + finish_reason = choice['finish_reason'] + + # Text content + text_chunk = delta.get('content') + if text_chunk: + content_parts.append(text_chunk) + yield text_chunk + + # Tool calls (accumulate) + for tc_delta in delta.get('tool_calls', []): + idx = tc_delta.get('index', 0) + if idx not in tool_calls_data: + tool_calls_data[idx] = { + 'id': '', 'name': '', 'arguments': '' + } + tc = tool_calls_data[idx] + if 'id' in tc_delta: + tc['id'] = tc_delta['id'] + func = tc_delta.get('function', {}) + if 'name' in func: + tc['name'] = func['name'] + if 'arguments' in func: + tc['arguments'] += func['arguments'] + + # Build final response + content = ''.join(content_parts) + tool_calls = [] + for idx in sorted(tool_calls_data.keys()): + tc = tool_calls_data[idx] + try: + arguments = json.loads(tc['arguments']) \ + if tc['arguments'] else {} + except json.JSONDecodeError: + arguments = {} + tool_calls.append(ToolCall( + id=tc['id'] or str(uuid.uuid4()), + name=tc['name'], + arguments=arguments + )) + + stop_reason_map = { + 'stop': StopReason.END_TURN, + 'tool_calls': StopReason.TOOL_USE, + 'length': StopReason.MAX_TOKENS, + 'content_filter': StopReason.STOP_SEQUENCE + } + stop_reason = stop_reason_map.get( + finish_reason or '', StopReason.UNKNOWN + ) + + if not content and not tool_calls: + raise LLMClientError(LLMError( + message='No response content returned from API', + provider=self.provider_name, + retryable=False + )) + + yield LLMResponse( + content=content, + tool_calls=tool_calls, + stop_reason=stop_reason, + model=model_name, + usage=usage + ) diff --git a/web/pgadmin/static/js/Theme/dark.js b/web/pgadmin/static/js/Theme/dark.js index 5deb73324d5..b087e49e0ad 100644 --- a/web/pgadmin/static/js/Theme/dark.js +++ b/web/pgadmin/static/js/Theme/dark.js @@ -89,6 +89,7 @@ export default function(basicSettings) { }, otherVars: { colorBrand: '#1b71b5', + hyperlinkColor: '#6CB4EE', borderColor: '#4a4a4a', inputBorderColor: '#6b6b6b', inputDisabledBg: 'inherit', diff --git a/web/pgadmin/static/js/Theme/high_contrast.js b/web/pgadmin/static/js/Theme/high_contrast.js index 184153cb273..0a4a5083cf5 100644 --- a/web/pgadmin/static/js/Theme/high_contrast.js +++ b/web/pgadmin/static/js/Theme/high_contrast.js @@ -87,6 +87,7 @@ export default function(basicSettings) { }, otherVars: { colorBrand: '#84D6FF', + hyperlinkColor: '#84D6FF', borderColor: '#A6B7C8', inputBorderColor: '#8B9CAD', inputDisabledBg: '#1F2932', diff --git a/web/pgadmin/static/js/Theme/light.js b/web/pgadmin/static/js/Theme/light.js index 093928cfef1..00847f91425 100644 --- a/web/pgadmin/static/js/Theme/light.js +++ b/web/pgadmin/static/js/Theme/light.js @@ -89,6 +89,7 @@ export default function(basicSettings) { }, otherVars: { colorBrand: '#326690', + hyperlinkColor: '#1a0dab', iconLoaderUrl: 'url("data:image/svg+xml,%3C%3Fxml version=\'1.0\' encoding=\'utf-8\'%3F%3E%3C!-- Generator: Adobe Illustrator 23.1.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) --%3E%3Csvg version=\'1.1\' id=\'Layer_1\' xmlns=\'http://www.w3.org/2000/svg\' xmlns:xlink=\'http://www.w3.org/1999/xlink\' x=\'0px\' y=\'0px\' viewBox=\'0 0 38 38\' style=\'enable-background:new 0 0 38 38;\' xml:space=\'preserve\'%3E%3Cstyle type=\'text/css\'%3E .st0%7Bfill:none;stroke:%23EBEEF3;stroke-width:5;%7D .st1%7Bfill:none;stroke:%23326690;stroke-width:5;%7D%0A%3C/style%3E%3Cg%3E%3Cg transform=\'translate(1 1)\'%3E%3Ccircle class=\'st0\' cx=\'18\' cy=\'18\' r=\'16\'/%3E%3Cpath class=\'st1\' d=\'M34,18c0-8.8-7.2-16-16-16 \'%3E%3CanimateTransform accumulate=\'none\' additive=\'replace\' attributeName=\'transform\' calcMode=\'linear\' dur=\'0.7s\' fill=\'remove\' from=\'0 18 18\' repeatCount=\'indefinite\' restart=\'always\' to=\'360 18 18\' type=\'rotate\'%3E%3C/animateTransform%3E%3C/path%3E%3C/g%3E%3C/g%3E%3C/svg%3E%0A");', iconLoaderSmall: 'url("data:image/svg+xml,%3C%3Fxml version=\'1.0\' encoding=\'utf-8\'%3F%3E%3C!-- Generator: Adobe Illustrator 23.1.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) --%3E%3Csvg version=\'1.1\' id=\'Layer_1\' xmlns=\'http://www.w3.org/2000/svg\' xmlns:xlink=\'http://www.w3.org/1999/xlink\' x=\'0px\' y=\'0px\' viewBox=\'0 0 38 38\' style=\'enable-background:new 0 0 38 38;\' xml:space=\'preserve\'%3E%3Cstyle type=\'text/css\'%3E .st0%7Bfill:none;stroke:%23EBEEF3;stroke-width:5;%7D .st1%7Bfill:none;stroke:%23326690;stroke-width:5;%7D%0A%3C/style%3E%3Cg%3E%3Cg transform=\'translate(1 1)\'%3E%3Ccircle class=\'st0\' cx=\'18\' cy=\'18\' r=\'16\'/%3E%3Cpath class=\'st1\' d=\'M34,18c0-8.8-7.2-16-16-16 \'%3E%3CanimateTransform accumulate=\'none\' additive=\'replace\' attributeName=\'transform\' calcMode=\'linear\' dur=\'0.7s\' fill=\'remove\' from=\'0 18 18\' repeatCount=\'indefinite\' restart=\'always\' to=\'360 18 18\' type=\'rotate\'%3E%3C/animateTransform%3E%3C/path%3E%3C/g%3E%3C/g%3E%3C/svg%3E%0A")', dashboardPgDoc: 'url("data:image/svg+xml,%3C%3Fxml version=\'1.0\' encoding=\'utf-8\'%3F%3E%3C!-- Generator: Adobe Illustrator 22.1.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) --%3E%3Csvg version=\'1.1\' id=\'Layer_1\' xmlns=\'http://www.w3.org/2000/svg\' xmlns:xlink=\'http://www.w3.org/1999/xlink\' x=\'0px\' y=\'0px\' viewBox=\'0 0 42 42\' style=\'enable-background:new 0 0 42 42;\' xml:space=\'preserve\'%3E%3Cstyle type=\'text/css\'%3E .st0%7Bstroke:%23000000;stroke-width:3.3022;%7D .st1%7Bfill:%23336791;%7D .st2%7Bfill:none;stroke:%23FFFFFF;stroke-width:1.1007;stroke-linecap:round;stroke-linejoin:round;%7D .st3%7Bfill:none;stroke:%23FFFFFF;stroke-width:1.1007;stroke-linecap:round;stroke-linejoin:bevel;%7D .st4%7Bfill:%23FFFFFF;stroke:%23FFFFFF;stroke-width:0.3669;%7D .st5%7Bfill:%23FFFFFF;stroke:%23FFFFFF;stroke-width:0.1835;%7D .st6%7Bfill:none;stroke:%23FFFFFF;stroke-width:0.2649;stroke-linecap:round;stroke-linejoin:round;%7D%0A%3C/style%3E%3Cg id=\'orginal\'%3E%3C/g%3E%3Cg id=\'Layer_x0020_3\'%3E%3Cpath class=\'st0\' d=\'M31.3,30c0.3-2.1,0.2-2.4,1.7-2.1l0.4,0c1.2,0.1,2.8-0.2,3.7-0.6c2-0.9,3.1-2.4,1.2-2 c-4.4,0.9-4.7-0.6-4.7-0.6c4.7-7,6.7-15.8,5-18c-4.6-5.9-12.6-3.1-12.7-3l0,0c-0.9-0.2-1.9-0.3-3-0.3c-2,0-3.5,0.5-4.7,1.4 c0,0-14.3-5.9-13.6,7.4c0.1,2.8,4,21.3,8.7,15.7c1.7-2,3.3-3.8,3.3-3.8c0.8,0.5,1.8,0.8,2.8,0.7l0.1-0.1c0,0.3,0,0.5,0,0.8 c-1.2,1.3-0.8,1.6-3.2,2.1c-2.4,0.5-1,1.4-0.1,1.6c1.1,0.3,3.7,0.7,5.5-1.8l-0.1,0.3c0.5,0.4,0.4,2.7,0.5,4.4 c0.1,1.7,0.2,3.2,0.5,4.1c0.3,0.9,0.7,3.3,3.9,2.6C29.1,38.3,31.1,37.5,31.3,30\'/%3E%3Cpath class=\'st1\' d=\'M38.3,25.3c-4.4,0.9-4.7-0.6-4.7-0.6c4.7-7,6.7-15.8,5-18c-4.6-5.9-12.6-3.1-12.7-3l0,0 c-0.9-0.2-1.9-0.3-3-0.3c-2,0-3.5,0.5-4.7,1.4c0,0-14.3-5.9-13.6,7.4c0.1,2.8,4,21.3,8.7,15.7c1.7-2,3.3-3.8,3.3-3.8 c0.8,0.5,1.8,0.8,2.8,0.7l0.1-0.1c0,0.3,0,0.5,0,0.8c-1.2,1.3-0.8,1.6-3.2,2.1c-2.4,0.5-1,1.4-0.1,1.6c1.1,0.3,3.7,0.7,5.5-1.8 l-0.1,0.3c0.5,0.4,0.8,2.4,0.7,4.3c-0.1,1.9-0.1,3.2,0.3,4.2c0.4,1,0.7,3.3,3.9,2.6c2.6-0.6,4-2,4.2-4.5c0.1-1.7,0.4-1.5,0.5-3 l0.2-0.7c0.3-2.3,0-3.1,1.7-2.8l0.4,0c1.2,0.1,2.8-0.2,3.7-0.6C39,26.4,40.2,24.9,38.3,25.3L38.3,25.3z\'/%3E%3Cpath class=\'st2\' d=\'M21.8,26.6c-0.1,4.4,0,8.8,0.5,9.8c0.4,1.1,1.3,3.2,4.5,2.5c2.6-0.6,3.6-1.7,4-4.1c0.3-1.8,0.9-6.7,1-7.7\'/%3E%3Cpath class=\'st2\' d=\'M18,4.7c0,0-14.3-5.8-13.6,7.4c0.1,2.8,4,21.3,8.7,15.7c1.7-2,3.2-3.7,3.2-3.7\'/%3E%3Cpath class=\'st2\' d=\'M25.7,3.6c-0.5,0.2,7.9-3.1,12.7,3c1.7,2.2-0.3,11-5,18\'/%3E%3Cpath class=\'st3\' d=\'M33.5,24.6c0,0,0.3,1.5,4.7,0.6c1.9-0.4,0.8,1.1-1.2,2c-1.6,0.8-5.3,0.9-5.3-0.1 C31.6,24.5,33.6,25.3,33.5,24.6c-0.1-0.6-1.1-1.2-1.7-2.7c-0.5-1.3-7.3-11.2,1.9-9.7c0.3-0.1-2.4-8.7-11-8.9 c-8.6-0.1-8.3,10.6-8.3,10.6\'/%3E%3Cpath class=\'st2\' d=\'M19.4,25.6c-1.2,1.3-0.8,1.6-3.2,2.1c-2.4,0.5-1,1.4-0.1,1.6c1.1,0.3,3.7,0.7,5.5-1.8c0.5-0.8,0-2-0.7-2.3 C20.5,25.1,20,24.9,19.4,25.6L19.4,25.6z\'/%3E%3Cpath class=\'st2\' d=\'M19.3,25.5c-0.1-0.8,0.3-1.7,0.7-2.8c0.6-1.6,2-3.3,0.9-8.5c-0.8-3.9-6.5-0.8-6.5-0.3c0,0.5,0.3,2.7-0.1,5.2 c-0.5,3.3,2.1,6,5,5.7\'/%3E%3Cpath class=\'st4\' d=\'M18,13.8c0,0.2,0.3,0.7,0.8,0.7c0.5,0.1,0.9-0.3,0.9-0.5c0-0.2-0.3-0.4-0.8-0.4C18.4,13.6,18,13.7,18,13.8 L18,13.8z\'/%3E%3Cpath class=\'st5\' d=\'M32,13.5c0,0.2-0.3,0.7-0.8,0.7c-0.5,0.1-0.9-0.3-0.9-0.5c0-0.2,0.3-0.4,0.8-0.4C31.6,13.2,32,13.3,32,13.5 L32,13.5z\'/%3E%3Cpath class=\'st2\' d=\'M33.7,12.2c0.1,1.4-0.3,2.4-0.4,3.9c-0.1,2.2,1,4.7-0.6,7.2\'/%3E%3Cpath class=\'st6\' d=\'M2.7,6.6\'/%3E%3C/g%3E%3C/svg%3E%0A")', diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index ce2cd3fe0b4..3bdfe35ed4f 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -2842,7 +2842,7 @@ def nlq_chat_stream(trans_id): """ from flask import stream_with_context from pgadmin.llm.utils import is_llm_enabled - from pgadmin.llm.chat import chat_with_database + from pgadmin.llm.chat import chat_with_database_stream from pgadmin.llm.prompts.nlq import NLQ_SYSTEM_PROMPT # Check if LLM is configured @@ -2893,74 +2893,52 @@ def generate(): 'message': gettext('Analyzing your request...') }) - # Call the LLM with database tools - response_text, _ = chat_with_database( + # Stream the LLM response with database tools + response_text = '' + for item in chat_with_database_stream( user_message=user_message, sid=trans_obj.sid, did=trans_obj.did, system_prompt=NLQ_SYSTEM_PROMPT - ) - - # Try to parse the response as JSON - sql = None - explanation = '' - - # First, try to extract JSON from markdown code blocks - json_text = response_text.strip() - - # Look for ```json ... ``` blocks - json_match = re.search( - r'```json\s*\n?(.*?)\n?```', - json_text, + ): + if isinstance(item, str): + # Text chunk from streaming LLM response + yield _nlq_sse_event({ + 'type': 'text_delta', + 'content': item + }) + elif isinstance(item, tuple) and \ + item[0] == 'tool_use': + # Tool execution in progress - reset streaming + yield _nlq_sse_event({ + 'type': 'thinking', + 'message': gettext( + 'Querying the database...' + ) + }) + elif isinstance(item, tuple): + # Final result: (response_text, messages) + response_text = item[0] + + # Extract SQL from markdown code fences + sql_blocks = re.findall( + r'```(?:sql|pgsql|postgresql)\s*\n(.*?)```', + response_text, re.DOTALL ) - if json_match: - json_text = json_match.group(1).strip() - else: - # Also try to find a plain JSON object in the response - # Look for {"sql": ... } pattern anywhere in the text - sql_pattern = ( - r'\{["\']?sql["\']?\s*:\s*' - r'(?:null|"[^"]*"|\'[^\']*\').*?\}' - ) - plain_json_match = re.search(sql_pattern, json_text, re.DOTALL) - if plain_json_match: - json_text = plain_json_match.group(0) - - try: - result = json.loads(json_text) - sql = result.get('sql') - explanation = result.get('explanation', '') - except (json.JSONDecodeError, TypeError): - # If not valid JSON, try to extract SQL from the response - # Look for SQL code blocks first - sql_match = re.search( - r'```sql\s*\n?(.*?)\n?```', - response_text, - re.DOTALL - ) - if sql_match: - sql = sql_match.group(1).strip() - else: - # Check for malformed tool call text patterns - # Some models output tool calls as text instead of - # proper tool use blocks - tool_call_match = re.search( - r'\s*' - r'\s*(.*?)\s*', - response_text, - re.DOTALL - ) - if tool_call_match: - sql = tool_call_match.group(1).strip() - explanation = gettext( - 'Generated SQL query from your request.' - ) - else: - # No parseable JSON or SQL block found - # Treat the response as an explanation/error message - explanation = response_text.strip() - # Don't set sql - leave it as None + sql = ';\n\n'.join( + block.strip().rstrip(';') for block in sql_blocks + ) if sql_blocks else None + + # Fallback: try JSON format in case LLM ignored + # the markdown instruction + if sql is None: + try: + result = json.loads(response_text.strip()) + if isinstance(result, dict): + sql = result.get('sql') + except (json.JSONDecodeError, TypeError): + pass # Generate a conversation ID if not provided if not conversation_id: @@ -2968,11 +2946,11 @@ def generate(): else: new_conversation_id = conversation_id - # Send the final result + # Send the final result with full response content yield _nlq_sse_event({ 'type': 'complete', 'sql': sql, - 'explanation': explanation, + 'content': response_text, 'conversation_id': new_conversation_id }) diff --git a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx index 5bbd2c413bd..8b2125fb558 100644 --- a/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx +++ b/web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx @@ -23,6 +23,8 @@ import AddIcon from '@mui/icons-material/Add'; import ClearAllIcon from '@mui/icons-material/ClearAll'; import AutoFixHighIcon from '@mui/icons-material/AutoFixHigh'; import { format as formatSQL } from 'sql-formatter'; +import { marked } from 'marked'; +import DOMPurify from 'dompurify'; import gettext from 'sources/gettext'; import url_for from 'sources/url_for'; import getApiInstance from '../../../../../../static/js/api_instance'; @@ -104,7 +106,6 @@ const SQLPreviewBox = styled(Box)(({ theme }) => ({ borderRadius: theme.spacing(0.5), overflow: 'auto', '& .cm-editor': { - minHeight: '60px', maxHeight: '250px', }, '& .cm-scroller': { @@ -130,15 +131,171 @@ const ThinkingIndicator = styled(Box)(({ theme }) => ({ color: theme.palette.text.secondary, })); +const MarkdownContent = styled(Box)(({ theme }) => ({ + fontSize: theme.typography.body2.fontSize, + lineHeight: theme.typography.body2.lineHeight, + '& p': { margin: `${theme.spacing(0.5)} 0` }, + '& p:first-of-type': { marginTop: 0 }, + '& p:last-of-type': { marginBottom: 0 }, + '& code': { + backgroundColor: theme.palette.action.hover, + padding: '1px 4px', + borderRadius: 3, + fontSize: '0.85em', + fontFamily: 'monospace', + }, + '& pre': { + backgroundColor: theme.palette.action.hover, + padding: theme.spacing(1), + borderRadius: 4, + overflow: 'auto', + '& code': { + backgroundColor: 'transparent', + padding: 0, + }, + }, + '& h1, & h2, & h3, & h4, & h5, & h6': { + margin: `${theme.spacing(1)} 0 ${theme.spacing(0.5)} 0`, + lineHeight: 1.3, + }, + '& h1': { fontSize: '1.3em' }, + '& h2': { fontSize: '1.2em' }, + '& h3': { fontSize: '1.1em' }, + '& ul': { + margin: `${theme.spacing(0.5)} 0`, + paddingLeft: theme.spacing(2.5), + listStyleType: 'disc !important', + }, + '& ol': { + margin: `${theme.spacing(0.5)} 0`, + paddingLeft: theme.spacing(2.5), + listStyleType: 'decimal !important', + }, + '& li': { + margin: `${theme.spacing(0.25)} 0`, + display: 'list-item !important', + listStyle: 'inherit !important', + }, + '& ul ul': { listStyleType: 'circle !important' }, + '& ul ul ul': { listStyleType: 'square !important' }, + '& table': { + borderCollapse: 'collapse', + margin: `${theme.spacing(0.5)} 0`, + width: '100%', + }, + '& th, & td': { + border: `1px solid ${theme.otherVars.borderColor}`, + padding: `${theme.spacing(0.25)} ${theme.spacing(0.75)}`, + textAlign: 'left', + }, + '& th': { + backgroundColor: theme.palette.action.hover, + fontWeight: 600, + }, + '& blockquote': { + borderLeft: `3px solid ${theme.otherVars.borderColor}`, + margin: `${theme.spacing(0.5)} 0`, + paddingLeft: theme.spacing(1), + opacity: 0.85, + }, + '& strong': { fontWeight: 600 }, + '& a': { + color: theme.otherVars.hyperlinkColor, + textDecoration: 'underline', + }, +})); + // Message types const MESSAGE_TYPES = { USER: 'user', ASSISTANT: 'assistant', SQL: 'sql', THINKING: 'thinking', + STREAMING: 'streaming', ERROR: 'error', }; +/** + * Incrementally parse streaming markdown text into an ordered list of + * segments. Each segment is: + * { type: 'text', content: string } + * { type: 'code', language: string, content: string, complete: boolean } + * + * Handles ```language fenced code blocks. Segments appear in the order + * the LLM streams them so the renderer can map straight over the array. + */ +function parseMarkdownSegments(text) { + const segments = []; + let pos = 0; + + while (pos < text.length) { + const fenceIdx = text.indexOf('```', pos); + + if (fenceIdx === -1) { + // No more fences — rest is text + const content = text.substring(pos); + if (content) segments.push({ type: 'text', content }); + break; + } + + // Text before the fence + if (fenceIdx > pos) { + segments.push({ type: 'text', content: text.substring(pos, fenceIdx) }); + } + + // Parse opening fence line: ```language\n + const afterFence = text.substring(fenceIdx + 3); + const langMatch = /^([a-zA-Z]*)\n/.exec(afterFence); + if (!langMatch) { + // Language line not complete yet — wait for more tokens + break; + } + + const language = langMatch[1].toLowerCase(); + const codeStart = fenceIdx + 3 + langMatch[0].length; + + // Find closing fence + const closeIdx = text.indexOf('```', codeStart); + if (closeIdx === -1) { + // Still streaming code block content + segments.push({ + type: 'code', language, + content: text.substring(codeStart), + complete: false, + }); + break; + } + + // Complete code block — trim trailing newline before closing fence + let codeContent = text.substring(codeStart, closeIdx); + if (codeContent.endsWith('\n')) { + codeContent = codeContent.slice(0, -1); + } + segments.push({ + type: 'code', language, + content: codeContent, + complete: true, + }); + + // Move past closing ``` and optional trailing newline + pos = closeIdx + 3; + if (pos < text.length && text[pos] === '\n') pos++; + } + + return segments; +} + +/** + * Render a markdown text fragment to sanitized HTML. + * Uses marked for inline formatting (bold, italic, code, lists, tables, etc.) + * and DOMPurify to prevent XSS. + */ +function renderMarkdownText(text) { + if (!text) return ''; + const html = marked.parse(text, { gfm: true, breaks: true }); + return DOMPurify.sanitize(html); +} + // Elephant/PostgreSQL-themed processing messages const THINKING_MESSAGES = [ 'Consulting the elephant...', @@ -164,7 +321,7 @@ function getRandomThinkingMessage() { } // Single chat message component -function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey }) { +function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey, formatSqlWithPrefs }) { if (message.type === MESSAGE_TYPES.USER) { return ( @@ -174,58 +331,117 @@ function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey }) } if (message.type === MESSAGE_TYPES.SQL) { + const segments = message.content + ? parseMarkdownSegments(message.content) : []; + + // Fallback for messages without markdown content (old format) + if (segments.length === 0 && message.sql) { + return ( + + + + + {gettext('Generated SQL')} + + + + onInsertSQL(message.sql)}> + + + + + onReplaceSQL(message.sql)}> + + + + + navigator.clipboard.writeText(message.sql)}> + + + + + + + + + + {message.explanation && ( + {message.explanation} + )} + + ); + } + + // Render markdown segments with action buttons on code blocks return ( - {message.explanation && ( - - {message.explanation} - - )} - - - - {gettext('Generated SQL')} - - - - onInsertSQL(message.sql)} - > - - - - - onReplaceSQL(message.sql)} - > - - - - - navigator.clipboard.writeText(message.sql)} - > - - - - - - - - - + {segments.map((seg, idx) => { + if (seg.type === 'text') { + const content = seg.content?.trim(); + if (!content) return null; + return ( + 0 ? 1 : 0 }} + dangerouslySetInnerHTML={{ __html: renderMarkdownText(content) }} + /> + ); + } + + if (seg.type === 'code') { + const isSql = ['sql', 'pgsql', 'postgresql'].includes(seg.language); + const formattedCode = isSql ? formatSqlWithPrefs(seg.content) : seg.content; + + return ( + + + + {seg.language || gettext('Code')} + + + {isSql && ( + <> + + onInsertSQL(formattedCode)}> + + + + + onReplaceSQL(formattedCode)}> + + + + + )} + + navigator.clipboard.writeText(formattedCode)}> + + + + + + + + + + ); + } + + return null; + })} ); } @@ -246,6 +462,106 @@ function ChatMessage({ message, onInsertSQL, onReplaceSQL, textColors, cmKey }) ); } + if (message.type === MESSAGE_TYPES.STREAMING) { + const segments = parseMarkdownSegments(message.content); + const BlinkingCursor = ( + + ); + + // No segments parsed yet — show raw text or spinner + if (segments.length === 0) { + return ( + + {message.content ? ( + + {message.content} + {BlinkingCursor} + + ) : ( + + + + {gettext('Generating response...')} + + + )} + + ); + } + + // Render markdown segments in order + const lastIdx = segments.length - 1; + return ( + + {segments.map((seg, idx) => { + const isLast = idx === lastIdx; + const cursor = isLast && !seg.complete ? BlinkingCursor : null; + + if (seg.type === 'code') { + return ( + + + + {seg.complete + ? (seg.language || gettext('Code')) + : gettext('Generating...')} + + + + + {seg.content} + {cursor} + + + + ); + } + + const content = seg.content?.trim(); + if (!content && !cursor) return null; + return ( + 0 ? 1 : 0, display: 'inline' }}> + + {cursor} + + ); + })} + + ); + } + if (message.type === MESSAGE_TYPES.ERROR) { return ( - {message.content} + ); } @@ -291,6 +609,8 @@ export function NLQChatPanel() { const abortControllerRef = useRef(null); const readerRef = useRef(null); const stoppedRef = useRef(false); + const streamingTextRef = useRef(''); + const streamingIdRef = useRef(null); const eventBus = useContext(QueryToolEventsContext); const queryToolCtx = useContext(QueryToolContext); const editorPrefs = usePreferences().getPreferencesForModule('editor'); @@ -444,8 +764,10 @@ export function NLQChatPanel() { const handleSubmit = async () => { if (!inputValue.trim() || isLoading) return; - // Reset stopped flag + // Reset stopped flag and streaming state stoppedRef.current = false; + streamingTextRef.current = ''; + streamingIdRef.current = null; // Fetch latest LLM provider/model info before submitting fetchLlmInfo(); @@ -547,34 +869,59 @@ export function NLQChatPanel() { // Check if user manually stopped if (stoppedRef.current) { - setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), - { - type: MESSAGE_TYPES.ASSISTANT, - content: gettext('Generation stopped.'), - }, - ]); + const streamId = streamingIdRef.current; + // If we have partial streaming content, show it as-is + if (streamingTextRef.current) { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: streamingTextRef.current + '\n\n' + gettext('(Generation stopped)'), + }, + ]); + } else { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: gettext('Generation stopped.'), + }, + ]); + } + streamingTextRef.current = ''; + streamingIdRef.current = null; } } catch (error) { clearTimeout(timeoutId); abortControllerRef.current = null; readerRef.current = null; + const streamId = streamingIdRef.current; // Show appropriate message based on error type if (error.name === 'AbortError') { // Check if this was a user-initiated stop or a timeout if (stoppedRef.current) { - // User manually stopped - setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), - { - type: MESSAGE_TYPES.ASSISTANT, - content: gettext('Generation stopped.'), - }, - ]); + // User manually stopped - show partial content if any + if (streamingTextRef.current) { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: streamingTextRef.current + '\n\n' + gettext('(Generation stopped)'), + }, + ]); + } else { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId), + { + type: MESSAGE_TYPES.ASSISTANT, + content: gettext('Generation stopped.'), + }, + ]); + } } else { // Timeout occurred setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.ERROR, content: gettext('Request timed out. The query may be too complex. Please try a simpler request.'), @@ -583,13 +930,15 @@ export function NLQChatPanel() { } } else { setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.ERROR, content: gettext('Failed to generate SQL: ') + error.message, }, ]); } + streamingTextRef.current = ''; + streamingIdRef.current = null; } finally { setIsLoading(false); setThinkingMessageId(null); @@ -598,48 +947,100 @@ export function NLQChatPanel() { const handleSSEEvent = (event, thinkingId) => { switch (event.type) { - case 'thinking': - setMessages((prev) => - prev.map((m) => - m.id === thinkingId ? { ...m, content: event.message } : m - ) - ); + case 'thinking': { + const streamId = streamingIdRef.current; + if (streamId) { + // Transition from streaming back to thinking (tool use) + // Remove streaming message and re-add thinking indicator + streamingTextRef.current = ''; + streamingIdRef.current = null; + setMessages((prev) => [ + ...prev.filter((m) => m.id !== streamId), + { + type: MESSAGE_TYPES.THINKING, + content: event.message, + id: thinkingId, + }, + ]); + setThinkingMessageId(thinkingId); + } else { + setMessages((prev) => + prev.map((m) => + m.id === thinkingId ? { ...m, content: event.message } : m + ) + ); + } break; + } - case 'sql': - case 'complete': - // If sql is null/empty, show as regular assistant message (e.g., clarification questions) - if (!event.sql) { + case 'text_delta': + streamingTextRef.current += event.content; + if (!streamingIdRef.current) { + // First text chunk: replace thinking with streaming message + streamingIdRef.current = Date.now(); setMessages((prev) => [ ...prev.filter((m) => m.id !== thinkingId), { - type: MESSAGE_TYPES.ASSISTANT, - content: event.explanation || gettext('I need more information to generate the SQL.'), + type: MESSAGE_TYPES.STREAMING, + content: streamingTextRef.current, + id: streamingIdRef.current, }, ]); } else { + // Update existing streaming message + const sid = streamingIdRef.current; + setMessages((prev) => + prev.map((m) => + m.id === sid ? { ...m, content: streamingTextRef.current } : m + ) + ); + } + break; + + case 'sql': + case 'complete': { + const streamId = streamingIdRef.current; + const content = event.content || event.explanation + || gettext('I need more information to generate the SQL.'); + // Use SQL type if there's SQL or any code fences in the response + const hasCodeBlocks = event.sql || (content && content.includes('```')); + if (hasCodeBlocks) { setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), { type: MESSAGE_TYPES.SQL, - sql: formatSqlWithPrefs(event.sql), - explanation: event.explanation, + content, + sql: event.sql, + }, + ]); + } else { + setMessages((prev) => [ + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamId), + { + type: MESSAGE_TYPES.ASSISTANT, + content, }, ]); } if (event.conversation_id) { setConversationId(event.conversation_id); } + // Reset streaming state + streamingTextRef.current = ''; + streamingIdRef.current = null; break; + } case 'error': setMessages((prev) => [ - ...prev.filter((m) => m.id !== thinkingId), + ...prev.filter((m) => m.id !== thinkingId && m.id !== streamingIdRef.current), { type: MESSAGE_TYPES.ERROR, content: event.message, }, ]); + streamingTextRef.current = ''; + streamingIdRef.current = null; break; } }; @@ -733,6 +1134,7 @@ export function NLQChatPanel() { onReplaceSQL={handleReplaceSQL} textColors={textColors} cmKey={cmKey} + formatSqlWithPrefs={formatSqlWithPrefs} /> )) )} diff --git a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py index 6f1f3447990..f797963b80d 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py +++ b/web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py @@ -10,6 +10,7 @@ """Tests for the NLQ (Natural Language Query) chat endpoint.""" import json +import re from unittest.mock import patch, MagicMock from pgadmin.utils.route import BaseTestGenerator @@ -43,8 +44,9 @@ class NLQChatTestCase(BaseTestGenerator): message='Find all users', expected_error=False, mock_response=( - '{"sql": "SELECT * FROM users;", ' - '"explanation": "Gets all users"}' + 'Here are all users:\n\n' + '```sql\nSELECT * FROM users;\n```\n\n' + 'This retrieves all rows from the users table.' ) )), ] @@ -92,11 +94,14 @@ def runTest(self): ) patches.append(mock_check_trans) - # Mock chat_with_database + # Mock chat_with_database_stream if hasattr(self, 'mock_response'): + def mock_stream_gen(*args, **kwargs): + yield self.mock_response + yield (self.mock_response, []) mock_chat = patch( - 'pgadmin.llm.chat.chat_with_database', - return_value=(self.mock_response, []) + 'pgadmin.llm.chat.chat_with_database_stream', + side_effect=mock_stream_gen ) patches.append(mock_chat) @@ -171,3 +176,248 @@ def runTest(self): def tearDown(self): pass + + +class NLQSqlExtractionTestCase(BaseTestGenerator): + """Test cases for SQL extraction from markdown responses""" + + scenarios = [ + ('SQL Extraction - Single SQL block', dict( + response_text=( + 'Here is the query:\n\n' + '```sql\nSELECT * FROM users;\n```\n\n' + 'This returns all users.' + ), + expected_sql='SELECT * FROM users;' + )), + ('SQL Extraction - Multiple SQL blocks', dict( + response_text=( + 'First get users:\n\n' + '```sql\nSELECT * FROM users;\n```\n\n' + 'Then get orders:\n\n' + '```sql\nSELECT * FROM orders;\n```' + ), + expected_sql='SELECT * FROM users;\n\nSELECT * FROM orders' + )), + ('SQL Extraction - pgsql language tag', dict( + response_text='```pgsql\nSELECT 1;\n```', + expected_sql='SELECT 1;' + )), + ('SQL Extraction - postgresql language tag', dict( + response_text='```postgresql\nSELECT 1;\n```', + expected_sql='SELECT 1;' + )), + ('SQL Extraction - No SQL blocks', dict( + response_text=( + 'I cannot generate a query without ' + 'knowing your table structure.' + ), + expected_sql=None + )), + ('SQL Extraction - Non-SQL code block only', dict( + response_text=( + 'Here is some Python:\n\n' + '```python\nprint("hello")\n```' + ), + expected_sql=None + )), + ('SQL Extraction - JSON fallback', dict( + response_text='{"sql": "SELECT 1;", "explanation": "test"}', + expected_sql='SELECT 1;' + )), + ('SQL Extraction - Multiline SQL', dict( + response_text=( + '```sql\n' + 'SELECT u.name, o.total\n' + 'FROM users u\n' + 'JOIN orders o ON u.id = o.user_id\n' + 'WHERE o.total > 100;\n' + '```' + ), + expected_sql=( + 'SELECT u.name, o.total\n' + 'FROM users u\n' + 'JOIN orders o ON u.id = o.user_id\n' + 'WHERE o.total > 100;' + ) + )), + ] + + def setUp(self): + pass + + def runTest(self): + """Test SQL extraction from markdown response text""" + response_text = self.response_text + + # Extract SQL using the same regex as the endpoint + sql_blocks = re.findall( + r'```(?:sql|pgsql|postgresql)\s*\n(.*?)```', + response_text, + re.DOTALL + ) + sql = ';\n\n'.join( + block.strip() for block in sql_blocks + ) if sql_blocks else None + + # JSON fallback + if sql is None: + try: + result = json.loads(response_text.strip()) + if isinstance(result, dict): + sql = result.get('sql') + except (json.JSONDecodeError, TypeError): + pass + + self.assertEqual(sql, self.expected_sql) + + def tearDown(self): + pass + + +class NLQStreamingSSETestCase(BaseTestGenerator): + """Test cases for SSE event format in streaming responses""" + + scenarios = [ + ('SSE - Text with SQL produces complete event', dict( + mock_response=( + '```sql\nSELECT 1;\n```' + ), + check_complete_has_sql=True + )), + ('SSE - Text without SQL has no sql field', dict( + mock_response='I need more information about your schema.', + check_complete_has_sql=False + )), + ] + + def setUp(self): + pass + + def runTest(self): + """Test SSE events from NLQ streaming endpoint""" + trans_id = 12345 + + patches = [] + + mock_llm_enabled = patch( + 'pgadmin.llm.utils.is_llm_enabled', + return_value=True + ) + patches.append(mock_llm_enabled) + + mock_trans_obj = MagicMock() + mock_trans_obj.sid = 1 + mock_trans_obj.did = 1 + + mock_conn = MagicMock() + mock_conn.connected.return_value = True + + mock_session = {'sid': 1, 'did': 1} + + mock_check_trans = patch( + 'pgadmin.tools.sqleditor.check_transaction_status', + return_value=( + True, None, mock_conn, mock_trans_obj, mock_session + ) + ) + patches.append(mock_check_trans) + + def mock_stream_gen(*args, **kwargs): + # Yield text chunks + for chunk in [self.mock_response[i:i + 10] + for i in range(0, len(self.mock_response), 10)]: + yield chunk + # Yield final tuple + yield (self.mock_response, []) + + mock_chat = patch( + 'pgadmin.llm.chat.chat_with_database_stream', + side_effect=mock_stream_gen + ) + patches.append(mock_chat) + + mock_csrf = patch( + 'pgadmin.authenticate.mfa.utils.mfa_required', + lambda f: f + ) + patches.append(mock_csrf) + + for p in patches: + p.start() + + try: + response = self.tester.post( + f'/sqleditor/nlq/chat/{trans_id}/stream', + data=json.dumps({'message': 'test query'}), + content_type='application/json', + follow_redirects=True + ) + + self.assertEqual(response.status_code, 200) + self.assertIn('text/event-stream', response.content_type) + + # Parse SSE events + events = [] + raw = response.data.decode('utf-8') + for line in raw.split('\n'): + if line.startswith('data: '): + try: + events.append(json.loads(line[6:])) + except json.JSONDecodeError: + pass + + # Should have at least one text_delta and one complete + event_types = [e.get('type') for e in events] + self.assertIn('text_delta', event_types) + self.assertIn('complete', event_types) + + # Check the complete event + complete_events = [ + e for e in events if e.get('type') == 'complete' + ] + self.assertEqual(len(complete_events), 1) + complete = complete_events[0] + + # Verify content is present + self.assertIn('content', complete) + self.assertEqual(complete['content'], self.mock_response) + + # Verify SQL extraction + if self.check_complete_has_sql: + self.assertIsNotNone(complete.get('sql')) + else: + self.assertIsNone(complete.get('sql')) + + finally: + for p in patches: + p.stop() + + def tearDown(self): + pass + + +class NLQPromptMarkdownFormatTestCase(BaseTestGenerator): + """Test that NLQ prompt instructs markdown code fences""" + + scenarios = [ + ('NLQ Prompt - Markdown format', dict()), + ] + + def setUp(self): + pass + + def runTest(self): + """Test NLQ prompt requires markdown SQL code fences""" + from pgadmin.llm.prompts.nlq import NLQ_SYSTEM_PROMPT + + # Prompt should instruct use of fenced code blocks + self.assertIn('fenced code block', NLQ_SYSTEM_PROMPT.lower()) + self.assertIn('sql', NLQ_SYSTEM_PROMPT.lower()) + + # Should NOT instruct JSON format + self.assertNotIn('"sql":', NLQ_SYSTEM_PROMPT) + self.assertNotIn('"explanation":', NLQ_SYSTEM_PROMPT) + + def tearDown(self): + pass