From 8a066e1df7c868562d7987952944f3a0daf8aa54 Mon Sep 17 00:00:00 2001 From: Srimon Date: Sun, 24 May 2026 19:31:05 +0530 Subject: [PATCH 1/4] feat: add async execution and batched programmatic APIs Add async QQL support with AsyncConnection and AsyncExecutor, plus sync and async batching helpers for running multiple statements through one programmatic API. Introduce BEGIN BATCH syntax, parameterized query helpers, and optional gRPC connection settings. Refactor shared parser and executor logic into qql.utils so sync and async paths can reuse filter conversion, vector shaping, topology parsing, batch grouping, and search parsing helpers. Update tests and docs for async usage, batching, parameterized queries, gRPC configuration, and batch block execution. --- README.md | 33 +- docs/getting-started.md | 11 +- docs/programmatic.md | 129 ++- docs/reference.md | 39 +- docs/scripts.md | 46 +- src/qql/__init__.py | 10 +- src/qql/ast_nodes.py | 6 + src/qql/async_connection.py | 211 +++++ src/qql/async_executor.py | 1402 ++++++++++++++++++++++++++++++++ src/qql/connection.py | 89 +- src/qql/executor.py | 529 ++++++------ src/qql/lexer.py | 11 + src/qql/parser.py | 211 ++--- src/qql/script.py | 15 +- src/qql/utils.py | 562 +++++++++++++ tests/test_async_connection.py | 389 +++++++++ tests/test_connection.py | 104 +++ 17 files changed, 3350 insertions(+), 447 deletions(-) create mode 100644 src/qql/async_connection.py create mode 100644 src/qql/async_executor.py create mode 100644 src/qql/utils.py create mode 100644 tests/test_async_connection.py diff --git a/README.md b/README.md index 7b15af8..25ba17b 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ [![PyPI version](https://img.shields.io/pypi/v/qql-cli?color=blue&label=PyPI)](https://pypi.org/project/qql-cli/) [![Python 3.12+](https://img.shields.io/pypi/pyversions/qql-cli)](https://pypi.org/project/qql-cli/) [![MIT License](https://img.shields.io/badge/license-MIT-green)](LICENSE) -[![Tests](https://img.shields.io/badge/tests-549%20passing-brightgreen)](tests/) +[![Tests](https://img.shields.io/badge/tests-635%20passing-brightgreen)](tests/) -Write `INSERT`, `SELECT`, `SEARCH`, `SCROLL`, `RECOMMEND`, `UPDATE`, `DELETE`, and `CREATE COLLECTION` statements instead of Python SDK calls. Supports hybrid dense+sparse vector search, grouped search (GROUP BY), cross-encoder reranking, quantization (scalar, turbo, binary, product), SQL-style `WHERE` filters, script execution, and collection dump/restore. +Write `INSERT`, `SELECT`, `SEARCH`, `SCROLL`, `RECOMMEND`, `UPDATE`, `DELETE`, and `CREATE COLLECTION` statements instead of Python SDK calls. Supports hybrid dense+sparse vector search, grouped search (GROUP BY), cross-encoder reranking, quantization (scalar, turbo, binary, product), SQL-style `WHERE` filters, script execution, collection dump/restore, async execution, gRPC transport, parameterized queries, and batched query execution. ``` qql> INSERT INTO COLLECTION notes VALUES {'text': 'Qdrant is a vector database', 'author': 'alice', 'year': 2024} @@ -50,16 +50,23 @@ Your query string When you run `INSERT`, the `text` field is automatically converted into a dense vector using [Fastembed](https://github.com/qdrant/fastembed). In **hybrid mode** (`USING HYBRID`), a sparse BM25 vector is also generated alongside the dense vector, and searches use Qdrant's Reciprocal Rank Fusion (RRF) by default to merge the results of both retrieval methods. You can switch hybrid search to DBSF with `FUSION 'dbsf'`. -QQL also exposes a **programmatic API** for use inside Python applications — no CLI required: +QQL also exposes a **programmatic API** for use inside Python applications — no CLI required. Use `Connection` for sync code, `AsyncConnection` for async apps, and batch helpers when you want QQL to combine compatible operations into fewer Qdrant requests: ```python -from qql import Connection +from qql import Connection, QQLBatch with Connection("http://localhost:6333") as conn: conn.run_query("INSERT INTO COLLECTION notes VALUES {'text': 'Qdrant is fast'}") - result = conn.run_query("SEARCH notes SIMILAR TO 'vector database' LIMIT 5") - for hit in result.data: - print(hit["score"], hit["payload"]) + result = conn.run_parameterized_query( + "SEARCH notes SIMILAR TO :query LIMIT 5", + {"query": "vector database"}, + ) + + with QQLBatch(conn) as batch: + neurology = batch.add("SEARCH notes SIMILAR TO 'neurology' LIMIT 5") + cardiology = batch.add("SEARCH notes SIMILAR TO 'cardiology' LIMIT 5") + + print(neurology.result.data, cardiology.result.data) ``` --- @@ -97,8 +104,8 @@ Full documentation lives in the [`docs/`](docs/) folder and at **[pavanjava.gith | [SEARCH / SELECT / SCROLL / RECOMMEND / Hybrid / GROUP BY / RERANK](docs/search.md) | Semantic search, grouped search, point retrieval, pagination, hybrid, reranking, recommendations | | [WHERE Filters](docs/filters.md) | Full SQL-style filter operators | | [Collections & Quantization](docs/collections.md) | SHOW, CREATE, DROP, QUANTIZE (scalar/turbo/binary/product), CREATE INDEX, UPDATE VECTOR, UPDATE PAYLOAD | -| [Scripts: EXECUTE / DUMP](docs/scripts.md) | Script files, collection backup/restore | -| [Programmatic Usage](docs/programmatic.md) | Use QQL as a Python library via `Connection` or `run_query()` | +| [Scripts: EXECUTE / DUMP](docs/scripts.md) | Script files, `BEGIN BATCH` blocks, collection backup/restore | +| [Programmatic Usage](docs/programmatic.md) | Sync/async Python APIs, parameterized queries, batching, gRPC | | [Reference: Models / Config / Errors](docs/reference.md) | Embedding models, config file, error reference | --- @@ -170,6 +177,12 @@ DELETE FROM articles WHERE year < 2020 -- Scripts EXECUTE /path/to/script.qql DUMP articles /path/to/backup.qql + +-- Batch block +BEGIN BATCH; + SEARCH articles SIMILAR TO 'query one' LIMIT 5; + SEARCH articles SIMILAR TO 'query two' LIMIT 5; +END BATCH ``` --- @@ -182,7 +195,7 @@ Tests do not require a running Qdrant instance — the Qdrant client is mocked. pytest tests/ -v ``` -Expected: **549 tests passing**. +Expected: **635 tests passing**. --- diff --git a/docs/getting-started.md b/docs/getting-started.md index bff4473..ab64e9d 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -5,7 +5,7 @@ title: "Getting Started" # Getting Started with QQL -QQL is a SQL-like query language and CLI for [Qdrant](https://qdrant.tech). Instead of writing Python SDK calls you write natural query statements to insert, search, manage, and delete vector data. +QQL is a SQL-like query language and CLI for [Qdrant](https://qdrant.tech). Instead of writing Python SDK calls you write natural query statements to insert, search, manage, and delete vector data. It can also be used as a sync or async Python library with batching, parameterized queries, and optional gRPC transport. --- @@ -154,6 +154,12 @@ SHOW COLLECTION notes -- Retrieve a point by ID SELECT * FROM notes WHERE id = 1 + +-- Run compatible queries as one batch +BEGIN BATCH; + SEARCH notes SIMILAR TO 'vector databases' LIMIT 5; + SEARCH notes SIMILAR TO 'semantic search' LIMIT 5; +END BATCH ``` --- @@ -164,5 +170,6 @@ SELECT * FROM notes WHERE id = 1 - [SEARCH / SELECT / SCROLL / RECOMMEND / Hybrid / RERANK](search.md) — querying - [WHERE Filters](filters.md) — payload filtering - [Collections & Quantization](collections.md) — managing collections -- [Scripts: EXECUTE / DUMP](scripts.md) — automating with script files +- [Scripts: EXECUTE / DUMP](scripts.md) — automating with script files and batch blocks +- [Programmatic Usage](programmatic.md) — sync/async APIs, batching, parameterized queries, gRPC - [Embedding Models](reference.md#embedding-models) — model reference diff --git a/docs/programmatic.md b/docs/programmatic.md index d632a0d..8db6599 100644 --- a/docs/programmatic.md +++ b/docs/programmatic.md @@ -16,6 +16,8 @@ single connection to Qdrant once and reuses it for every `run_query()` call — more efficient than the legacy `run_query()` function, which creates a new client on every invocation. +Use `AsyncConnection` when your application already runs on `asyncio`. + ### Basic usage ```python @@ -70,6 +72,22 @@ with Connection("https://.qdrant.io", secret="") as print(result.data) ``` +### gRPC transport + +QQL can ask the Qdrant client to prefer gRPC for lower request overhead: + +```python +from qql import Connection + +with Connection( + "http://localhost:6333", + prefer_grpc=True, + grpc_port=6334, +) as conn: + result = conn.run_query("SHOW COLLECTIONS") + print(result.data) +``` + ### Custom embedding model ```python @@ -155,9 +173,117 @@ with Connection("http://localhost:6333") as conn: | `url` | `str` | `"http://localhost:6333"` | Qdrant instance URL | | `secret` | `str \| None` | `None` | API key; `None` for unauthenticated | | `default_model` | `str \| None` | `None` → `sentence-transformers/all-MiniLM-L6-v2` | Dense embedding model used when no `USING MODEL` clause is given | +| `prefer_grpc` | `bool` | `False` | Passes `prefer_grpc=True` to the Qdrant client | +| `grpc_port` | `int` | `6334` | gRPC port used when `prefer_grpc=True` | | `default_dense_vector_name` | `str` | `"dense"` | Dense vector name used when QQL creates a collection and no explicit `USING VECTOR` name is given | | `default_sparse_vector_name` | `str` | `"sparse"` | Sparse vector name used when QQL creates a hybrid collection and no explicit sparse vector name is given | +--- + +## Parameterized Queries + +Parameterized helpers render `:name` placeholders before parsing the QQL statement. String values are quoted and escaped; booleans are rendered as `true` / `false`. + +```python +from qql import Connection + +with Connection("http://localhost:6333") as conn: + result = conn.run_parameterized_query( + "SEARCH notes SIMILAR TO :query LIMIT 5 WHERE author = :author", + {"query": "vector database", "author": "alice"}, + ) + + results = conn.run_parameterized_batch( + "SEARCH notes SIMILAR TO :query LIMIT 5 WHERE category = :category", + [ + {"query": "brain stroke", "category": "Neurology"}, + {"query": "heart attack", "category": "Cardiology"}, + ], + ) +``` + +Parameterized queries are a convenience for building QQL strings safely in application code; they are not sent to Qdrant as server-side prepared statements. + +--- + +## Batch Execution + +`run_queries_batch()` parses multiple QQL strings into a `BatchBlockStmt`. The executor groups compatible statements: + +- compatible `SEARCH` / `RECOMMEND` statements use Qdrant `query_batch_points` +- compatible `INSERT` statements become one `INSERT BULK` +- mixed or incompatible statements still execute in order + +```python +from qql import Connection + +with Connection("http://localhost:6333") as conn: + results = conn.run_queries_batch([ + "SEARCH docs SIMILAR TO 'neurology' LIMIT 5", + "SEARCH docs SIMILAR TO 'cardiology' LIMIT 5", + ]) + + for result in results: + print(result.message) +``` + +For ergonomic batching in application code, use `QQLBatch`: + +```python +from qql import Connection, QQLBatch + +with Connection("http://localhost:6333") as conn: + with QQLBatch(conn) as batch: + neuro = batch.add("SEARCH docs SIMILAR TO 'neurology' LIMIT 5") + cardio = batch.add("SEARCH docs SIMILAR TO 'cardiology' LIMIT 5") + + print(neuro.result.data) + print(cardio.result.data) +``` + +Each proxy's `.result` becomes available after the context manager exits. + +--- + +## Async API + +`AsyncConnection` mirrors the sync API for `asyncio` applications and uses `AsyncQdrantClient` under the hood. + +```python +from qql import AsyncConnection + +async with AsyncConnection("http://localhost:6333") as conn: + await conn.run_query( + "INSERT INTO COLLECTION notes VALUES {'text': 'async QQL'}" + ) + result = await conn.run_query( + "SEARCH notes SIMILAR TO 'async vector search' LIMIT 5" + ) + print(result.data) +``` + +Async batching and parameterized helpers are also available: + +```python +from qql import AsyncConnection, QQLAsyncBatch + +async with AsyncConnection("http://localhost:6333", prefer_grpc=True) as conn: + result = await conn.run_parameterized_query( + "SEARCH docs SIMILAR TO :query LIMIT 5", + {"query": "clinical notes"}, + ) + + async with QQLAsyncBatch(conn) as batch: + first = batch.add("SEARCH docs SIMILAR TO 'neurology' LIMIT 5") + second = batch.add("SEARCH docs SIMILAR TO 'cardiology' LIMIT 5") + + print(first.result.data, second.result.data) +``` + +The async executor preserves the same `ExecutionResult` shape as the sync executor. + +--- + ### Power-user: `executor` property For low-level access to the pipeline, use `conn.executor` directly: @@ -250,7 +376,8 @@ class ExecutionResult: |---|---| | INSERT (dense) | `{"id": int \| "", "collection": ""}` | | INSERT (hybrid) | `{"id": int \| "", "collection": ""}` | -| INSERT BULK | `None` (count in `result.message`) | +| INSERT BULK | `{"ids": [int \| "", ...]}` | +| BEGIN BATCH / programmatic batch | `[ExecutionResult, ...]` | | SELECT | `{"id": str, "payload": dict}` or `None` when not found | | SEARCH | `[{"id": str, "score": float, "payload": dict}, ...]` | | SCROLL | `{"points": [{"id": str, "payload": dict}, ...], "next_offset": str \| int \| None}` | diff --git a/docs/reference.md b/docs/reference.md index ce9734e..dda9a45 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -5,7 +5,7 @@ title: "Reference" # Reference — Models, Config, Project Structure, Errors -Default embedding models, configuration parameters, project layout, and common error codes for troubleshooting. +Default embedding models, configuration parameters, public APIs, project layout, and common error codes for troubleshooting. --- @@ -147,6 +147,28 @@ You can edit this file directly to change the default model without reconnecting --- +## Public Python API + +| API | Description | +|---|---| +| `Connection` | Stateful sync QQL client backed by `QdrantClient` | +| `AsyncConnection` | Stateful async QQL client backed by `AsyncQdrantClient` | +| `QQLBatch` | Sync context manager for collecting statements and resolving per-statement results after execution | +| `QQLAsyncBatch` | Async context manager equivalent of `QQLBatch` | +| `Executor` | Low-level sync AST executor | +| `AsyncExecutor` | Low-level async AST executor | +| `ExecutionResult` | Standard result object returned by all operations | + +Both sync and async connections support: + +- `run_query(query)` +- `run_queries_batch([query, ...])` +- `run_parameterized_query(template, params)` +- `run_parameterized_batch(template, [params, ...])` +- `prefer_grpc=True` and `grpc_port=` connection options + +--- + ## Project Structure ``` @@ -154,16 +176,19 @@ qql/ ├── pyproject.toml # Package config; installs the `qql` CLI command ├── src/ │ └── qql/ -│ ├── __init__.py # Public API: Connection, run_query() +│ ├── __init__.py # Public API exports: sync, async, batching, parser/executor │ ├── cli.py # CLI entry point: connect, disconnect, execute, dump, REPL │ ├── config.py # QQLConfig dataclass + ~/.qql/config.json I/O -│ ├── connection.py # Connection class — stateful programmatic API +│ ├── connection.py # Sync Connection, QQLBatch, parameterized query helpers +│ ├── async_connection.py # AsyncConnection and QQLAsyncBatch │ ├── exceptions.py # QQLError, QQLSyntaxError, QQLRuntimeError │ ├── lexer.py # Tokenizer: string → List[Token] │ ├── ast_nodes.py # Frozen dataclasses for each statement and filter type │ ├── parser.py # Recursive descent parser: tokens → AST node │ ├── embedder.py # Embedder (dense) + SparseEmbedder (BM25) + CrossEncoderEmbedder (rerank) -│ ├── executor.py # AST node → Qdrant client call + filter + hybrid search +│ ├── executor.py # Sync AST node → Qdrant client call +│ ├── async_executor.py # Async AST node → AsyncQdrantClient call +│ ├── utils.py # Shared pure helpers for parsing, filters, batching, vectors │ ├── script.py # Script runner: parse and execute .qql files statement by statement │ └── dumper.py # Collection exporter: scroll all points → .qql INSERT BULK script └── tests/ @@ -171,6 +196,7 @@ qql/ ├── test_parser.py # Parser unit tests ├── test_executor.py # Executor unit tests (mocked Qdrant client) ├── test_connection.py # Connection class unit tests (mocked Qdrant client) + ├── test_async_connection.py # AsyncConnection / AsyncExecutor tests ├── test_script.py # Script runner unit tests └── test_dumper.py # Dumper unit tests ``` @@ -185,7 +211,7 @@ Tests do not require a running Qdrant instance — the Qdrant client is mocked. pytest tests/ -v ``` -Expected output: **604 tests passing**. +Expected output: **635 tests passing**. --- @@ -218,3 +244,6 @@ Expected output: **604 tests passing**. | `Unknown index type '...'` | Invalid schema type in CREATE INDEX | Use one of: `keyword`, `integer`, `float`, `bool`, `text`, `geo`, `datetime`, `uuid` | | `Unknown CREATE INDEX option '...'` | Unsupported advanced option for the chosen payload index type | Check which `WITH { ... }` keys are supported for `keyword`, `uuid`, or `text` | | `Qdrant error during CREATE INDEX: ...` | Qdrant rejected the index creation | Check field name and collection state | +| `Unterminated batch block; expected END BATCH` | A `BEGIN BATCH` block was not closed | Add `END BATCH` at the end of the block | +| `Batch has not been executed yet.` | Read a `QQLBatch` proxy result before leaving the context manager | Access `.result` only after the `with QQLBatch(...)` block exits | +| `AsyncBatch has not been executed yet.` | Read a `QQLAsyncBatch` proxy result before leaving the async context manager | Access `.result` only after the `async with QQLAsyncBatch(...)` block exits | diff --git a/docs/scripts.md b/docs/scripts.md index 1fd512f..6938eb8 100644 --- a/docs/scripts.md +++ b/docs/scripts.md @@ -5,7 +5,7 @@ title: "Scripts: EXECUTE / DUMP" # Script Files — EXECUTE and DUMP -QQL supports reading from and writing to `.qql` script files, making it easy to automate bulk operations, seed databases, and back up collections. +QQL supports reading from and writing to `.qql` script files, making it easy to automate bulk operations, seed databases, and back up collections. Scripts can contain regular statements or explicit `BEGIN BATCH ... END BATCH` blocks. --- @@ -51,10 +51,54 @@ SHOW COLLECTIONS **Rules:** - `--` to end-of-line is a comment and is ignored (inline or full-line) - Statements can span multiple lines (e.g. `INSERT BULK ... VALUES [...]`) +- `BEGIN BATCH ... END BATCH` is treated as one statement by the script splitter +- Semicolons are optional between top-level statements, but useful inside batch blocks - `RECOMMEND` statements work in `.qql` files the same way they do in the REPL - Blank lines between statements are ignored - By default all statements run even if one fails; use `--stop-on-error` to halt early +--- + +## BEGIN BATCH — group statements for fewer Qdrant calls + +Use `BEGIN BATCH ... END BATCH` when you want QQL to parse several statements as one executable batch. The executor keeps statement order in the returned results while grouping compatible operations internally. + +```sql +BEGIN BATCH; + SEARCH articles SIMILAR TO 'stroke symptoms' LIMIT 5 WHERE department = 'neurology'; + SEARCH articles SIMILAR TO 'cardiac markers' LIMIT 5 WHERE department = 'cardiology'; + RECOMMEND FROM articles POSITIVE IDS (1001, 1002) LIMIT 5; +END BATCH +``` + +Batch execution rules: + +- compatible `SEARCH` / `RECOMMEND` statements for the same collection use Qdrant `query_batch_points` +- compatible `INSERT` statements are combined into one bulk insert +- incompatible or mutation statements still execute in order +- each child statement produces its own `ExecutionResult` + +You can also use batch blocks directly in the REPL or through `Connection.run_query()`. + +```python +from qql import Connection + +with Connection("http://localhost:6333") as conn: + result = conn.run_query(""" + BEGIN BATCH; + SEARCH articles SIMILAR TO 'neurology' LIMIT 5; + SEARCH articles SIMILAR TO 'cardiology' LIMIT 5; + END BATCH + """) + + for child in result.data: + print(child.message) +``` + +Programmatic callers can use `run_queries_batch()` or `QQLBatch` instead of writing a batch block by hand. See [Programmatic Usage](programmatic.md#batch-execution). + +--- + **Included examples:** - [`resources/sample.qql`](../resources/sample.qql) seeds the demo medical dataset - [`resources/sample_v2.qql`](../resources/sample_v2.qql) is a compact end-to-end example with explicit IDs and runnable `RECOMMEND` statements diff --git a/src/qql/__init__.py b/src/qql/__init__.py index deeb737..296dbfe 100644 --- a/src/qql/__init__.py +++ b/src/qql/__init__.py @@ -12,15 +12,22 @@ QQLConfig, load_config, ) -from .connection import Connection +from .connection import Connection, QQLBatch, OperationProxy +from .async_connection import AsyncConnection, QQLAsyncBatch, AsyncOperationProxy from .exceptions import QQLError, QQLRuntimeError, QQLSyntaxError from .executor import ExecutionResult, Executor +from .async_executor import AsyncExecutor from .lexer import Lexer from .parser import Parser __all__ = [ "__version__", "Connection", + "QQLBatch", + "OperationProxy", + "AsyncConnection", + "QQLAsyncBatch", + "AsyncOperationProxy", "DEFAULT_DENSE_VECTOR_NAME", "DEFAULT_MODEL", "DEFAULT_SPARSE_VECTOR_NAME", @@ -30,6 +37,7 @@ "QQLSyntaxError", "ExecutionResult", "Executor", + "AsyncExecutor", "Lexer", "Parser", "load_config", diff --git a/src/qql/ast_nodes.py b/src/qql/ast_nodes.py index 5f0b2f2..f50d92b 100644 --- a/src/qql/ast_nodes.py +++ b/src/qql/ast_nodes.py @@ -340,6 +340,11 @@ class UpdatePayloadStmt: query_filter: FilterExpr | None = None +@dataclass(frozen=True) +class BatchBlockStmt: + statements: tuple[ASTNode, ...] + + # Union type for all top-level statement nodes ASTNode = ( InsertStmt @@ -357,4 +362,5 @@ class UpdatePayloadStmt: | DeleteStmt | UpdateVectorStmt | UpdatePayloadStmt + | BatchBlockStmt ) diff --git a/src/qql/async_connection.py b/src/qql/async_connection.py new file mode 100644 index 0000000..1125e48 --- /dev/null +++ b/src/qql/async_connection.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +from typing import Any +from qdrant_client import AsyncQdrantClient + +from .config import DEFAULT_MODEL, QQLConfig +from .async_executor import AsyncExecutor +from .executor import ExecutionResult +from .lexer import Lexer +from .parser import Parser +from .utils import render_parameterized_query + + +class AsyncConnection: + """Stateful asynchronous connection to a Qdrant instance. + + Creates a single ``AsyncQdrantClient`` and ``AsyncExecutor`` once and reuses + them for every :meth:`run_query` call — more efficient than the standalone + one-shot helpers, which create a fresh client on every invocation. + + **Basic usage**:: + + conn = AsyncConnection("http://localhost:6333", secret="my-key") + result = await conn.run_query( + "INSERT INTO COLLECTION docs VALUES {'text': 'hello world'}" + ) + result = await conn.run_query("SEARCH docs SIMILAR TO 'hello' LIMIT 5") + await conn.close() + + **Context manager (preferred)**:: + + async with AsyncConnection("http://localhost:6333") as conn: + result = await conn.run_query("SHOW COLLECTIONS") + print(result.data) + + **Qdrant Cloud**:: + + async with AsyncConnection("https://.qdrant.io", secret="") as conn: + result = await conn.run_query("SHOW COLLECTIONS") + + **Custom embedding model**:: + + async with AsyncConnection( + "http://localhost:6333", + default_model="BAAI/bge-base-en-v1.5", + ) as conn: + result = await conn.run_query( + "INSERT INTO COLLECTION docs VALUES {'text': 'hello'}" + ) + """ + + def __init__( + self, + url: str = "http://localhost:6333", + secret: str | None = None, + default_model: str | None = None, + prefer_grpc: bool = False, + grpc_port: int = 6334, + ) -> None: + """Create an asynchronous connection to a Qdrant instance. + + Args: + url: Base URL of the Qdrant instance (default: ``http://localhost:6333``). + secret: API key for authenticated instances; ``None`` for unauthenticated. + default_model: Dense embedding model used when no ``USING MODEL`` clause + is specified. Defaults to ``sentence-transformers/all-MiniLM-L6-v2``. + prefer_grpc: Whether to connect via fast gRPC transport. + grpc_port: The gRPC port of Qdrant instance (default: 6334). + """ + self._config = QQLConfig( + url=url, + secret=secret, + default_model=default_model or DEFAULT_MODEL, + ) + client_kwargs = {"url": url, "api_key": secret} + if prefer_grpc: + client_kwargs["prefer_grpc"] = True + client_kwargs["grpc_port"] = grpc_port + self._client = AsyncQdrantClient(**client_kwargs) + self._executor = AsyncExecutor(self._client, self._config) + + # ── Public API ──────────────────────────────────────────────────────── + + async def run_query(self, query: str) -> ExecutionResult: + """Parse and execute a single QQL statement asynchronously. + + Args: + query: A QQL query string, e.g. ``"SEARCH docs SIMILAR TO 'hello' LIMIT 5"``. + + Returns: + An :class:`~qql.ExecutionResult` with ``success``, ``message``, and ``data`` fields. + + Raises: + QQLSyntaxError: The query string could not be parsed. + QQLRuntimeError: The query parsed correctly but Qdrant rejected it. + """ + tokens = Lexer().tokenize(query) + node = Parser(tokens).parse() + return await self._executor.execute(node) + + async def run_queries_batch(self, queries: list[str]) -> list[ExecutionResult]: + """Parse and execute a batch of QQL statements asynchronously. + + Combines compatible operations (such as SEARCH queries) to execute in + a single network request. + """ + from .ast_nodes import BatchBlockStmt + nodes = [] + for q in queries: + tokens = Lexer().tokenize(q) + node = Parser(tokens).parse() + nodes.append(node) + + batch_node = BatchBlockStmt(statements=tuple(nodes)) + res = await self._executor.execute(batch_node) + return res.data + + async def run_parameterized_query(self, template: str, params: dict[str, Any]) -> ExecutionResult: + """Execute one QQL query template with named parameters asynchronously. + + Uses named placeholders prefixed with ':' (e.g. :query, :category). + """ + return await self.run_query(render_parameterized_query(template, params)) + + async def run_parameterized_batch(self, template: str, params: list[dict[str, Any]]) -> list[ExecutionResult]: + """Execute a single QQL query template with a batch of parameters asynchronously. + + Uses named placeholders prefixed with ':' (e.g. :query, :category). + """ + queries = [render_parameterized_query(template, p) for p in params] + return await self.run_queries_batch(queries) + + async def close(self) -> None: + """Close the underlying Qdrant asynchronous connection pool.""" + await self._client.close() + + # ── Context manager ─────────────────────────────────────────────────── + + async def __aenter__(self) -> AsyncConnection: + return self + + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + await self.close() + + # ── Power-user properties ───────────────────────────────────────────── + + @property + def config(self) -> QQLConfig: + """The :class:`~qql.QQLConfig` in use (url, secret, default_model).""" + return self._config + + @property + def executor(self) -> AsyncExecutor: + """Direct access to the :class:`~qql.AsyncExecutor` for low-level use. + + Example — run multiple statements sharing a pre-built AST node:: + + from qql.lexer import Lexer + from qql.parser import Parser + + conn = AsyncConnection("http://localhost:6333") + tokens = Lexer().tokenize("SHOW COLLECTIONS") + node = Parser(tokens).parse() + result = await conn.executor.execute(node) + """ + return self._executor + + +class QQLAsyncBatch: + """Asynchronous session context manager for executing batch queries and mutations in QQL.""" + + def __init__(self, connection: AsyncConnection) -> None: + self.connection = connection + self._queries: list[str] = [] + self._proxies: list[AsyncOperationProxy] = [] + + def add(self, query: str) -> AsyncOperationProxy: + """Queue a QQL statement for batch execution.""" + self._queries.append(query) + proxy = AsyncOperationProxy() + self._proxies.append(proxy) + return proxy + + async def __aenter__(self) -> QQLAsyncBatch: + return self + + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + if exc_type is not None: + return + if not self._queries: + return + results = await self.connection.run_queries_batch(self._queries) + for proxy, res in zip(self._proxies, results): + proxy._resolve(res) + + +class AsyncOperationProxy: + """Proxy handle that resolves to an ExecutionResult after QQLAsyncBatch exits.""" + + def __init__(self) -> None: + self._result: ExecutionResult | None = None + + def _resolve(self, result: ExecutionResult) -> None: + self._result = result + + @property + def result(self) -> ExecutionResult: + """The resolved ExecutionResult.""" + if self._result is None: + raise RuntimeError("AsyncBatch has not been executed yet.") + return self._result diff --git a/src/qql/async_executor.py b/src/qql/async_executor.py new file mode 100644 index 0000000..7d7ec36 --- /dev/null +++ b/src/qql/async_executor.py @@ -0,0 +1,1402 @@ +from __future__ import annotations + +import time +import asyncio +from typing import Any + +from qdrant_client import AsyncQdrantClient +from qdrant_client.http.exceptions import UnexpectedResponse +from qdrant_client.models import ( + Distance, + Filter, + FusionQuery, + LookupLocation, + Modifier, + PointStruct, + PointVectors, + Prefetch, + RecommendInput, + RecommendQuery, + QueryRequest, + SearchParams, + SparseVector, + SparseVectorParams, + VectorParams, + PayloadSchemaType, +) + +from .ast_nodes import ( + ASTNode, + AlterCollectionStmt, + CreateCollectionStmt, + CreateIndexStmt, + DeleteStmt, + DropCollectionStmt, + InsertBulkStmt, + InsertStmt, + RecommendStmt, + SelectStmt, + ScrollStmt, + SearchStmt, + ShowCollectionStmt, + ShowCollectionsStmt, + UpdateVectorStmt, + UpdatePayloadStmt, + BatchBlockStmt, +) +from .config import QQLConfig +from .embedder import Embedder, SparseEmbedder +from .exceptions import QQLRuntimeError +from .executor import Executor, ExecutionResult, CollectionTopology +from .utils import ( + build_bulk_insert_from_group, + build_dense_point_vector, + build_dense_query, + collection_topology_kwargs, + exclude_ids_from_filter, + extract_point_id_and_payload, + group_batch_statements, + has_mmr, + inserted_point_results, + parse_recommend_strategy, + resolve_hybrid_fusion, + validate_search_mmr_usage, +) + +_RERANK_FETCH_MULTIPLIER = 4 +_HYBRID_PREFETCH_MULTIPLIER = 4 +_COLLECTION_VISIBILITY_TIMEOUT_SECONDS = 5.0 +_COLLECTION_VISIBILITY_POLL_SECONDS = 0.05 + + +class AsyncExecutor(Executor): + """Asynchronous QQL execution engine for ``AsyncQdrantClient``. + + The async executor mirrors :class:`~qql.Executor` at the statement boundary: + every AST node supported by the sync executor has an async execution path + here. Pure parsing, validation, vector-shaping, filter-building, and result + formatting helpers live in ``qql.utils`` or are inherited from + :class:`~qql.Executor`; only Qdrant client calls and collection-creation + coordination are implemented with ``async``/``await`` in this module. + """ + + def __init__(self, client: AsyncQdrantClient, config: QQLConfig) -> None: + super().__init__(client=client, config=config) # type: ignore[arg-type] + self._client: AsyncQdrantClient = client + self._creation_lock = asyncio.Lock() + + async def execute(self, node: ASTNode) -> ExecutionResult: + if isinstance(node, InsertBulkStmt): + return await self._execute_insert_bulk(node) + if isinstance(node, InsertStmt): + return await self._execute_insert(node) + if isinstance(node, CreateCollectionStmt): + return await self._execute_create(node) + if isinstance(node, AlterCollectionStmt): + return await self._execute_alter_collection(node) + if isinstance(node, CreateIndexStmt): + return await self._execute_create_index(node) + if isinstance(node, DropCollectionStmt): + return await self._execute_drop(node) + if isinstance(node, ShowCollectionsStmt): + return await self._execute_show(node) + if isinstance(node, ShowCollectionStmt): + return await self._execute_show_collection(node) + if isinstance(node, ScrollStmt): + return await self._execute_scroll(node) + if isinstance(node, SelectStmt): + return await self._execute_select(node) + if isinstance(node, SearchStmt): + return await self._execute_search(node) + if isinstance(node, RecommendStmt): + return await self._execute_recommend(node) + if isinstance(node, DeleteStmt): + return await self._execute_delete(node) + if isinstance(node, UpdateVectorStmt): + return await self._execute_update_vector(node) + if isinstance(node, UpdatePayloadStmt): + return await self._execute_update_payload(node) + if isinstance(node, BatchBlockStmt): + return await self._execute_batch_block(node) + raise QQLRuntimeError(f"Unknown AST node type: {type(node)}") + + # ── Topology & Helper methods ───────────────────────────────────────── + + async def _resolve_topology(self, name: str) -> CollectionTopology: + if not await self._client.collection_exists(name): + return CollectionTopology(exists=False, is_named_dense=False) + + info = await self._client.get_collection(name) + params = info.config.params + vectors = params.vectors # type: ignore[union-attr] + sparse_vectors = params.sparse_vectors or {} + return CollectionTopology(**collection_topology_kwargs(vectors, sparse_vectors)) + + async def _ensure_collection( + self, + name: str, + vector_size: int, + topology: CollectionTopology, + explicit_vector: str | None, + ) -> None: + if topology.exists: + info = await self._client.get_collection(name) + vectors = info.config.params.vectors # type: ignore[union-attr] + if isinstance(vectors, dict): + vector_name = topology.dense_using(explicit_vector) + if vector_name is None: + raise QQLRuntimeError("Collection has no dense vector") + vector_config = vectors[vector_name] + expected_size = getattr(vector_config, "size", None) + if expected_size is not None and expected_size != vector_size: + raise QQLRuntimeError( + f"Vector dimension mismatch: collection '{name}' vector " + f"'{vector_name}' expects {expected_size} dims, but " + f"model produces {vector_size} dims. Specify a compatible " + "model with USING MODEL ''." + ) + elif vectors is not None: + if vectors.size != vector_size: + raise QQLRuntimeError( + f"Vector dimension mismatch: collection '{name}' expects " + f"{vectors.size} dims, but model produces {vector_size} dims. " + f"Specify a compatible model with USING MODEL ''." + ) + else: + raise QQLRuntimeError("Collection has no dense vector") + else: + async with self._creation_lock: + current_topology = await self._resolve_topology(name) + if current_topology.exists: + await self._ensure_collection(name, vector_size, current_topology, explicit_vector) + return + + await self._create_collection_and_wait( + collection_name=name, + vectors_config={ + explicit_vector or self._default_dense_vector_name(): VectorParams( + size=vector_size, distance=Distance.COSINE + ) + }, + ) + + async def _create_collection_and_wait(self, **kwargs: Any) -> None: + collection_name = kwargs["collection_name"] + await self._client.create_collection(**kwargs) + + deadline = time.monotonic() + _COLLECTION_VISIBILITY_TIMEOUT_SECONDS + while time.monotonic() < deadline: + if await self._client.collection_exists(collection_name): + return + await asyncio.sleep(_COLLECTION_VISIBILITY_POLL_SECONDS) + + raise QQLRuntimeError( + f"Collection '{collection_name}' was created but did not become visible in time" + ) + + async def _build_hybrid_vectors( + self, + query_text: str, + dense_model: str, + sparse_model_name: str, + ) -> tuple[list[float], SparseVector]: + dense_embedder = Embedder(dense_model) + sparse_embedder = SparseEmbedder(sparse_model_name) + + dense_vector = dense_embedder.embed(query_text) + sparse_obj = sparse_embedder.query_embed(query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + return dense_vector, sparse_vector + + # ── Statement executors ─────────────────────────────────────────────── + + async def _execute_insert(self, node: InsertStmt) -> ExecutionResult: + if "text" not in node.values: + raise QQLRuntimeError("INSERT requires a 'text' field in VALUES") + + topology = await self._resolve_topology(node.collection) + use_hybrid = node.hybrid or (topology.exists and topology.is_hybrid) + + if use_hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_embedder = Embedder(dense_model) + sparse_embedder = SparseEmbedder(sparse_model_name) + + dense_vector = dense_embedder.embed(node.values["text"]) + sparse_obj = sparse_embedder.embed(node.values["text"]) + + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + + dense_name = node.dense_vector or self._default_dense_vector_name() + sparse_name = node.sparse_vector or self._default_sparse_vector_name() + + if topology.exists: + resolved_dense = topology.dense_using(node.dense_vector) + if resolved_dense is None: + raise QQLRuntimeError( + "Hybrid collections must use named dense vectors" + ) + dense_name = resolved_dense + sparse_name = topology.sparse_using(node.sparse_vector) + else: + async with self._creation_lock: + current_topology = await self._resolve_topology(node.collection) + if not current_topology.exists: + await self._create_collection_and_wait( + collection_name=node.collection, + vectors_config={ + dense_name: VectorParams( + size=len(dense_vector), distance=Distance.COSINE + ) + }, + sparse_vectors_config={ + sparse_name: SparseVectorParams(modifier=Modifier.IDF) + }, + ) + else: + dense_name = current_topology.dense_using(node.dense_vector) or dense_name + sparse_name = current_topology.sparse_using(node.sparse_vector) + + point_id, payload = extract_point_id_and_payload(node.values) + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=[ + PointStruct( + id=point_id, + vector={dense_name: dense_vector, sparse_name: sparse_vector}, + payload=payload, + ) + ], + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted 1 point [{point_id}] (hybrid)", + data={"id": point_id, "collection": node.collection}, + ) + + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + vector = embedder.embed(node.values["text"]) + + await self._ensure_collection( + node.collection, len(vector), topology, node.dense_vector + ) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), + ) + + point_id, payload = extract_point_id_and_payload(node.values) + + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=[PointStruct(id=point_id, vector=point_vector, payload=payload)], + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted 1 point [{point_id}]", + data={"id": point_id, "collection": node.collection}, + ) + + async def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: + if not node.values_list: + raise QQLRuntimeError("INSERT BULK VALUES list is empty") + for i, vals in enumerate(node.values_list): + if "text" not in vals: + raise QQLRuntimeError( + f"INSERT BULK: item at index {i} is missing required 'text' field" + ) + + topology = await self._resolve_topology(node.collection) + use_hybrid = node.hybrid or (topology.exists and topology.is_hybrid) + + if use_hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_embedder = Embedder(dense_model) + sparse_embedder = SparseEmbedder(sparse_model_name) + dense_name = node.dense_vector or self._default_dense_vector_name() + sparse_name = node.sparse_vector or self._default_sparse_vector_name() + if topology.exists: + resolved_dense = topology.dense_using(node.dense_vector) + if resolved_dense is None: + raise QQLRuntimeError( + "Hybrid collections must use named dense vectors" + ) + dense_name = resolved_dense + sparse_name = topology.sparse_using(node.sparse_vector) + + dense_vectors = [ + dense_embedder.embed(vals["text"]) for vals in node.values_list + ] + sparse_objs = [sparse_embedder.embed(vals["text"]) for vals in node.values_list] + + first_dense_vector = dense_vectors[0] if dense_vectors else None + points: list[PointStruct] = [] + for idx, vals in enumerate(node.values_list): + point_id, payload = extract_point_id_and_payload(vals) + dense_vector = dense_vectors[idx] + sparse_obj = sparse_objs[idx] + sparse_vector = SparseVector( + indices=sparse_obj["indices"], values=sparse_obj["values"] + ) + points.append( + PointStruct( + id=point_id, + vector={dense_name: dense_vector, sparse_name: sparse_vector}, + payload=payload, + ) + ) + + if not topology.exists: + assert first_dense_vector is not None + async with self._creation_lock: + current_topology = await self._resolve_topology(node.collection) + if not current_topology.exists: + await self._create_collection_and_wait( + collection_name=node.collection, + vectors_config={ + dense_name: VectorParams(size=len(first_dense_vector), distance=Distance.COSINE) + }, + sparse_vectors_config={ + sparse_name: SparseVectorParams(modifier=Modifier.IDF) + }, + ) + else: + dense_name = current_topology.dense_using(node.dense_vector) or dense_name + sparse_name = current_topology.sparse_using(node.sparse_vector) + + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=points, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT BULK: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted {len(points)} points (hybrid)", + data={"ids": [p.id for p in points]}, + ) + + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + + vectors = [embedder.embed(vals["text"]) for vals in node.values_list] + + first_vector = vectors[0] if vectors else None + points = [] + for idx, vals in enumerate(node.values_list): + vector = vectors[idx] + point_id, payload = extract_point_id_and_payload(vals) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), + ) + points.append( + PointStruct(id=point_id, vector=point_vector, payload=payload) + ) + + assert first_vector is not None + await self._ensure_collection( + node.collection, len(first_vector), topology, node.dense_vector + ) + + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=points, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT BULK: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted {len(points)} points", + data={"ids": [p.id for p in points]}, + ) + + async def _execute_create(self, node: CreateCollectionStmt) -> ExecutionResult: + if await self._client.collection_exists(node.collection): + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' already exists", + ) + + dense_model_name = node.model or self._config.default_model + + quant_config = ( + self._build_quantization_config(node.quantization) + if node.quantization is not None + else None + ) + quant_label = ( + f", {node.quantization.type.value} quantization" + if node.quantization is not None + else "" + ) + hnsw_config = self._build_hnsw_config(node.config) + optimizers_config = self._build_optimizers_config(node.config) + params_config = self._build_collection_params_create_kwargs(node.config) + config_label = self._describe_collection_config(node.config) + vector_on_disk = ( + node.config.vectors.on_disk + if node.config is not None and node.config.vectors is not None + else None + ) + + if node.hybrid: + embedder = Embedder(dense_model_name) + dims = embedder.dimensions + dense_name = node.dense_vector or self._default_dense_vector_name() + sparse_name = node.sparse_vector or self._default_sparse_vector_name() + create_kwargs: dict[str, Any] = { + "collection_name": node.collection, + "vectors_config": { + dense_name: VectorParams( + size=dims, + distance=Distance.COSINE, + on_disk=vector_on_disk, + ) + }, + "sparse_vectors_config": { + sparse_name: SparseVectorParams(modifier=Modifier.IDF) + }, + } + if quant_config is not None: + create_kwargs["quantization_config"] = quant_config + if hnsw_config is not None: + create_kwargs["hnsw_config"] = hnsw_config + if optimizers_config is not None: + create_kwargs["optimizers_config"] = optimizers_config + create_kwargs.update(params_config) + await self._create_collection_and_wait(**create_kwargs) + return ExecutionResult( + success=True, + message=( + f"Collection '{node.collection}' created " + f"(hybrid: {dims}-dim dense + BM25 sparse, cosine distance{quant_label}{config_label})" + ), + ) + + embedder = Embedder(dense_model_name) + dims = embedder.dimensions + dense_name = node.dense_vector or self._default_dense_vector_name() + create_kwargs = { + "collection_name": node.collection, + "vectors_config": { + dense_name: VectorParams( + size=dims, + distance=Distance.COSINE, + on_disk=vector_on_disk, + ) + }, + } + if quant_config is not None: + create_kwargs["quantization_config"] = quant_config + if hnsw_config is not None: + create_kwargs["hnsw_config"] = hnsw_config + if optimizers_config is not None: + create_kwargs["optimizers_config"] = optimizers_config + create_kwargs.update(params_config) + await self._create_collection_and_wait(**create_kwargs) + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' created ({dims}-dimensional vectors, cosine distance{quant_label}{config_label})", + ) + + async def _execute_alter_collection(self, node: AlterCollectionStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + topology = await self._resolve_topology(node.collection) + + update_kwargs: dict[str, Any] = {"collection_name": node.collection} + vectors_config = self._build_vectors_config_diff(topology, node.config) + hnsw_config = self._build_hnsw_config(node.config) + optimizers_config = self._build_optimizers_config(node.config) + collection_params = self._build_collection_params_diff(node.config) + quantization_config = self._build_alter_quantization_config(node.quantization) + + if vectors_config is not None: + update_kwargs["vectors_config"] = vectors_config + if hnsw_config is not None: + update_kwargs["hnsw_config"] = hnsw_config + if optimizers_config is not None: + update_kwargs["optimizers_config"] = optimizers_config + if collection_params is not None: + update_kwargs["collection_params"] = collection_params + if quantization_config is not None: + update_kwargs["quantization_config"] = quantization_config + + try: + await self._client.update_collection(**update_kwargs) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during ALTER COLLECTION: {e}") from e + + return ExecutionResult( + success=True, + message=( + f"Collection '{node.collection}' altered" + f"{self._describe_collection_config(node.config)}" + f"{self._describe_quantization_update(node.quantization)}" + ), + ) + + async def _execute_create_index(self, node: CreateIndexStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + schema_map = { + "keyword": PayloadSchemaType.KEYWORD, + "integer": PayloadSchemaType.INTEGER, + "float": PayloadSchemaType.FLOAT, + "bool": PayloadSchemaType.BOOL, + "text": PayloadSchemaType.TEXT, + "geo": PayloadSchemaType.GEO, + "datetime": PayloadSchemaType.DATETIME, + "uuid": PayloadSchemaType.UUID, + } + try: + schema_map[node.schema] + except KeyError as e: + raise QQLRuntimeError( + "Unknown index type '" + f"{node.schema}'. Expected one of: keyword, integer, float, bool, text, geo, datetime, uuid" + ) from e + field_schema = self._build_payload_index_schema(node) + + try: + await self._client.create_payload_index( + collection_name=node.collection, + field_name=node.field_name, + field_schema=field_schema, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during CREATE INDEX: {e}") from e + + option_label = f" with options {node.options}" if node.options else "" + return ExecutionResult( + success=True, + message=( + f"Created index on '{node.collection}.{node.field_name}' as '{node.schema}'{option_label}" + ), + ) + + async def _execute_drop(self, node: DropCollectionStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + await self._client.delete_collection(node.collection) + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' dropped", + ) + + async def _execute_show(self, node: ShowCollectionsStmt) -> ExecutionResult: + response = await self._client.get_collections() + names = [c.name for c in response.collections] + return ExecutionResult( + success=True, + message=f"{len(names)} collection(s) found", + data=names, + ) + + async def _execute_show_collection(self, node: ShowCollectionStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + info = await self._client.get_collection(node.collection) + config = info.config + params = config.params + + vectors = params.vectors # type: ignore[union-attr] + sparse_vector_params = params.sparse_vectors or {} + if isinstance(vectors, dict): + vector_details = {} + for vname, vconfig in vectors.items(): + vector_details[vname] = { + "size": vconfig.size, + "distance": str(vconfig.distance) if vconfig.distance else None, + "on_disk": vconfig.on_disk, + } + elif vectors is None: + raise QQLRuntimeError( + f"Collection '{node.collection}' has no vector configuration" + ) + else: + vector_details = { + "": { + "size": vectors.size, + "distance": str(vectors.distance) if vectors.distance else None, + "on_disk": vectors.on_disk, + } + } + topology = "hybrid" if sparse_vector_params else "dense" + + sparse_vectors = {} + if sparse_vector_params: + for sname, sconfig in sparse_vector_params.items(): + sparse_vectors[sname] = { + "modifier": str(sconfig.modifier) if sconfig.modifier else None, + } + + quant_config = config.quantization_config + quantization = None + if quant_config is not None: + qtype = type(quant_config).__name__ + if hasattr(quant_config, "scalar"): + quantization = "scalar" + elif hasattr(quant_config, "binary"): + quantization = "binary" + elif hasattr(quant_config, "product"): + quantization = "product" + elif hasattr(quant_config, "turbo"): + quantization = "turbo" + else: + quantization = qtype + + hnsw = { + "m": config.hnsw_config.m, + "ef_construct": config.hnsw_config.ef_construct, + } + if config.hnsw_config.full_scan_threshold is not None: + hnsw["full_scan_threshold"] = config.hnsw_config.full_scan_threshold + if config.hnsw_config.max_indexing_threads is not None: + hnsw["max_indexing_threads"] = config.hnsw_config.max_indexing_threads + if config.hnsw_config.on_disk is not None: + hnsw["on_disk"] = config.hnsw_config.on_disk + if config.hnsw_config.payload_m is not None: + hnsw["payload_m"] = config.hnsw_config.payload_m + if config.hnsw_config.inline_storage is not None: + hnsw["inline_storage"] = config.hnsw_config.inline_storage + + payload_indexes = {} + for field_name, idx_info in (info.payload_schema or {}).items(): + payload_indexes[field_name] = self._serialize_payload_index_info(idx_info) + + sharding = { + "shard_number": params.shard_number, + "replication_factor": params.replication_factor, + "write_consistency_factor": params.write_consistency_factor, + "read_fan_out_factor": params.read_fan_out_factor, + "read_fan_out_delay_ms": params.read_fan_out_delay_ms, + "on_disk_payload": params.on_disk_payload, + } + + data = { + "name": node.collection, + "status": str(info.status), + "points_count": info.points_count, + "indexed_vectors_count": info.indexed_vectors_count, + "segments_count": info.segments_count, + "topology": topology, + "vectors": vector_details, + "sparse_vectors": sparse_vectors or None, + "quantization": quantization, + "hnsw_config": hnsw, + "payload_schema": payload_indexes or None, + "sharding": sharding, + } + + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' diagnostics", + data=data, + ) + + async def _execute_scroll(self, node: ScrollStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + scroll_filter: Filter | None = None + if node.query_filter is not None: + scroll_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + + try: + records, next_offset = await self._client.scroll( + collection_name=node.collection, + scroll_filter=scroll_filter, + limit=node.limit, + offset=node.after, + with_payload=True, + with_vectors=False, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SCROLL: {e}") from e + + points = [ + {"id": str(rec.id), "payload": rec.payload or {}} + for rec in records + ] + return ExecutionResult( + success=True, + message=f"Scrolled {len(points)} point(s) from '{node.collection}'", + data={"points": points, "next_offset": next_offset}, + ) + + async def _execute_select(self, node: SelectStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + try: + records = await self._client.retrieve( + collection_name=node.collection, + ids=[node.point_id], + with_payload=True, + with_vectors=False, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SELECT: {e}") from e + + if not records: + return ExecutionResult( + success=True, + message=f"Point '{node.point_id}' not found in '{node.collection}'", + ) + + record = records[0] + return ExecutionResult( + success=True, + message=f"Retrieved point '{node.point_id}' from '{node.collection}'", + data={"id": str(record.id), "payload": record.payload or {}}, + ) + + async def _execute_search(self, node: SearchStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + topology = await self._resolve_topology(node.collection) + + qdrant_filter: Filter | None = None + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + + search_params = self._build_search_params(node.with_clause) + validate_search_mmr_usage(node) + + fetch_limit = node.limit * _RERANK_FETCH_MULTIPLIER if node.rerank else node.limit + + lookup_from: LookupLocation | None = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + if node.group_by is not None: + return await self._execute_search_groups( + node, qdrant_filter, search_params, topology + ) + + if node.hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_vector, sparse_vector = await self._build_hybrid_vectors( + node.query_text, dense_model, sparse_model_name + ) + + try: + response = await self._client.query_points( + collection_name=node.collection, + prefetch=[ + Prefetch( + query=build_dense_query(dense_vector, node.with_clause), + using=topology.dense_using(node.dense_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + Prefetch( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + ], + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), + limit=fetch_limit, + offset=node.offset or None, + query_filter=qdrant_filter, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SEARCH: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (hybrid, reranked)", + data=results, + ) + + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (hybrid)", + data=results, + ) + + if node.sparse_only: + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + sparse_embedder = SparseEmbedder(sparse_model_name) + sparse_obj = sparse_embedder.query_embed(node.query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + + try: + response = await self._client.query_points( + collection_name=node.collection, + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=fetch_limit, + offset=node.offset or None, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SEARCH: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (sparse, reranked)", + data=results, + ) + + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (sparse)", + data=results, + ) + + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + vector = embedder.embed(node.query_text) + + try: + query_using = topology.dense_using(node.dense_vector) + response = await self._client.query_points( + collection_name=node.collection, + query=build_dense_query(vector, node.with_clause), + using=query_using, + limit=fetch_limit, + offset=node.offset or None, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SEARCH: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (reranked)", + data=results, + ) + + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s)", + data=results, + ) + + async def _execute_search_groups( + self, + node: SearchStmt, + qdrant_filter: Filter | None, + search_params: SearchParams | None, + topology: CollectionTopology, + ) -> ExecutionResult: + lookup_from: LookupLocation | None = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + try: + if node.hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_vector, sparse_vector = await self._build_hybrid_vectors( + node.query_text, dense_model, sparse_model_name + ) + response = await self._client.query_points_groups( + collection_name=node.collection, + group_by=node.group_by, + prefetch=[ + Prefetch( + query=build_dense_query(dense_vector, node.with_clause), + using=topology.dense_using(node.dense_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + Prefetch( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + ], + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), + limit=node.limit, + group_size=node.group_size, + query_filter=qdrant_filter, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + label = "hybrid, grouped" + elif node.sparse_only: + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + sparse_obj = SparseEmbedder(sparse_model_name).query_embed(node.query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + response = await self._client.query_points_groups( + collection_name=node.collection, + group_by=node.group_by, + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=node.limit, + group_size=node.group_size, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + label = "sparse, grouped" + else: + model_name = node.model or self._config.default_model + vector = Embedder(model_name).embed(node.query_text) + query_using = topology.dense_using(node.dense_vector) + response = await self._client.query_points_groups( + collection_name=node.collection, + group_by=node.group_by, + query=build_dense_query(vector, node.with_clause), + using=query_using, + limit=node.limit, + group_size=node.group_size, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + label = "grouped" + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during GROUP BY SEARCH: {e}") from e + + groups = [ + { + "group_id": str(g.id), + "hits": [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in g.hits + ], + } + for g in response.groups + ] + return ExecutionResult( + success=True, + message=f"Found {len(groups)} group(s) by '{node.group_by}' ({label})", + data=groups, + ) + + async def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + qdrant_filter: Filter | None = None + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + qdrant_filter = exclude_ids_from_filter( + qdrant_filter, + [*node.positive_ids, *node.negative_ids], + ) + + recommend_input = RecommendInput( + positive=list(node.positive_ids), + negative=list(node.negative_ids) or None, + strategy=parse_recommend_strategy(node.strategy), + ) + + search_params = self._build_search_params(node.with_clause) + if has_mmr(node.with_clause): + raise QQLRuntimeError("MMR is supported only for SEARCH statements") + + lookup_from: LookupLocation | None = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + try: + response = await self._client.query_points( + collection_name=node.collection, + query=RecommendQuery(recommend=recommend_input), + limit=node.limit, + offset=node.offset or None, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + using=node.using, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during RECOMMEND: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + return ExecutionResult( + success=True, + message=f"Found {len(results)} recommendation(s)", + data=results, + ) + + async def _execute_delete(self, node: DeleteStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + try: + if node.query_filter is not None: + await self._client.delete( + collection_name=node.collection, + wait=True, + points_selector=self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ), + ) + return ExecutionResult( + success=True, + message=f"Deleted points from '{node.collection}' by filter", + ) + + from qdrant_client.models import PointIdsList + + if node.point_id is None: + raise QQLRuntimeError("DELETE requires either a point id or a filter") + + await self._client.delete( + collection_name=node.collection, + wait=True, + points_selector=PointIdsList(points=[node.point_id]), + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during DELETE: {e}") from e + + return ExecutionResult( + success=True, + message=f"Deleted point '{node.point_id}' from '{node.collection}'", + ) + + async def _execute_update_vector(self, node: UpdateVectorStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + topology = await self._resolve_topology(node.collection) + vector_name = topology.dense_payload_name(node.vector_name) + vector_struct: Any = ( + {vector_name: list(node.vector)} if vector_name else list(node.vector) + ) + try: + await self._client.update_vectors( + collection_name=node.collection, + points=[PointVectors(id=node.point_id, vector=vector_struct)], + wait=True, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during UPDATE VECTOR: {e}") from e + return ExecutionResult( + success=True, + message=f"Updated vector for point [{node.point_id}] in '{node.collection}'", + data=[], + ) + + async def _execute_update_payload(self, node: UpdatePayloadStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + try: + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + await self._client.set_payload( + collection_name=node.collection, + payload=node.payload, + points=qdrant_filter, + wait=True, + ) + return ExecutionResult( + success=True, + message=f"Payload updated in '{node.collection}' (filter-based)", + data=[], + ) + await self._client.set_payload( + collection_name=node.collection, + payload=node.payload, + points=[node.point_id], + wait=True, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during UPDATE PAYLOAD: {e}") from e + return ExecutionResult( + success=True, + message=f"Payload updated for point [{node.point_id}] in '{node.collection}'", + data=[], + ) + + async def _execute_batch_block(self, node: BatchBlockStmt) -> ExecutionResult: + if not node.statements: + return ExecutionResult(success=True, message="Executed empty batch", data=[]) + + all_results = [] + succeeded_count = 0 + + for group in group_batch_statements(node.statements): + if group.kind == 'query': + res = await self._execute_query_batch(group.collection, group.statements) + all_results.extend(res) + succeeded_count += len([r for r in res if r.success]) + elif group.kind == 'insert': + bulk_node = build_bulk_insert_from_group( + group.collection, + group.statements, + ) + res = await self._execute_insert_bulk(bulk_node) + insert_results = inserted_point_results( + res, + group.statements, + ExecutionResult, + ) + all_results.extend(insert_results) + succeeded_count += len([r for r in insert_results if r.success]) + else: + for s in group.statements: + res = await self.execute(s) + all_results.append(res) + if res.success: + succeeded_count += 1 + + total_stmts = len(node.statements) + return ExecutionResult( + success=succeeded_count == total_stmts, + message=f"Batch executed {succeeded_count}/{total_stmts} statement(s) successfully", + data=all_results, + ) + + async def _execute_query_batch( + self, + collection_name: str, + nodes: list[SearchStmt | RecommendStmt], + ) -> list[ExecutionResult]: + if not await self._client.collection_exists(collection_name): + raise QQLRuntimeError(f"Collection '{collection_name}' does not exist") + + topology = await self._resolve_topology(collection_name) + requests = [] + + for node in nodes: + qdrant_filter = None + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + + search_params = self._build_search_params(node.with_clause) + + lookup_from = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + if isinstance(node, SearchStmt): + validate_search_mmr_usage(node) + fetch_limit = node.limit * _RERANK_FETCH_MULTIPLIER if node.rerank else node.limit + + if node.hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_vector, sparse_vector = await self._build_hybrid_vectors( + node.query_text, dense_model, sparse_model_name + ) + + req = QueryRequest( + prefetch=[ + Prefetch( + query=build_dense_query(dense_vector, node.with_clause), + using=topology.dense_using(node.dense_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + Prefetch( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + ], + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + elif node.sparse_only: + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + sparse_embedder = SparseEmbedder(sparse_model_name) + sparse_obj = sparse_embedder.query_embed(node.query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + + req = QueryRequest( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + else: + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + vector = embedder.embed(node.query_text) + query_using = topology.dense_using(node.dense_vector) + + req = QueryRequest( + query=build_dense_query(vector, node.with_clause), + using=query_using, + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + else: + qdrant_filter = exclude_ids_from_filter( + qdrant_filter, + [*node.positive_ids, *node.negative_ids], + ) + recommend_input = RecommendInput( + positive=list(node.positive_ids), + negative=list(node.negative_ids) or None, + strategy=parse_recommend_strategy(node.strategy), + ) + if has_mmr(node.with_clause): + raise QQLRuntimeError("MMR is supported only for SEARCH statements") + + req = QueryRequest( + query=RecommendQuery(recommend=recommend_input), + limit=node.limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + using=node.using, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + + requests.append(req) + + try: + responses = await self._client.query_batch_points( + collection_name=collection_name, + requests=requests, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during Batch Query: {e}") from e + + execution_results = [] + for i, response in enumerate(responses): + node = nodes[i] + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if isinstance(node, SearchStmt) and node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + label = "hybrid, reranked" if node.hybrid else ("sparse, reranked" if node.sparse_only else "reranked") + msg = f"Found {len(results)} result(s) ({label})" + else: + if isinstance(node, SearchStmt): + label = "hybrid" if node.hybrid else ("sparse" if node.sparse_only else "") + label_suffix = f" ({label})" if label else "" + msg = f"Found {len(results)} result(s){label_suffix}" + else: + msg = f"Found {len(results)} recommendation(s)" + + execution_results.append( + ExecutionResult(success=True, message=msg, data=results) + ) + + return execution_results diff --git a/src/qql/connection.py b/src/qql/connection.py index e51f19f..3b26ace 100644 --- a/src/qql/connection.py +++ b/src/qql/connection.py @@ -1,9 +1,11 @@ from __future__ import annotations +from typing import Any from .config import DEFAULT_MODEL, QQLConfig from .executor import Executor, ExecutionResult from .lexer import Lexer from .parser import Parser +from .utils import render_parameterized_query class Connection: @@ -51,6 +53,8 @@ def __init__( url: str = "http://localhost:6333", secret: str | None = None, default_model: str | None = None, + prefer_grpc: bool = False, + grpc_port: int = 6334, ) -> None: """Create a connection to a Qdrant instance. @@ -60,6 +64,8 @@ def __init__( default_model: Dense embedding model used when no ``USING MODEL`` clause is specified. Defaults to ``sentence-transformers/all-MiniLM-L6-v2``. + prefer_grpc: Whether to connect via fast gRPC transport. + grpc_port: The gRPC port of Qdrant instance (default: 6334). """ from qdrant_client import QdrantClient @@ -68,7 +74,11 @@ def __init__( secret=secret, default_model=default_model or DEFAULT_MODEL, ) - self._client = QdrantClient(url=url, api_key=secret) + client_kwargs = {"url": url, "api_key": secret} + if prefer_grpc: + client_kwargs["prefer_grpc"] = True + client_kwargs["grpc_port"] = grpc_port + self._client = QdrantClient(**client_kwargs) self._executor = Executor(self._client, self._config) # ── Public API ──────────────────────────────────────────────────────── @@ -92,6 +102,38 @@ def run_query(self, query: str) -> ExecutionResult: node = Parser(tokens).parse() return self._executor.execute(node) + def run_queries_batch(self, queries: list[str]) -> list[ExecutionResult]: + """Parse and execute a batch of QQL statements. + + Combines compatible operations (such as SEARCH queries) to execute in + a single network request. + """ + from .ast_nodes import BatchBlockStmt + nodes = [] + for q in queries: + tokens = Lexer().tokenize(q) + node = Parser(tokens).parse() + nodes.append(node) + + batch_node = BatchBlockStmt(statements=tuple(nodes)) + res = self._executor.execute(batch_node) + return res.data + + def run_parameterized_query(self, template: str, params: dict[str, Any]) -> ExecutionResult: + """Execute one QQL query template with named parameters. + + Uses named placeholders prefixed with ':' (e.g. :query, :category). + """ + return self.run_query(render_parameterized_query(template, params)) + + def run_parameterized_batch(self, template: str, params: list[dict[str, Any]]) -> list[ExecutionResult]: + """Execute a single QQL query template with a batch of parameters. + + Uses named placeholders prefixed with ':' (e.g. :query, :category). + """ + queries = [render_parameterized_query(template, p) for p in params] + return self.run_queries_batch(queries) + def close(self) -> None: """Close the underlying Qdrant HTTP connection pool. @@ -130,3 +172,48 @@ def executor(self) -> Executor: result = conn.executor.execute(node) """ return self._executor + + +class QQLBatch: + """Session context manager for executing batch queries and mutations in QQL.""" + + def __init__(self, connection: Connection) -> None: + self.connection = connection + self._queries: list[str] = [] + self._proxies: list[OperationProxy] = [] + + def add(self, query: str) -> OperationProxy: + """Queue a QQL statement for batch execution.""" + self._queries.append(query) + proxy = OperationProxy() + self._proxies.append(proxy) + return proxy + + def __enter__(self) -> QQLBatch: + return self + + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + if exc_type is not None: + return + if not self._queries: + return + results = self.connection.run_queries_batch(self._queries) + for proxy, res in zip(self._proxies, results): + proxy._resolve(res) + + +class OperationProxy: + """Proxy handle that resolves to an ExecutionResult after QQLBatch exits.""" + + def __init__(self) -> None: + self._result: ExecutionResult | None = None + + def _resolve(self, result: ExecutionResult) -> None: + self._result = result + + @property + def result(self) -> ExecutionResult: + """The resolved ExecutionResult.""" + if self._result is None: + raise RuntimeError("Batch has not been executed yet.") + return self._result diff --git a/src/qql/executor.py b/src/qql/executor.py index 975947d..fa6f6f6 100644 --- a/src/qql/executor.py +++ b/src/qql/executor.py @@ -1,7 +1,6 @@ from __future__ import annotations import time -import uuid from dataclasses import dataclass from typing import Any @@ -15,29 +14,15 @@ CompressionRatio, Distance, Disabled, - FieldCondition, Filter, - Fusion, FusionQuery, - HasIdCondition, HnswConfigDiff, - IsEmptyCondition, - IsNullCondition, KeywordIndexParams, KeywordIndexType, Language, LookupLocation, - MatchAny, - MatchExcept, - MatchPhrase, - MatchText, - MatchTextAny, - MatchValue, - Mmr, Modifier, - NearestQuery, OptimizersConfigDiff, - PayloadField, PayloadSchemaType, PointStruct, PointVectors, @@ -45,10 +30,9 @@ ProductQuantization, ProductQuantizationConfig, QuantizationSearchParams, - Range, RecommendInput, RecommendQuery, - RecommendStrategy, + QueryRequest, ScalarQuantization, ScalarQuantizationConfig, ScalarType, @@ -72,28 +56,14 @@ from .ast_nodes import ( ASTNode, AlterCollectionStmt, - AndExpr, - BetweenExpr, CollectionConfig, - CompareExpr, CreateCollectionStmt, CreateIndexStmt, DeleteStmt, DropCollectionStmt, FilterExpr, - InExpr, InsertBulkStmt, InsertStmt, - IsEmptyExpr, - IsNotEmptyExpr, - IsNotNullExpr, - IsNullExpr, - MatchAnyExpr, - MatchPhraseExpr, - MatchTextExpr, - NotExpr, - NotInExpr, - OrExpr, QuantizationUpdate, QuantizationConfig, QuantizationType, @@ -106,10 +76,27 @@ ShowCollectionsStmt, UpdateVectorStmt, UpdatePayloadStmt, + BatchBlockStmt, ) from .config import QQLConfig from .embedder import CrossEncoderEmbedder, Embedder, SparseEmbedder from .exceptions import QQLRuntimeError +from .utils import ( + build_bulk_insert_from_group, + build_dense_point_vector, + build_dense_query, + build_qdrant_filter, + collection_topology_kwargs, + exclude_ids_from_filter, + extract_point_id_and_payload, + group_batch_statements, + has_mmr, + inserted_point_results, + parse_recommend_strategy, + resolve_hybrid_fusion, + validate_search_mmr_usage, + wrap_as_filter, +) _RERANK_FETCH_MULTIPLIER = 4 _HYBRID_PREFETCH_MULTIPLIER = 4 @@ -236,6 +223,8 @@ def execute(self, node: ASTNode) -> ExecutionResult: return self._execute_update_vector(node) if isinstance(node, UpdatePayloadStmt): return self._execute_update_payload(node) + if isinstance(node, BatchBlockStmt): + return self._execute_batch_block(node) raise QQLRuntimeError(f"Unknown AST node type: {type(node)}") # ── Statement executors ─────────────────────────────────────────────── @@ -255,6 +244,10 @@ def _fetch_collection_info(self, name: str): raise QQLRuntimeError( f"Qdrant error fetching collection '{name}': {e}" ) from e + except ValueError as e: + if f"Collection {name} not found" in str(e): + return None + raise def _topology_from_collection_info(self, info: Any) -> CollectionTopology: """Parse a CollectionInfo object into a :class:`CollectionTopology`. @@ -265,40 +258,7 @@ def _topology_from_collection_info(self, info: Any) -> CollectionTopology: params = info.config.params vectors = params.vectors # type: ignore[union-attr] sparse_vectors = params.sparse_vectors or {} - - if isinstance(vectors, dict): - dense_names = tuple(vectors.keys()) - dense_sizes: tuple[tuple[str, int], ...] = tuple( - (k, v.size) - for k, v in vectors.items() - if getattr(v, "size", None) is not None - ) - has_unnamed_dense = False - is_named_dense = True - elif vectors is None: - dense_names = () - dense_sizes = () - has_unnamed_dense = False - is_named_dense = False - else: - # Single unnamed dense vector - dense_names = () - unnamed_size = getattr(vectors, "size", None) - dense_sizes = (("", unnamed_size),) if unnamed_size is not None else () - has_unnamed_dense = True - is_named_dense = False - - sparse_names = ( - tuple(sparse_vectors.keys()) if isinstance(sparse_vectors, dict) else () - ) - return CollectionTopology( - exists=True, - is_named_dense=is_named_dense, - has_unnamed_dense=has_unnamed_dense, - dense_names=dense_names, - sparse_names=sparse_names, - dense_sizes=dense_sizes, - ) + return CollectionTopology(**collection_topology_kwargs(vectors, sparse_vectors)) def _resolve_topology(self, name: str) -> CollectionTopology: """Return the topology for *name* using exactly one Qdrant API call. @@ -362,7 +322,7 @@ def _execute_insert(self, node: InsertStmt) -> ExecutionResult: }, ) - point_id, payload = self._extract_point_id_and_payload(node.values) + point_id, payload = extract_point_id_and_payload(node.values) try: self._client.upsert( collection_name=node.collection, @@ -392,9 +352,14 @@ def _execute_insert(self, node: InsertStmt) -> ExecutionResult: self._ensure_collection( node.collection, len(vector), topology, node.dense_vector ) - point_vector = self._build_dense_point_vector(topology, vector, node.dense_vector) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), + ) - point_id, payload = self._extract_point_id_and_payload(node.values) + point_id, payload = extract_point_id_and_payload(node.values) try: self._client.upsert( @@ -443,7 +408,7 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: first_dense_vector: list[float] | None = None points: list[PointStruct] = [] for vals in node.values_list: - point_id, payload = self._extract_point_id_and_payload(vals) + point_id, payload = extract_point_id_and_payload(vals) dense_vector = dense_embedder.embed(vals["text"]) if first_dense_vector is None: first_dense_vector = dense_vector @@ -483,6 +448,7 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: return ExecutionResult( success=True, message=f"Inserted {len(points)} points (hybrid)", + data={"ids": [p.id for p in points]}, ) # ── Standard dense-only bulk INSERT ─────────────────────────────── @@ -495,9 +461,12 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: vector = embedder.embed(vals["text"]) if first_vector is None: first_vector = vector - point_id, payload = self._extract_point_id_and_payload(vals) - point_vector = self._build_dense_point_vector( - topology, vector, node.dense_vector + point_id, payload = extract_point_id_and_payload(vals) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), ) points.append( PointStruct(id=point_id, vector=point_vector, payload=payload) @@ -520,6 +489,7 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: return ExecutionResult( success=True, message=f"Inserted {len(points)} points", + data={"ids": [p.id for p in points]}, ) def _execute_create(self, node: CreateCollectionStmt) -> ExecutionResult: @@ -889,7 +859,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: ) search_params = self._build_search_params(node.with_clause) - self._validate_search_mmr_usage(node) + validate_search_mmr_usage(node) # When reranking is requested, fetch more candidates so the reranker has # enough material to reorder; only `node.limit` results are returned. @@ -921,7 +891,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: collection_name=node.collection, prefetch=[ Prefetch( - query=self._build_dense_query(dense_vector, node.with_clause), + query=build_dense_query(dense_vector, node.with_clause), using=topology.dense_using(node.dense_vector), limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, params=search_params, @@ -933,7 +903,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: params=search_params, ), ], - query=FusionQuery(fusion=self._resolve_hybrid_fusion(node.fusion)), + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), limit=fetch_limit, offset=node.offset or None, query_filter=qdrant_filter, @@ -1015,7 +985,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: query_using = topology.dense_using(node.dense_vector) response = self._client.query_points( collection_name=node.collection, - query=self._build_dense_query(vector, node.with_clause), + query=build_dense_query(vector, node.with_clause), using=query_using, limit=fetch_limit, offset=node.offset or None, @@ -1066,15 +1036,6 @@ def _build_hybrid_vectors( ) return dense_vector, sparse_vector - def _resolve_hybrid_fusion(self, fusion: str | None) -> Fusion: - if fusion is None or fusion == "rrf": - return Fusion.RRF - if fusion == "dbsf": - return Fusion.DBSF - raise QQLRuntimeError( - f"Unsupported hybrid fusion '{fusion}'; expected 'rrf' or 'dbsf'" - ) - def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: if not self._client.collection_exists(node.collection): raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") @@ -1084,7 +1045,7 @@ def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: qdrant_filter = self._wrap_as_filter( self._build_qdrant_filter(node.query_filter) ) - qdrant_filter = self._exclude_ids_from_filter( + qdrant_filter = exclude_ids_from_filter( qdrant_filter, [*node.positive_ids, *node.negative_ids], ) @@ -1092,11 +1053,11 @@ def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: recommend_input = RecommendInput( positive=list(node.positive_ids), negative=list(node.negative_ids) or None, - strategy=self._parse_recommend_strategy(node.strategy), + strategy=parse_recommend_strategy(node.strategy), ) search_params = self._build_search_params(node.with_clause) - if self._has_mmr(node.with_clause): + if has_mmr(node.with_clause): raise QQLRuntimeError("MMR is supported only for SEARCH statements") lookup_from: LookupLocation | None = None @@ -1501,107 +1462,6 @@ def _describe_quantization_update( return f", quantization={quantization.config.type.value}" return "" - def _has_mmr(self, with_clause: SearchWith | None) -> bool: - return with_clause is not None and ( - with_clause.mmr_diversity is not None or with_clause.mmr_candidates is not None - ) - - def _validate_search_mmr_usage(self, node: SearchStmt) -> None: - if not self._has_mmr(node.with_clause): - return - if node.sparse_only: - raise QQLRuntimeError("MMR is not supported with USING SPARSE yet") - - def _build_dense_query( - self, - vector: list[float], - with_clause: SearchWith | None, - ) -> list[float] | NearestQuery: - if not self._has_mmr(with_clause): - return vector - return NearestQuery( - nearest=vector, - mmr=Mmr( - diversity=with_clause.mmr_diversity, - candidates_limit=with_clause.mmr_candidates, - ), - ) - - def _parse_recommend_strategy( - self, strategy: str | None - ) -> RecommendStrategy | None: - if strategy is None: - return None - try: - return RecommendStrategy(strategy) - except ValueError as e: - raise QQLRuntimeError( - "Unknown recommend strategy " - f"'{strategy}'. Expected one of: average_vector, best_score, sum_scores" - ) from e - - def _exclude_ids_from_filter( - self, - query_filter: Filter | None, - point_ids: list[str | int], - ) -> Filter | None: - if not point_ids: - return query_filter - - exclude_condition = HasIdCondition(has_id=point_ids) - if query_filter is None: - return Filter(must_not=[exclude_condition]) - - return Filter( - must=list(query_filter.must or []), - should=list(query_filter.should or []), - must_not=[*(query_filter.must_not or []), exclude_condition], - min_should=query_filter.min_should, - ) - - def _extract_point_id_and_payload( - self, values: dict[str, Any] - ) -> tuple[str | int, dict[str, Any]]: - payload = dict(values) - if "id" not in payload: - return str(uuid.uuid4()), payload - - point_id = payload.pop("id") - if isinstance(point_id, bool): - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) - if isinstance(point_id, int): - if point_id < 0: - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) - return point_id, payload - if isinstance(point_id, str): - try: - uuid.UUID(point_id) - except ValueError as e: - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) from e - return point_id, payload - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) - - def _build_dense_point_vector( - self, - topology: CollectionTopology, - vector: list[float], - explicit_vector: str | None, - ) -> list[float] | dict[str, list[float]]: - if not topology.exists: - return {explicit_vector or self._default_dense_vector_name(): vector} - vector_name = topology.dense_payload_name(explicit_vector) - if vector_name is None: - return vector - return {vector_name: vector} - def _apply_reranking( self, query: str, @@ -1682,7 +1542,7 @@ def _execute_search_groups( group_by=node.group_by, prefetch=[ Prefetch( - query=self._build_dense_query(dense_vector, node.with_clause), + query=build_dense_query(dense_vector, node.with_clause), using=topology.dense_using(node.dense_vector), limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, params=search_params, @@ -1694,7 +1554,7 @@ def _execute_search_groups( params=search_params, ), ], - query=FusionQuery(fusion=self._resolve_hybrid_fusion(node.fusion)), + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), limit=node.limit, group_size=node.group_size, query_filter=qdrant_filter, @@ -1729,7 +1589,7 @@ def _execute_search_groups( response = self._client.query_points_groups( collection_name=node.collection, group_by=node.group_by, - query=self._build_dense_query(vector, node.with_clause), + query=build_dense_query(vector, node.with_clause), using=query_using, limit=node.limit, group_size=node.group_size, @@ -1815,98 +1675,221 @@ def _execute_update_payload(self, node: UpdatePayloadStmt) -> ExecutionResult: data=[], ) - # ── Filter conversion ───────────────────────────────────────────────── - - def _build_qdrant_filter(self, expr: FilterExpr) -> Any: - """Convert a FilterExpr AST node into a Qdrant model object. - - Returns one of: Filter, FieldCondition, IsNullCondition, IsEmptyCondition. - Use _wrap_as_filter() to guarantee the top-level result is a Filter. - """ - # ── Logical combinators ─────────────────────────────────────────── - if isinstance(expr, AndExpr): - return Filter(must=[self._build_qdrant_filter(op) for op in expr.operands]) - - if isinstance(expr, OrExpr): - return Filter(should=[self._build_qdrant_filter(op) for op in expr.operands]) - - if isinstance(expr, NotExpr): - return Filter(must_not=[self._build_qdrant_filter(expr.operand)]) - - # ── Comparison ──────────────────────────────────────────────────── - if isinstance(expr, CompareExpr): - if expr.op == "=": - return FieldCondition( - key=expr.field, match=MatchValue(value=expr.value) + def _execute_batch_block(self, node: BatchBlockStmt) -> ExecutionResult: + if not node.statements: + return ExecutionResult(success=True, message="Executed empty batch", data=[]) + + all_results = [] + succeeded_count = 0 + + for group in group_batch_statements(node.statements): + if group.kind == 'query': + res = self._execute_query_batch(group.collection, group.statements) + all_results.extend(res) + succeeded_count += len([r for r in res if r.success]) + elif group.kind == 'insert': + bulk_node = build_bulk_insert_from_group( + group.collection, + group.statements, ) - if expr.op == "!=": - return Filter( - must_not=[ - FieldCondition(key=expr.field, match=MatchValue(value=expr.value)) - ] + res = self._execute_insert_bulk(bulk_node) + insert_results = inserted_point_results( + res, + group.statements, + ExecutionResult, ) - _range_key = {">": "gt", ">=": "gte", "<": "lt", "<=": "lte"}[expr.op] - return FieldCondition( - key=expr.field, range=Range(**{_range_key: expr.value}) - ) - - # ── BETWEEN ─────────────────────────────────────────────────────── - if isinstance(expr, BetweenExpr): - return FieldCondition( - key=expr.field, range=Range(gte=expr.low, lte=expr.high) - ) - - # ── IN / NOT IN ─────────────────────────────────────────────────── - if isinstance(expr, InExpr): - return FieldCondition( - key=expr.field, match=MatchAny(any=list(expr.values)) - ) - - if isinstance(expr, NotInExpr): - return FieldCondition( - key=expr.field, - match=MatchExcept(**{"except": list(expr.values)}), - ) - - # ── IS NULL / IS NOT NULL ───────────────────────────────────────── - if isinstance(expr, IsNullExpr): - return IsNullCondition(is_null=PayloadField(key=expr.field)) + all_results.extend(insert_results) + succeeded_count += len([r for r in insert_results if r.success]) + else: + for s in group.statements: + res = self.execute(s) + all_results.append(res) + if res.success: + succeeded_count += 1 + + total_stmts = len(node.statements) + return ExecutionResult( + success=succeeded_count == total_stmts, + message=f"Batch executed {succeeded_count}/{total_stmts} statement(s) successfully", + data=all_results, + ) - if isinstance(expr, IsNotNullExpr): - return Filter( - must_not=[IsNullCondition(is_null=PayloadField(key=expr.field))] + def _execute_query_batch( + self, + collection_name: str, + nodes: list[SearchStmt | RecommendStmt], + ) -> list[ExecutionResult]: + if not self._client.collection_exists(collection_name): + raise QQLRuntimeError(f"Collection '{collection_name}' does not exist") + + topology = self._resolve_topology(collection_name) + requests = [] + + for node in nodes: + qdrant_filter = None + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + + search_params = self._build_search_params(node.with_clause) + + lookup_from = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + if isinstance(node, SearchStmt): + validate_search_mmr_usage(node) + fetch_limit = node.limit * _RERANK_FETCH_MULTIPLIER if node.rerank else node.limit + + if node.hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_vector, sparse_vector = self._build_hybrid_vectors( + node.query_text, dense_model, sparse_model_name + ) + + req = QueryRequest( + prefetch=[ + Prefetch( + query=build_dense_query(dense_vector, node.with_clause), + using=topology.dense_using(node.dense_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + Prefetch( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + ], + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + elif node.sparse_only: + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + sparse_embedder = SparseEmbedder(sparse_model_name) + sparse_obj = sparse_embedder.query_embed(node.query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + + req = QueryRequest( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + else: + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + vector = embedder.embed(node.query_text) + query_using = topology.dense_using(node.dense_vector) + + req = QueryRequest( + query=build_dense_query(vector, node.with_clause), + using=query_using, + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + else: + qdrant_filter = exclude_ids_from_filter( + qdrant_filter, + [*node.positive_ids, *node.negative_ids], + ) + recommend_input = RecommendInput( + positive=list(node.positive_ids), + negative=list(node.negative_ids) or None, + strategy=parse_recommend_strategy(node.strategy), + ) + if has_mmr(node.with_clause): + raise QQLRuntimeError("MMR is supported only for SEARCH statements") + + req = QueryRequest( + query=RecommendQuery(recommend=recommend_input), + limit=node.limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + using=node.using, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + + requests.append(req) + + try: + responses = self._client.query_batch_points( + collection_name=collection_name, + requests=requests, ) - - # ── IS EMPTY / IS NOT EMPTY ─────────────────────────────────────── - if isinstance(expr, IsEmptyExpr): - return IsEmptyCondition(is_empty=PayloadField(key=expr.field)) - - if isinstance(expr, IsNotEmptyExpr): - return Filter( - must_not=[IsEmptyCondition(is_empty=PayloadField(key=expr.field))] + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during Batch Query: {e}") from e + + execution_results = [] + for i, response in enumerate(responses): + node = nodes[i] + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if isinstance(node, SearchStmt) and node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + label = "hybrid, reranked" if node.hybrid else ("sparse, reranked" if node.sparse_only else "reranked") + msg = f"Found {len(results)} result(s) ({label})" + else: + if isinstance(node, SearchStmt): + label = "hybrid" if node.hybrid else ("sparse" if node.sparse_only else "") + label_suffix = f" ({label})" if label else "" + msg = f"Found {len(results)} result(s){label_suffix}" + else: + msg = f"Found {len(results)} recommendation(s)" + + execution_results.append( + ExecutionResult(success=True, message=msg, data=results) ) + + return execution_results - # ── Full-text MATCH ─────────────────────────────────────────────── - if isinstance(expr, MatchTextExpr): - return FieldCondition(key=expr.field, match=MatchText(text=expr.text)) - - if isinstance(expr, MatchAnyExpr): - return FieldCondition( - key=expr.field, match=MatchTextAny(text_any=expr.text) - ) + # ── Filter conversion ───────────────────────────────────────────────── - if isinstance(expr, MatchPhraseExpr): - return FieldCondition( - key=expr.field, match=MatchPhrase(phrase=expr.text) - ) + def _build_qdrant_filter(self, expr: FilterExpr) -> Any: + """Convert a FilterExpr AST node into a Qdrant model object. - raise QQLRuntimeError(f"Unknown filter expression type: {type(expr)}") + Returns one of: Filter, FieldCondition, IsNullCondition, IsEmptyCondition. + Use _wrap_as_filter() to guarantee the top-level result is a Filter. + """ + return build_qdrant_filter(expr) def _wrap_as_filter(self, qdrant_expr: Any) -> Filter: """Ensure the top-level expression is a Filter (required by query_points).""" - if isinstance(qdrant_expr, Filter): - return qdrant_expr - return Filter(must=[qdrant_expr]) + return wrap_as_filter(qdrant_expr) # ── Collection helpers ──────────────────────────────────────────────── diff --git a/src/qql/lexer.py b/src/qql/lexer.py index 39babee..1df588d 100644 --- a/src/qql/lexer.py +++ b/src/qql/lexer.py @@ -71,6 +71,10 @@ class TokenKind(Enum): UPDATE = auto() SET = auto() PAYLOAD = auto() + # ── Batch keywords ──────────────────────────────────────────────────── + BEGIN = auto() + BATCH = auto() + END = auto() # ── Filter keywords ─────────────────────────────────────────────────── AND = auto() OR = auto() @@ -99,6 +103,7 @@ class TokenKind(Enum): COLON = auto() COMMA = auto() EQUALS = auto() + SEMICOLON = auto() # ── Comparison operators ────────────────────────────────────────────── NOT_EQUALS = auto() # != GT = auto() # > @@ -176,6 +181,9 @@ class TokenKind(Enum): "UPDATE": TokenKind.UPDATE, "SET": TokenKind.SET, "PAYLOAD": TokenKind.PAYLOAD, + "BEGIN": TokenKind.BEGIN, + "BATCH": TokenKind.BATCH, + "END": TokenKind.END, # Filter keywords "AND": TokenKind.AND, "OR": TokenKind.OR, @@ -239,6 +247,9 @@ def tokenize(self, query: str) -> list[Token]: elif ch == ",": tokens.append(Token(TokenKind.COMMA, ",", i)) i += 1 + elif ch == ";": + tokens.append(Token(TokenKind.SEMICOLON, ";", i)) + i += 1 # ── Comparison operators (multi-char look-ahead) ────────────── elif ch == "=": diff --git a/src/qql/parser.py b/src/qql/parser.py index b62b9b0..7134fce 100644 --- a/src/qql/parser.py +++ b/src/qql/parser.py @@ -42,9 +42,16 @@ UpdatePayloadStmt, VectorsConfig, HnswRuntimeConfig, + BatchBlockStmt, ) from .exceptions import QQLSyntaxError from .lexer import Token, TokenKind +from .utils import ( + parse_search_group_by, + parse_search_lookup, + parse_search_using, + parse_search_with, +) # Comparison operator token → string symbol mapping _CMP_OPS: dict[TokenKind, str] = { @@ -56,9 +63,6 @@ TokenKind.LTE: "<=", } -_HYBRID_FUSION_VALUES = {"rrf", "dbsf"} - - class Parser: def __init__(self, tokens: list[Token]) -> None: self._tokens = tokens @@ -67,36 +71,65 @@ def __init__(self, tokens: list[Token]) -> None: # ── Public entry point ──────────────────────────────────────────────── def parse(self) -> ASTNode: + node = self._parse_single_statement() + self._expect(TokenKind.EOF) + return node + + def _parse_single_statement(self) -> ASTNode: tok = self._peek() if tok.kind == TokenKind.INSERT: - node = self._parse_insert() + return self._parse_insert() elif tok.kind == TokenKind.CREATE: - node = self._parse_create() + return self._parse_create() elif tok.kind == TokenKind.ALTER: - node = self._parse_alter() + return self._parse_alter() elif tok.kind == TokenKind.DROP: - node = self._parse_drop() + return self._parse_drop() elif tok.kind == TokenKind.SHOW: - node = self._parse_show() + return self._parse_show() elif tok.kind == TokenKind.SCROLL: - node = self._parse_scroll() + return self._parse_scroll() elif tok.kind == TokenKind.SELECT: - node = self._parse_select() + return self._parse_select() elif tok.kind == TokenKind.SEARCH: - node = self._parse_search() + return self._parse_search() elif tok.kind == TokenKind.RECOMMEND: - node = self._parse_recommend() + return self._parse_recommend() elif tok.kind == TokenKind.DELETE: - node = self._parse_delete() + return self._parse_delete() elif tok.kind == TokenKind.UPDATE: - node = self._parse_update() + return self._parse_update() + elif tok.kind == TokenKind.BEGIN: + return self._parse_batch_block() else: raise QQLSyntaxError( f"Unexpected token '{tok.value}'; expected a QQL statement keyword", tok.pos, ) - self._expect(TokenKind.EOF) - return node + + def _parse_batch_block(self) -> BatchBlockStmt: + self._expect(TokenKind.BEGIN) + self._expect(TokenKind.BATCH) + statements = [] + while True: + while self._peek().kind == TokenKind.SEMICOLON: + self._advance() + + if self._peek().kind == TokenKind.END: + self._advance() + self._expect(TokenKind.BATCH) + break + + if self._peek().kind == TokenKind.EOF: + raise QQLSyntaxError("Unterminated batch block; expected END BATCH", self._peek().pos) + + stmt = self._parse_single_statement() + statements.append(stmt) + + if self._peek().kind == TokenKind.SEMICOLON: + self._advance() + + return BatchBlockStmt(statements=tuple(statements)) # ── Statement parsers ───────────────────────────────────────────────── @@ -696,80 +729,14 @@ def _parse_search(self) -> SearchStmt: self._expect(TokenKind.THRESHOLD) score_threshold = float(self._parse_number()) - lookup_from: tuple[str, str | None] | None = None - if self._peek().kind == TokenKind.LOOKUP: - self._advance() - self._expect(TokenKind.FROM) - lookup_collection = self._parse_identifier() - lookup_vector: str | None = None - if self._peek().kind == TokenKind.VECTOR: - self._advance() - lookup_vector = self._expect(TokenKind.STRING).value - lookup_from = (lookup_collection, lookup_vector) + lookup_from = parse_search_lookup(self) with_clause: SearchWith | None = None if self._peek().kind == TokenKind.EXACT: self._advance() with_clause = SearchWith(exact=True) - model: str | None = None - hybrid: bool = False - fusion: str | None = None - sparse_only: bool = False - sparse_model: str | None = None - dense_vector: str | None = None - sparse_vector: str | None = None - if self._peek().kind == TokenKind.USING: - self._advance() # consume USING - if self._peek().kind == TokenKind.HYBRID: - self._advance() # consume HYBRID - hybrid = True - # Optional FUSION / DENSE|SPARSE MODEL|VECTOR sub-clauses, any order. - while self._peek().kind in (TokenKind.FUSION, TokenKind.DENSE, TokenKind.SPARSE): - sub = self._advance() - if sub.kind == TokenKind.FUSION: - value_tok = self._expect(TokenKind.STRING) - fusion = value_tok.value.lower() - if fusion not in _HYBRID_FUSION_VALUES: - raise QQLSyntaxError( - f"Unsupported hybrid fusion '{value_tok.value}'; expected 'rrf' or 'dbsf'", - value_tok.pos, - ) - continue - if self._peek().kind == TokenKind.MODEL: - self._advance() - m = self._expect(TokenKind.STRING).value - if sub.kind == TokenKind.DENSE: - model = m - else: - sparse_model = m - elif self._peek().kind == TokenKind.VECTOR: - self._advance() - name = self._expect(TokenKind.STRING).value - if sub.kind == TokenKind.DENSE: - dense_vector = name - else: - sparse_vector = name - else: - raise QQLSyntaxError( - "Expected MODEL or VECTOR after DENSE/SPARSE in USING HYBRID", - self._peek().pos, - ) - elif self._peek().kind == TokenKind.SPARSE: - self._advance() # consume SPARSE - sparse_only = True - while self._peek().kind in (TokenKind.MODEL, TokenKind.VECTOR): - sub = self._advance() - if sub.kind == TokenKind.MODEL: - sparse_model = self._expect(TokenKind.STRING).value - else: - sparse_vector = self._expect(TokenKind.STRING).value - elif self._peek().kind == TokenKind.VECTOR: - self._advance() - dense_vector = self._expect(TokenKind.STRING).value - else: - self._expect(TokenKind.MODEL) - model = self._expect(TokenKind.STRING).value + using = parse_search_using(self) query_filter: FilterExpr | None = None if self._peek().kind == TokenKind.WHERE: self._advance() # consume WHERE @@ -783,79 +750,25 @@ def _parse_search(self) -> SearchStmt: self._advance() # consume MODEL rerank_model = self._expect(TokenKind.STRING).value - if self._peek().kind == TokenKind.EXACT: - self._advance() - if with_clause is None: - with_clause = SearchWith(exact=True) - else: - with_clause = SearchWith( - hnsw_ef=with_clause.hnsw_ef, - exact=True, - acorn=with_clause.acorn, - indexed_only=with_clause.indexed_only, - quantization=with_clause.quantization, - mmr_diversity=with_clause.mmr_diversity, - mmr_candidates=with_clause.mmr_candidates, - ) - - if self._peek().kind == TokenKind.WITH: - self._advance() # consume WITH - parsed_with = self._parse_with_clause() - if with_clause is None: - with_clause = parsed_with - else: - with_clause = SearchWith( - hnsw_ef=parsed_with.hnsw_ef or with_clause.hnsw_ef, - exact=parsed_with.exact or with_clause.exact, - acorn=parsed_with.acorn or with_clause.acorn, - indexed_only=parsed_with.indexed_only or with_clause.indexed_only, - quantization=parsed_with.quantization or with_clause.quantization, - mmr_diversity=( - parsed_with.mmr_diversity - if parsed_with.mmr_diversity is not None - else with_clause.mmr_diversity - ), - mmr_candidates=parsed_with.mmr_candidates or with_clause.mmr_candidates, - ) - group_by: str | None = None - group_size: int = 3 - if self._peek().kind == TokenKind.GROUP: - if offset > 0: - raise QQLSyntaxError("OFFSET cannot be used with GROUP BY", self._peek().pos) - self._advance() # consume GROUP - self._expect(TokenKind.BY) - group_by = self._parse_field_path() - if rerank: - raise QQLSyntaxError( - "GROUP BY and RERANK cannot be combined in the same SEARCH statement", - self._peek().pos, - ) - if self._peek().kind == TokenKind.GROUP_SIZE: - self._advance() # consume GROUP_SIZE - gs_tok = self._peek() - group_size = int(self._expect(TokenKind.INTEGER).value) - if group_size <= 0: - raise QQLSyntaxError( - f"GROUP_SIZE must be a positive integer, got {group_size}", - gs_tok.pos, - ) + with_clause = parse_search_with(self, with_clause) + group = parse_search_group_by(self, offset, rerank) return SearchStmt( collection=collection, query_text=query_text, limit=limit, - model=model, - hybrid=hybrid, - fusion=fusion, - sparse_only=sparse_only, - sparse_model=sparse_model, + model=using.model, + hybrid=using.hybrid, + fusion=using.fusion, + sparse_only=using.sparse_only, + sparse_model=using.sparse_model, query_filter=query_filter, rerank=rerank, rerank_model=rerank_model, with_clause=with_clause, - group_by=group_by, - group_size=group_size, - dense_vector=dense_vector, - sparse_vector=sparse_vector, + group_by=group.group_by, + group_size=group.group_size, + dense_vector=using.dense_vector, + sparse_vector=using.sparse_vector, offset=offset, score_threshold=score_threshold, lookup_from=lookup_from, diff --git a/src/qql/script.py b/src/qql/script.py index ab2de16..08023b6 100644 --- a/src/qql/script.py +++ b/src/qql/script.py @@ -30,6 +30,7 @@ TokenKind.SEARCH, TokenKind.RECOMMEND, TokenKind.DELETE, + TokenKind.BEGIN, } _DEPTH_OPEN = {TokenKind.LBRACE, TokenKind.LBRACKET, TokenKind.LPAREN} @@ -84,13 +85,14 @@ def split_statements(tokens: list[Token]) -> list[list[Token]]: """Split a flat token list into per-statement chunks. A new chunk begins whenever a statement-starter keyword (INSERT, CREATE, - DROP, SHOW, SCROLL, SELECT, SEARCH, RECOMMEND, DELETE) is encountered at - brace/bracket/paren depth 0. + DROP, SHOW, SCROLL, SELECT, SEARCH, RECOMMEND, DELETE, BEGIN) is encountered at + brace/bracket/paren depth 0 and batch_depth 0. The EOF sentinel is consumed and never included in any chunk. """ chunks: list[list[Token]] = [] current: list[Token] = [] depth = 0 + batch_depth = 0 for tok in tokens: if tok.kind == TokenKind.EOF: @@ -100,11 +102,16 @@ def split_statements(tokens: list[Token]) -> list[list[Token]]: elif tok.kind in _DEPTH_CLOSE: depth -= 1 - # New statement starts when we see a starter at the top level - if tok.kind in _STMT_STARTERS and depth == 0 and current: + # New statement starts when we see a starter at the top level outside of a batch block + if tok.kind in _STMT_STARTERS and depth == 0 and batch_depth == 0 and current: chunks.append(current) current = [] + if tok.kind == TokenKind.BEGIN: + batch_depth += 1 + elif tok.kind == TokenKind.END: + batch_depth -= 1 + current.append(tok) if current: diff --git a/src/qql/utils.py b/src/qql/utils.py new file mode 100644 index 0000000..e72d1db --- /dev/null +++ b/src/qql/utils.py @@ -0,0 +1,562 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from typing import Any + +from qdrant_client.models import ( + FieldCondition, + Filter, + Fusion, + HasIdCondition, + IsEmptyCondition, + IsNullCondition, + MatchAny, + MatchExcept, + MatchPhrase, + MatchText, + MatchTextAny, + MatchValue, + Mmr, + NearestQuery, + PayloadField, + Range, + RecommendStrategy, +) + +from .ast_nodes import ( + ASTNode, + AndExpr, + BetweenExpr, + CompareExpr, + FilterExpr, + InExpr, + InsertBulkStmt, + InsertStmt, + IsEmptyExpr, + IsNotEmptyExpr, + IsNotNullExpr, + IsNullExpr, + MatchAnyExpr, + MatchPhraseExpr, + MatchTextExpr, + NotExpr, + NotInExpr, + OrExpr, + RecommendStmt, + SearchStmt, + SearchWith, +) +from .exceptions import QQLRuntimeError, QQLSyntaxError +from .lexer import TokenKind + +_HYBRID_FUSION_VALUES = {"rrf", "dbsf"} + + +@dataclass(frozen=True) +class BatchGroup: + kind: str + collection: str | None + statements: list[ASTNode] + + +@dataclass(frozen=True) +class SearchUsingOptions: + model: str | None = None + hybrid: bool = False + fusion: str | None = None + sparse_only: bool = False + sparse_model: str | None = None + dense_vector: str | None = None + sparse_vector: str | None = None + + +@dataclass(frozen=True) +class SearchGroupByOptions: + group_by: str | None = None + group_size: int = 3 + + +def render_parameterized_query(template: str, params: dict[str, Any]) -> str: + query_str = template + for k in sorted(params.keys(), key=len, reverse=True): + val = params[k] + placeholder = f":{k}" + if isinstance(val, str): + escaped_val = val.replace("'", "\\'") + repr_val = f"'{escaped_val}'" + elif isinstance(val, bool): + repr_val = "true" if val else "false" + else: + repr_val = str(val) + query_str = query_str.replace(placeholder, repr_val) + return query_str + + +def collection_topology_kwargs(vectors: Any, sparse_vectors: Any) -> dict[str, Any]: + if isinstance(vectors, dict): + dense_names = tuple(vectors.keys()) + dense_sizes = tuple( + (k, v.size) + for k, v in vectors.items() + if getattr(v, "size", None) is not None + ) + has_unnamed_dense = False + is_named_dense = True + elif vectors is None: + dense_names = () + dense_sizes = () + has_unnamed_dense = False + is_named_dense = False + else: + dense_names = () + unnamed_size = getattr(vectors, "size", None) + dense_sizes = (("", unnamed_size),) if unnamed_size is not None else () + has_unnamed_dense = True + is_named_dense = False + + sparse_names = tuple(sparse_vectors.keys()) if isinstance(sparse_vectors, dict) else () + return { + "exists": True, + "is_named_dense": is_named_dense, + "has_unnamed_dense": has_unnamed_dense, + "dense_names": dense_names, + "sparse_names": sparse_names, + "dense_sizes": dense_sizes, + } + + +def _append_batch_group( + groups: list[BatchGroup], + kind: str | None, + collection: str | None, + statements: list[ASTNode], +) -> None: + if statements: + groups.append(BatchGroup(kind or "other", collection, statements)) + + +def _compatible_insert_batch(current_group: list[ASTNode], stmt: InsertStmt) -> bool: + if not current_group: + return False + prev_stmt = current_group[0] + if not isinstance(prev_stmt, InsertStmt): + return False + return ( + stmt.model == prev_stmt.model + and stmt.hybrid == prev_stmt.hybrid + and stmt.sparse_model == prev_stmt.sparse_model + and stmt.dense_vector == prev_stmt.dense_vector + and stmt.sparse_vector == prev_stmt.sparse_vector + ) + + +def group_batch_statements(statements: tuple[ASTNode, ...]) -> list[BatchGroup]: + groups: list[BatchGroup] = [] + current_type: str | None = None + current_collection: str | None = None + current_group: list[ASTNode] = [] + + for stmt in statements: + if isinstance(stmt, (SearchStmt, RecommendStmt)): + coll = stmt.collection + if current_type == "query" and current_collection == coll: + current_group.append(stmt) + continue + _append_batch_group(groups, current_type, current_collection, current_group) + current_type = "query" + current_collection = coll + current_group = [stmt] + continue + + if isinstance(stmt, InsertStmt): + coll = stmt.collection + if ( + current_type == "insert" + and current_collection == coll + and _compatible_insert_batch(current_group, stmt) + ): + current_group.append(stmt) + continue + _append_batch_group(groups, current_type, current_collection, current_group) + current_type = "insert" + current_collection = coll + current_group = [stmt] + continue + + _append_batch_group(groups, current_type, current_collection, current_group) + groups.append(BatchGroup("other", None, [stmt])) + current_type = None + current_collection = None + current_group = [] + + _append_batch_group(groups, current_type, current_collection, current_group) + return groups + + +def build_bulk_insert_from_group( + collection: str, + statements: list[ASTNode], +) -> InsertBulkStmt: + first = statements[0] + if not isinstance(first, InsertStmt): + raise QQLRuntimeError("Batch insert group must contain INSERT statements") + insert_statements = [stmt for stmt in statements if isinstance(stmt, InsertStmt)] + return InsertBulkStmt( + collection=collection, + values_list=tuple(stmt.values for stmt in insert_statements), + model=first.model, + hybrid=first.hybrid, + sparse_model=first.sparse_model, + dense_vector=first.dense_vector, + sparse_vector=first.sparse_vector, + ) + + +def inserted_point_results( + result: Any, + statements: list[ASTNode], + result_type: type, +) -> list[Any]: + if not result.success: + return [result_type(success=False, message=result.message) for _ in statements] + + inserted_ids = result.data.get("ids", []) if isinstance(result.data, dict) else [] + is_hybrid = "hybrid" in result.message + label = "hybrid, batched" if is_hybrid else "batched" + rows = [] + for idx, stmt in enumerate(statements): + if not isinstance(stmt, InsertStmt): + continue + point_id = inserted_ids[idx] if idx < len(inserted_ids) else "unknown" + rows.append( + result_type( + success=True, + message=f"Inserted 1 point [{point_id}] ({label})", + data={"id": point_id, "collection": stmt.collection}, + ) + ) + return rows + + +def build_qdrant_filter(expr: FilterExpr) -> Any: + if isinstance(expr, AndExpr): + return Filter(must=[build_qdrant_filter(op) for op in expr.operands]) + if isinstance(expr, OrExpr): + return Filter(should=[build_qdrant_filter(op) for op in expr.operands]) + if isinstance(expr, NotExpr): + return Filter(must_not=[build_qdrant_filter(expr.operand)]) + if isinstance(expr, CompareExpr): + if expr.op == "=": + return FieldCondition(key=expr.field, match=MatchValue(value=expr.value)) + if expr.op == "!=": + return Filter( + must_not=[ + FieldCondition(key=expr.field, match=MatchValue(value=expr.value)) + ] + ) + range_key = {">": "gt", ">=": "gte", "<": "lt", "<=": "lte"}[expr.op] + return FieldCondition(key=expr.field, range=Range(**{range_key: expr.value})) + if isinstance(expr, BetweenExpr): + return FieldCondition(key=expr.field, range=Range(gte=expr.low, lte=expr.high)) + if isinstance(expr, InExpr): + return FieldCondition(key=expr.field, match=MatchAny(any=list(expr.values))) + if isinstance(expr, NotInExpr): + return FieldCondition( + key=expr.field, + match=MatchExcept(**{"except": list(expr.values)}), + ) + if isinstance(expr, IsNullExpr): + return IsNullCondition(is_null=PayloadField(key=expr.field)) + if isinstance(expr, IsNotNullExpr): + return Filter(must_not=[IsNullCondition(is_null=PayloadField(key=expr.field))]) + if isinstance(expr, IsEmptyExpr): + return IsEmptyCondition(is_empty=PayloadField(key=expr.field)) + if isinstance(expr, IsNotEmptyExpr): + return Filter(must_not=[IsEmptyCondition(is_empty=PayloadField(key=expr.field))]) + if isinstance(expr, MatchTextExpr): + return FieldCondition(key=expr.field, match=MatchText(text=expr.text)) + if isinstance(expr, MatchAnyExpr): + return FieldCondition(key=expr.field, match=MatchTextAny(text_any=expr.text)) + if isinstance(expr, MatchPhraseExpr): + return FieldCondition(key=expr.field, match=MatchPhrase(phrase=expr.text)) + raise QQLRuntimeError(f"Unknown filter expression type: {type(expr)}") + + +def wrap_as_filter(qdrant_expr: Any) -> Filter: + if isinstance(qdrant_expr, Filter): + return qdrant_expr + return Filter(must=[qdrant_expr]) + + +def resolve_hybrid_fusion(fusion: str | None) -> Fusion: + if fusion is None or fusion == "rrf": + return Fusion.RRF + if fusion == "dbsf": + return Fusion.DBSF + raise QQLRuntimeError( + f"Unsupported hybrid fusion '{fusion}'; expected 'rrf' or 'dbsf'" + ) + + +def has_mmr(with_clause: SearchWith | None) -> bool: + return with_clause is not None and ( + with_clause.mmr_diversity is not None or with_clause.mmr_candidates is not None + ) + + +def validate_search_mmr_usage(node: SearchStmt) -> None: + if not has_mmr(node.with_clause): + return + if node.sparse_only: + raise QQLRuntimeError("MMR is not supported with USING SPARSE yet") + + +def build_dense_query( + vector: list[float], + with_clause: SearchWith | None, +) -> list[float] | NearestQuery: + if not has_mmr(with_clause): + return vector + return NearestQuery( + nearest=vector, + mmr=Mmr( + diversity=with_clause.mmr_diversity, + candidates_limit=with_clause.mmr_candidates, + ), + ) + + +def parse_recommend_strategy(strategy: str | None) -> RecommendStrategy | None: + if strategy is None: + return None + try: + return RecommendStrategy(strategy) + except ValueError as e: + raise QQLRuntimeError( + "Unknown recommend strategy " + f"'{strategy}'. Expected one of: average_vector, best_score, sum_scores" + ) from e + + +def exclude_ids_from_filter( + query_filter: Filter | None, + point_ids: list[str | int], +) -> Filter | None: + if not point_ids: + return query_filter + + exclude_condition = HasIdCondition(has_id=point_ids) + if query_filter is None: + return Filter(must_not=[exclude_condition]) + + return Filter( + must=list(query_filter.must or []), + should=list(query_filter.should or []), + must_not=[*(query_filter.must_not or []), exclude_condition], + min_should=query_filter.min_should, + ) + + +def extract_point_id_and_payload( + values: dict[str, Any], +) -> tuple[str | int, dict[str, Any]]: + payload = dict(values) + if "id" not in payload: + return str(uuid.uuid4()), payload + + point_id = payload.pop("id") + if isinstance(point_id, bool): + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + if isinstance(point_id, int): + if point_id < 0: + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + return point_id, payload + if isinstance(point_id, str): + try: + uuid.UUID(point_id) + except ValueError as e: + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) from e + return point_id, payload + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + + +def build_dense_point_vector( + topology: Any, + vector: list[float], + explicit_vector: str | None, + default_dense_vector_name: str, +) -> list[float] | dict[str, list[float]]: + if not topology.exists: + return {explicit_vector or default_dense_vector_name: vector} + vector_name = topology.dense_payload_name(explicit_vector) + if vector_name is None: + return vector + return {vector_name: vector} + + +def merge_search_with(base: SearchWith | None, override: SearchWith) -> SearchWith: + if base is None: + return override + return SearchWith( + hnsw_ef=override.hnsw_ef or base.hnsw_ef, + exact=override.exact or base.exact, + acorn=override.acorn or base.acorn, + indexed_only=override.indexed_only or base.indexed_only, + quantization=override.quantization or base.quantization, + mmr_diversity=( + override.mmr_diversity + if override.mmr_diversity is not None + else base.mmr_diversity + ), + mmr_candidates=override.mmr_candidates or base.mmr_candidates, + ) + + +def parse_search_lookup(parser: Any) -> tuple[str, str | None] | None: + if parser._peek().kind != TokenKind.LOOKUP: + return None + parser._advance() + parser._expect(TokenKind.FROM) + lookup_collection = parser._parse_identifier() + lookup_vector: str | None = None + if parser._peek().kind == TokenKind.VECTOR: + parser._advance() + lookup_vector = parser._expect(TokenKind.STRING).value + return lookup_collection, lookup_vector + + +def parse_search_using(parser: Any) -> SearchUsingOptions: + if parser._peek().kind != TokenKind.USING: + return SearchUsingOptions() + + parser._advance() + if parser._peek().kind == TokenKind.HYBRID: + return _parse_hybrid_using(parser) + if parser._peek().kind == TokenKind.SPARSE: + return _parse_sparse_using(parser) + if parser._peek().kind == TokenKind.VECTOR: + parser._advance() + return SearchUsingOptions(dense_vector=parser._expect(TokenKind.STRING).value) + + parser._expect(TokenKind.MODEL) + return SearchUsingOptions(model=parser._expect(TokenKind.STRING).value) + + +def _parse_hybrid_using(parser: Any) -> SearchUsingOptions: + parser._advance() + model: str | None = None + fusion: str | None = None + sparse_model: str | None = None + dense_vector: str | None = None + sparse_vector: str | None = None + + while parser._peek().kind in (TokenKind.FUSION, TokenKind.DENSE, TokenKind.SPARSE): + sub = parser._advance() + if sub.kind == TokenKind.FUSION: + value_tok = parser._expect(TokenKind.STRING) + fusion = value_tok.value.lower() + if fusion not in _HYBRID_FUSION_VALUES: + raise QQLSyntaxError( + f"Unsupported hybrid fusion '{value_tok.value}'; expected 'rrf' or 'dbsf'", + value_tok.pos, + ) + continue + if parser._peek().kind == TokenKind.MODEL: + parser._advance() + parsed_model = parser._expect(TokenKind.STRING).value + if sub.kind == TokenKind.DENSE: + model = parsed_model + else: + sparse_model = parsed_model + continue + if parser._peek().kind == TokenKind.VECTOR: + parser._advance() + name = parser._expect(TokenKind.STRING).value + if sub.kind == TokenKind.DENSE: + dense_vector = name + else: + sparse_vector = name + continue + raise QQLSyntaxError( + "Expected MODEL or VECTOR after DENSE/SPARSE in USING HYBRID", + parser._peek().pos, + ) + + return SearchUsingOptions( + model=model, + hybrid=True, + fusion=fusion, + sparse_model=sparse_model, + dense_vector=dense_vector, + sparse_vector=sparse_vector, + ) + + +def _parse_sparse_using(parser: Any) -> SearchUsingOptions: + parser._advance() + sparse_model: str | None = None + sparse_vector: str | None = None + while parser._peek().kind in (TokenKind.MODEL, TokenKind.VECTOR): + sub = parser._advance() + if sub.kind == TokenKind.MODEL: + sparse_model = parser._expect(TokenKind.STRING).value + else: + sparse_vector = parser._expect(TokenKind.STRING).value + return SearchUsingOptions( + sparse_only=True, + sparse_model=sparse_model, + sparse_vector=sparse_vector, + ) + + +def parse_search_with(parser: Any, with_clause: SearchWith | None) -> SearchWith | None: + if parser._peek().kind == TokenKind.EXACT: + parser._advance() + with_clause = merge_search_with(with_clause, SearchWith(exact=True)) + + if parser._peek().kind == TokenKind.WITH: + parser._advance() + with_clause = merge_search_with(with_clause, parser._parse_with_clause()) + + return with_clause + + +def parse_search_group_by( + parser: Any, + offset: int, + rerank: bool, +) -> SearchGroupByOptions: + if parser._peek().kind != TokenKind.GROUP: + return SearchGroupByOptions() + + if offset > 0: + raise QQLSyntaxError("OFFSET cannot be used with GROUP BY", parser._peek().pos) + parser._advance() + parser._expect(TokenKind.BY) + group_by = parser._parse_field_path() + if rerank: + raise QQLSyntaxError( + "GROUP BY and RERANK cannot be combined in the same SEARCH statement", + parser._peek().pos, + ) + + group_size = 3 + if parser._peek().kind == TokenKind.GROUP_SIZE: + parser._advance() + group_size_tok = parser._peek() + group_size = int(parser._expect(TokenKind.INTEGER).value) + if group_size <= 0: + raise QQLSyntaxError( + f"GROUP_SIZE must be a positive integer, got {group_size}", + group_size_tok.pos, + ) + return SearchGroupByOptions(group_by=group_by, group_size=group_size) diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py new file mode 100644 index 0000000..e148d72 --- /dev/null +++ b/tests/test_async_connection.py @@ -0,0 +1,389 @@ +"""Tests for the AsyncConnection and QQLAsyncBatch classes. + +All tests mock AsyncQdrantClient so no live Qdrant instance is required. +""" +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock + +from qql import ( + AsyncConnection, + QQLConfig, + AsyncExecutor, + ExecutionResult, + QQLAsyncBatch, +) +from qql.exceptions import QQLSyntaxError + + +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +# ── TestAsyncConnectionInit ─────────────────────────────────────────────────── + +class TestAsyncConnectionInit: + """AsyncConnection.__init__ stores config and wires up the async executor.""" + + def test_default_url_and_no_secret(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection() + assert conn.config.url == "http://localhost:6333" + assert conn.config.secret is None + + def test_custom_url_and_secret_passed_to_async_qdrant_client(self, mocker): + mock_client_cls = mocker.patch("qql.async_connection.AsyncQdrantClient") + AsyncConnection("https://cloud.example.io", secret="s3cr3t") + mock_client_cls.assert_called_once_with( + url="https://cloud.example.io", api_key="s3cr3t" + ) + + def test_custom_default_model_stored_in_config(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection("http://localhost:6333", default_model="BAAI/bge-small-en-v1.5") + assert conn.config.default_model == "BAAI/bge-small-en-v1.5" + + def test_config_and_executor_properties_return_correct_types(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection("http://localhost:6333") + assert isinstance(conn.config, QQLConfig) + assert isinstance(conn.executor, AsyncExecutor) + + +# ── TestAsyncConnectionRunQuery ──────────────────────────────────────────────── + +@pytest.mark.anyio +class TestAsyncConnectionRunQuery: + """AsyncConnection.run_query() pipes through the Lexer → Parser → AsyncExecutor.""" + + async def test_run_query_calls_executor_execute(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[] + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection("http://localhost:6333") + await conn.run_query("SHOW COLLECTIONS") + mock_executor.execute.assert_called_once() + + async def test_executor_instance_reused_across_queries(self, mocker): + """AsyncExecutor() is constructed once; run_query() never re-instantiates it.""" + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[] + ) + executor_cls = mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection("http://localhost:6333") + await conn.run_query("SHOW COLLECTIONS") + await conn.run_query("SHOW COLLECTIONS") + await conn.run_query("SHOW COLLECTIONS") + + # AsyncExecutor constructor called exactly once, not once per query + executor_cls.assert_called_once() + # But execute() called three times + assert mock_executor.execute.call_count == 3 + + async def test_invalid_query_raises_qql_syntax_error(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection("http://localhost:6333") + with pytest.raises(QQLSyntaxError): + await conn.run_query("TOTALLY INVALID QUERY GIBBERISH") + + async def test_run_query_returns_execution_result(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="1 collection(s) found", data=["docs"] + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection("http://localhost:6333") + result = await conn.run_query("SHOW COLLECTIONS") + assert isinstance(result, ExecutionResult) + assert result.success is True + + +# ── TestAsyncConnectionLifecycle ─────────────────────────────────────────────── + +@pytest.mark.anyio +class TestAsyncConnectionLifecycle: + """AsyncConnection.close() and the async context-manager protocol.""" + + async def test_close_calls_client_close(self, mocker): + mock_client = AsyncMock() + mocker.patch("qql.async_connection.AsyncQdrantClient", return_value=mock_client) + conn = AsyncConnection("http://localhost:6333") + await conn.close() + mock_client.close.assert_called_once() + + async def test_context_manager_closes_on_exit(self, mocker): + mock_client = AsyncMock() + mocker.patch("qql.async_connection.AsyncQdrantClient", return_value=mock_client) + + async with AsyncConnection("http://localhost:6333") as conn: + assert conn._client is mock_client + + mock_client.close.assert_called_once() + + async def test_context_manager_closes_on_exception(self, mocker): + mock_client = AsyncMock() + mocker.patch("qql.async_connection.AsyncQdrantClient", return_value=mock_client) + + with pytest.raises(RuntimeError, match="oops"): + async with AsyncConnection("http://localhost:6333"): + raise RuntimeError("oops") + + mock_client.close.assert_called_once() + + +# ── TestAsyncConnectionBatch ─────────────────────────────────────────────────── + +@pytest.mark.anyio +class TestAsyncConnectionBatch: + """AsyncConnection batching support (run_queries_batch, run_parameterized_batch, QQLAsyncBatch).""" + + async def test_run_queries_batch(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="Batch executed successfully", + data=[ + ExecutionResult(success=True, message="Found 1 result(s)"), + ExecutionResult(success=True, message="Found 2 result(s)"), + ], + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + results = await conn.run_queries_batch([ + "SEARCH docs SIMILAR TO 'neurology' LIMIT 5", + "SEARCH docs SIMILAR TO 'cardiology' LIMIT 5", + ]) + assert len(results) == 2 + assert results[0].message == "Found 1 result(s)" + assert results[1].message == "Found 2 result(s)" + + async def test_run_parameterized_batch(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="Batch executed successfully", + data=[ + ExecutionResult(success=True, message="ok"), + ExecutionResult(success=True, message="ok"), + ], + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + results = await conn.run_parameterized_batch( + "SEARCH docs SIMILAR TO :query LIMIT 5 WHERE category = :category", + [ + {"query": "brain stroke", "category": "Neurology"}, + {"query": "heart attack", "category": "Cardiology"}, + ], + ) + assert len(results) == 2 + mock_executor.execute.assert_called_once() + stmt = mock_executor.execute.call_args[0][0] + # Verify both statements compiled correctly + assert len(stmt.statements) == 2 + + async def test_run_parameterized_query(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="ok", + data="res1", + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + result = await conn.run_parameterized_query( + "SEARCH docs SIMILAR TO :query LIMIT 5 WHERE category = :category", + {"query": "brain stroke", "category": "Neurology"}, + ) + + stmt = mock_executor.execute.call_args[0][0] + assert result.data == "res1" + assert stmt.query_text == "brain stroke" + + async def test_qql_async_batch_context_manager(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="Batch executed successfully", + data=[ + ExecutionResult(success=True, message="Res 1", data="d1"), + ExecutionResult(success=True, message="Res 2", data="d2"), + ], + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + async with QQLAsyncBatch(conn) as batch: + ref1 = batch.add("SEARCH docs SIMILAR TO 'neurology' LIMIT 5") + ref2 = batch.add("SEARCH docs SIMILAR TO 'cardiology' LIMIT 5") + + assert ref1.result.message == "Res 1" + assert ref2.result.message == "Res 2" + assert ref1.result.data == "d1" + assert ref2.result.data == "d2" + + +# ── TestArchitecturalGapsClosed ──────────────────────────────────────────────── + +@pytest.mark.anyio +class TestArchitecturalGapsClosed: + """Rigorous tests covering async execution, race conditions, strict parser validation, and ID propagation.""" + + async def test_async_search_embeds_once(self, mocker): + """AsyncExecutor keeps the hot path direct and avoids threadpool overhead for cached embeddings.""" + mock_client = AsyncMock() + mock_client.collection_exists.return_value = True + + # Mock embedders to track how they are called + mocker.patch("qql.async_executor.Embedder.__init__", return_value=None) + mock_embed = mocker.patch("qql.async_executor.Embedder.embed", return_value=[0.1, 0.2]) + + from qql import QQLConfig + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + + from qql.parser import Parser + from qql.lexer import Lexer + + node = Parser(Lexer().tokenize("SEARCH docs SIMILAR TO 'neurology' LIMIT 5")).parse() + + result = await executor.execute(node) + assert result.success is True + mock_embed.assert_called_once_with("neurology") + + async def test_batch_parsing_rejects_trailing_statements(self): + """run_queries_batch must raise QQLSyntaxError when a query contains trailing tokens.""" + conn = AsyncConnection() + with pytest.raises(QQLSyntaxError, match="Expected EOF"): + await conn.run_queries_batch([ + "SHOW COLLECTIONS; DROP COLLECTION x" + ]) + + async def test_batched_insert_propagates_correct_ids(self, mocker): + """_execute_batch_block preserves individual point IDs when aggregating Inserts into a bulk Insert.""" + mock_client = AsyncMock() + mock_client.collection_exists.return_value = True + mock_client.upsert.return_value = None + + # Mock get_collection to return a mock config with matching vector size + mock_info = mocker.MagicMock() + mock_info.config.params.vectors.size = 2 + mock_client.get_collection.return_value = mock_info + + mocker.patch("qql.async_executor.Embedder.__init__", return_value=None) + mocker.patch("qql.async_executor.Embedder.embed", return_value=[0.1, 0.2]) + + from qql.executor import CollectionTopology + topology = CollectionTopology( + exists=True, + is_named_dense=False, + has_unnamed_dense=True, + dense_names=(), + sparse_names=(), + ) + mocker.patch("qql.async_executor.AsyncExecutor._resolve_topology", return_value=topology) + + from qql import QQLConfig + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + + from qql.parser import Parser + from qql.lexer import Lexer + + # Aggregate multiple INSERTs inside a Batch block statement + qql_batch = ( + "BEGIN BATCH\n" + "INSERT INTO COLLECTION docs VALUES {'text': 'a', 'id': 101};\n" + "INSERT INTO COLLECTION docs VALUES {'text': 'b', 'id': 102};\n" + "END BATCH" + ) + node = Parser(Lexer().tokenize(qql_batch)).parse() + res = await executor.execute(node) + + assert res.success is True + # Check that individual execution results correctly maintain their original custom point ID identity! + assert len(res.data) == 2 + assert res.data[0].data["id"] == 101 + assert res.data[1].data["id"] == 102 + + async def test_race_condition_collection_creation(self, mocker): + """Concurrent inserts into a non-existent collection serialize creation to avoid Qdrant conflicts.""" + import asyncio + mock_client = AsyncMock() + + # Mock get_collection to return a mock config with matching vector size + mock_info = mocker.MagicMock() + mock_info.config.params.vectors.size = 2 + mock_client.get_collection.return_value = mock_info + + from qql.executor import CollectionTopology + # Mock resolve_topology sequence using real CollectionTopology objects + topology_sequence = [ + CollectionTopology(exists=False, is_named_dense=False), # First insert task resolve topology + CollectionTopology(exists=False, is_named_dense=False), # Second insert task resolve topology + CollectionTopology(exists=False, is_named_dense=False), # Inside lock for first insert + CollectionTopology(exists=True, is_named_dense=False, has_unnamed_dense=True, dense_names=(), sparse_names=()), # Inside lock for second insert + ] + + mocker.patch("qql.async_executor.Embedder.__init__", return_value=None) + mocker.patch("qql.async_executor.Embedder.embed", return_value=[0.1, 0.2]) + + # Override _resolve_topology to yield the sequence + calls = 0 + async def mock_resolve(*args, **kwargs): + nonlocal calls + val = topology_sequence[calls] + calls += 1 + return val + + mocker.patch("qql.async_executor.AsyncExecutor._resolve_topology", side_effect=mock_resolve) + mocker.patch("qql.async_executor.AsyncExecutor._create_collection_and_wait", return_value=None) + + from qql import QQLConfig + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + + from qql.parser import Parser + from qql.lexer import Lexer + + insert_node_1 = Parser(Lexer().tokenize("INSERT INTO COLLECTION docs VALUES {'text': 'a', 'id': 1}")).parse() + insert_node_2 = Parser(Lexer().tokenize("INSERT INTO COLLECTION docs VALUES {'text': 'b', 'id': 2}")).parse() + + # Fire both concurrently + res1, res2 = await asyncio.gather( + executor.execute(insert_node_1), + executor.execute(insert_node_2), + ) + + assert res1.success is True + assert res2.success is True + # Verify that _create_collection_and_wait was called exactly once despite concurrency! + executor._create_collection_and_wait.assert_called_once() + + def test_strict_batch_grammar(self): + """Parser must raise QQLSyntaxError if a batch block ends with bare END instead of END BATCH.""" + from qql.parser import Parser + from qql.lexer import Lexer + + invalid_batch = ( + "BEGIN BATCH\n" + "SHOW COLLECTIONS\n" + "END" + ) + with pytest.raises(QQLSyntaxError, match="Expected BATCH"): + Parser(Lexer().tokenize(invalid_batch)).parse() diff --git a/tests/test_connection.py b/tests/test_connection.py index c209698..4d40638 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -191,3 +191,107 @@ def test_run_query_invalid_syntax_still_raises(self, mocker): mocker.patch("qdrant_client.QdrantClient") with pytest.raises(QQLSyntaxError): run_query("TOTALLY INVALID", url="http://localhost:6333") + + +# ── TestConnectionBatching ─────────────────────────────────────────────────── + +class TestConnectionBatching: + """Connection batching functionality: prefer_grpc, QQLBatch, parameterized run.""" + + def test_prefer_grpc_passed_to_qdrant_client(self, mocker): + mock_client_cls = mocker.patch("qdrant_client.QdrantClient") + Connection("http://localhost:6333", prefer_grpc=True, grpc_port=9999) + mock_client_cls.assert_called_once_with( + url="http://localhost:6333", api_key=None, prefer_grpc=True, grpc_port=9999 + ) + + def test_run_queries_batch_pipes_to_executor(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[ + ExecutionResult(success=True, message="ok1", data="res1"), + ExecutionResult(success=True, message="ok2", data="res2") + ] + ) + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + conn = Connection() + results = conn.run_queries_batch([ + "SHOW COLLECTIONS", + "SHOW COLLECTIONS" + ]) + assert len(results) == 2 + assert results[0].data == "res1" + assert results[1].data == "res2" + mock_executor.execute.assert_called_once() + + def test_run_parameterized_batch_substitutes_vars(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[ + ExecutionResult(success=True, message="ok1", data="res1"), + ExecutionResult(success=True, message="ok2", data="res2") + ] + ) + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + conn = Connection() + results = conn.run_parameterized_batch( + "SEARCH docs SIMILAR TO :query LIMIT :limit WHERE category = :cat", + [ + {"query": "ML", "limit": 5, "cat": "news"}, + {"query": "AI", "limit": 10, "cat": "tech"} + ] + ) + assert len(results) == 2 + # Check that AST was called with reconstructed strings + called_node = mock_executor.execute.call_args[0][0] + # In BatchBlockStmt, statements should contain substituted strings + assert called_node.statements[0].query_text == "ML" + assert called_node.statements[0].limit == 5 + assert called_node.statements[1].query_text == "AI" + assert called_node.statements[1].limit == 10 + + def test_run_parameterized_query_substitutes_vars(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data="res1" + ) + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + conn = Connection() + result = conn.run_parameterized_query( + "SEARCH docs SIMILAR TO :query LIMIT :limit WHERE category = :cat", + {"query": "ML", "limit": 5, "cat": "news"}, + ) + + called_node = mock_executor.execute.call_args[0][0] + assert result.data == "res1" + assert called_node.query_text == "ML" + assert called_node.limit == 5 + + def test_qql_batch_context_manager(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[ + ExecutionResult(success=True, message="ok1", data="res1"), + ExecutionResult(success=True, message="ok2", data="res2") + ] + ) + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + from qql import QQLBatch + conn = Connection() + with QQLBatch(conn) as batch: + p1 = batch.add("SHOW COLLECTIONS") + p2 = batch.add("SHOW COLLECTIONS") + with pytest.raises(RuntimeError): + p1.result + + assert p1.result.data == "res1" + assert p2.result.data == "res2" + mock_executor.execute.assert_called_once() From 43e0b3a6b86f28cb273fcd7069711db7c10ce43f Mon Sep 17 00:00:00 2001 From: Srimon Date: Mon, 25 May 2026 08:42:09 +0530 Subject: [PATCH 2/4] feat: enhance batch processing with strict result count checks and null handling --- docs/reference.md | 2 +- src/qql/async_connection.py | 16 ++++++- src/qql/async_executor.py | 56 ++++++++++++---------- src/qql/connection.py | 16 ++++++- src/qql/parser.py | 15 ++++-- src/qql/utils.py | 86 +++++++++++++++++++++++++++++----- tests/test_async_connection.py | 58 +++++++++++++++++++++++ tests/test_connection.py | 43 +++++++++++++++++ tests/test_executor.py | 55 ++++++++++++++++++++++ tests/test_parser.py | 15 ++++++ 10 files changed, 315 insertions(+), 47 deletions(-) diff --git a/docs/reference.md b/docs/reference.md index dda9a45..7164c22 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -171,7 +171,7 @@ Both sync and async connections support: ## Project Structure -``` +```text qql/ ├── pyproject.toml # Package config; installs the `qql` CLI command ├── src/ diff --git a/src/qql/async_connection.py b/src/qql/async_connection.py index 1125e48..b51ad0b 100644 --- a/src/qql/async_connection.py +++ b/src/qql/async_connection.py @@ -190,8 +190,16 @@ async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseExc if not self._queries: return results = await self.connection.run_queries_batch(self._queries) - for proxy, res in zip(self._proxies, results): + for proxy, res in zip(self._proxies, results, strict=False): proxy._resolve(res) + if len(results) != len(self._proxies): + error = RuntimeError( + "Batch result count mismatch: " + f"expected {len(self._proxies)}, got {len(results)}" + ) + for proxy in self._proxies[len(results):]: + proxy._reject(error) + raise error class AsyncOperationProxy: @@ -199,13 +207,19 @@ class AsyncOperationProxy: def __init__(self) -> None: self._result: ExecutionResult | None = None + self._exception: RuntimeError | None = None def _resolve(self, result: ExecutionResult) -> None: self._result = result + def _reject(self, exception: RuntimeError) -> None: + self._exception = exception + @property def result(self) -> ExecutionResult: """The resolved ExecutionResult.""" + if self._exception is not None: + raise self._exception if self._result is None: raise RuntimeError("AsyncBatch has not been executed yet.") return self._result diff --git a/src/qql/async_executor.py b/src/qql/async_executor.py index 7d7ec36..53bc933 100644 --- a/src/qql/async_executor.py +++ b/src/qql/async_executor.py @@ -138,12 +138,16 @@ async def _ensure_collection( vector_size: int, topology: CollectionTopology, explicit_vector: str | None, - ) -> None: + ) -> CollectionTopology: if topology.exists: info = await self._client.get_collection(name) vectors = info.config.params.vectors # type: ignore[union-attr] + sparse_vectors = info.config.params.sparse_vectors or {} + current_topology = CollectionTopology( + **collection_topology_kwargs(vectors, sparse_vectors) + ) if isinstance(vectors, dict): - vector_name = topology.dense_using(explicit_vector) + vector_name = current_topology.dense_using(explicit_vector) if vector_name is None: raise QQLRuntimeError("Collection has no dense vector") vector_config = vectors[vector_name] @@ -164,12 +168,12 @@ async def _ensure_collection( ) else: raise QQLRuntimeError("Collection has no dense vector") + return current_topology else: async with self._creation_lock: current_topology = await self._resolve_topology(name) if current_topology.exists: - await self._ensure_collection(name, vector_size, current_topology, explicit_vector) - return + return await self._ensure_collection(name, vector_size, current_topology, explicit_vector) await self._create_collection_and_wait( collection_name=name, @@ -179,6 +183,7 @@ async def _ensure_collection( ) }, ) + return await self._resolve_topology(name) async def _create_collection_and_wait(self, **kwargs: Any) -> None: collection_name = kwargs["collection_name"] @@ -290,7 +295,7 @@ async def _execute_insert(self, node: InsertStmt) -> ExecutionResult: embedder = Embedder(model_name) vector = embedder.embed(node.values["text"]) - await self._ensure_collection( + topology = await self._ensure_collection( node.collection, len(vector), topology, node.dense_vector ) point_vector = build_dense_point_vector( @@ -351,22 +356,6 @@ async def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: sparse_objs = [sparse_embedder.embed(vals["text"]) for vals in node.values_list] first_dense_vector = dense_vectors[0] if dense_vectors else None - points: list[PointStruct] = [] - for idx, vals in enumerate(node.values_list): - point_id, payload = extract_point_id_and_payload(vals) - dense_vector = dense_vectors[idx] - sparse_obj = sparse_objs[idx] - sparse_vector = SparseVector( - indices=sparse_obj["indices"], values=sparse_obj["values"] - ) - points.append( - PointStruct( - id=point_id, - vector={dense_name: dense_vector, sparse_name: sparse_vector}, - payload=payload, - ) - ) - if not topology.exists: assert first_dense_vector is not None async with self._creation_lock: @@ -385,6 +374,22 @@ async def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: dense_name = current_topology.dense_using(node.dense_vector) or dense_name sparse_name = current_topology.sparse_using(node.sparse_vector) + points: list[PointStruct] = [] + for idx, vals in enumerate(node.values_list): + point_id, payload = extract_point_id_and_payload(vals) + dense_vector = dense_vectors[idx] + sparse_obj = sparse_objs[idx] + sparse_vector = SparseVector( + indices=sparse_obj["indices"], values=sparse_obj["values"] + ) + points.append( + PointStruct( + id=point_id, + vector={dense_name: dense_vector, sparse_name: sparse_vector}, + payload=payload, + ) + ) + try: await self._client.upsert( collection_name=node.collection, @@ -406,6 +411,10 @@ async def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: vectors = [embedder.embed(vals["text"]) for vals in node.values_list] first_vector = vectors[0] if vectors else None + assert first_vector is not None + topology = await self._ensure_collection( + node.collection, len(first_vector), topology, node.dense_vector + ) points = [] for idx, vals in enumerate(node.values_list): vector = vectors[idx] @@ -420,11 +429,6 @@ async def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: PointStruct(id=point_id, vector=point_vector, payload=payload) ) - assert first_vector is not None - await self._ensure_collection( - node.collection, len(first_vector), topology, node.dense_vector - ) - try: await self._client.upsert( collection_name=node.collection, diff --git a/src/qql/connection.py b/src/qql/connection.py index 3b26ace..0fa5092 100644 --- a/src/qql/connection.py +++ b/src/qql/connection.py @@ -198,8 +198,16 @@ def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException if not self._queries: return results = self.connection.run_queries_batch(self._queries) - for proxy, res in zip(self._proxies, results): + for proxy, res in zip(self._proxies, results, strict=False): proxy._resolve(res) + if len(results) != len(self._proxies): + error = RuntimeError( + "Batch result count mismatch: " + f"expected {len(self._proxies)}, got {len(results)}" + ) + for proxy in self._proxies[len(results):]: + proxy._reject(error) + raise error class OperationProxy: @@ -207,13 +215,19 @@ class OperationProxy: def __init__(self) -> None: self._result: ExecutionResult | None = None + self._exception: RuntimeError | None = None def _resolve(self, result: ExecutionResult) -> None: self._result = result + def _reject(self, exception: RuntimeError) -> None: + self._exception = exception + @property def result(self) -> ExecutionResult: """The resolved ExecutionResult.""" + if self._exception is not None: + raise self._exception if self._result is None: raise RuntimeError("Batch has not been executed yet.") return self._result diff --git a/src/qql/parser.py b/src/qql/parser.py index 7134fce..617f76d 100644 --- a/src/qql/parser.py +++ b/src/qql/parser.py @@ -72,6 +72,8 @@ def __init__(self, tokens: list[Token]) -> None: def parse(self) -> ASTNode: node = self._parse_single_statement() + while self._peek().kind == TokenKind.SEMICOLON: + self._advance() self._expect(TokenKind.EOF) return node @@ -1078,12 +1080,15 @@ def _parse_field_path(self) -> str: f"Expected a field name, got '{tok.value}'", tok.pos ) - def _parse_literal(self) -> str | int | float | bool: - """STRING | INTEGER | FLOAT | boolean""" + def _parse_literal(self) -> str | int | float | bool | None: + """STRING | INTEGER | FLOAT | boolean | NULL""" tok = self._peek() if tok.kind == TokenKind.STRING: self._advance() return tok.value + if tok.kind == TokenKind.NULL: + self._advance() + return None if tok.kind == TokenKind.INTEGER: self._advance() return int(tok.value) @@ -1099,7 +1104,7 @@ def _parse_literal(self) -> str | int | float | bool: self._advance() return False raise QQLSyntaxError( - f"Expected a literal value (string, integer, float, or boolean), got '{tok.value}'", + f"Expected a literal value (string, integer, float, boolean, or null), got '{tok.value}'", tok.pos, ) @@ -1116,10 +1121,10 @@ def _parse_number(self) -> int | float: f"Expected a number, got '{tok.value}'", tok.pos ) - def _parse_literal_list(self) -> list[str | int | float | bool]: + def _parse_literal_list(self) -> list[str | int | float | bool | None]: """'(' literal { ',' literal } [','] ')' — used by IN / NOT IN.""" self._expect(TokenKind.LPAREN) - items: list[str | int | float | bool] = [] + items: list[str | int | float | bool | None] = [] if self._peek().kind == TokenKind.RPAREN: self._advance() return items diff --git a/src/qql/utils.py b/src/qql/utils.py index e72d1db..024cc7b 100644 --- a/src/qql/utils.py +++ b/src/qql/utils.py @@ -78,19 +78,65 @@ class SearchGroupByOptions: def render_parameterized_query(template: str, params: dict[str, Any]) -> str: - query_str = template - for k in sorted(params.keys(), key=len, reverse=True): - val = params[k] - placeholder = f":{k}" - if isinstance(val, str): - escaped_val = val.replace("'", "\\'") - repr_val = f"'{escaped_val}'" - elif isinstance(val, bool): - repr_val = "true" if val else "false" - else: - repr_val = str(val) - query_str = query_str.replace(placeholder, repr_val) - return query_str + rendered = [] + in_string = False + quote_char = "" + i = 0 + while i < len(template): + ch = template[i] + if in_string: + rendered.append(ch) + if ch == "\\" and i + 1 < len(template): + rendered.append(template[i + 1]) + i += 2 + continue + if ch == quote_char: + in_string = False + quote_char = "" + i += 1 + continue + + if ch in ("'", '"'): + in_string = True + quote_char = ch + rendered.append(ch) + i += 1 + continue + + if ch == ":": + name_start = i + 1 + name_end = name_start + while name_end < len(template) and ( + template[name_end].isalnum() or template[name_end] == "_" + ): + name_end += 1 + name = template[name_start:name_end] + if name in params: + rendered.append(_qql_literal(params[name])) + i = name_end + continue + + rendered.append(ch) + i += 1 + + return "".join(rendered) + + +def _qql_literal(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, str): + escaped = ( + value.replace("\\", "\\\\") + .replace("'", "\\'") + .replace("\n", "\\n") + .replace("\t", "\\t") + .replace("\r", "\\r") + ) + return f"'{escaped}'" + if isinstance(value, bool): + return "true" if value else "false" + return str(value) def collection_topology_kwargs(vectors: Any, sparse_vectors: Any) -> dict[str, Any]: @@ -158,6 +204,14 @@ def group_batch_statements(statements: tuple[ASTNode, ...]) -> list[BatchGroup]: current_group: list[ASTNode] = [] for stmt in statements: + if isinstance(stmt, SearchStmt) and stmt.group_by is not None: + _append_batch_group(groups, current_type, current_collection, current_group) + groups.append(BatchGroup("other", stmt.collection, [stmt])) + current_type = None + current_collection = None + current_group = [] + continue + if isinstance(stmt, (SearchStmt, RecommendStmt)): coll = stmt.collection if current_type == "query" and current_collection == coll: @@ -247,6 +301,12 @@ def build_qdrant_filter(expr: FilterExpr) -> Any: if isinstance(expr, NotExpr): return Filter(must_not=[build_qdrant_filter(expr.operand)]) if isinstance(expr, CompareExpr): + if expr.value is None: + null_condition = IsNullCondition(is_null=PayloadField(key=expr.field)) + if expr.op == "=": + return null_condition + if expr.op == "!=": + return Filter(must_not=[null_condition]) if expr.op == "=": return FieldCondition(key=expr.field, match=MatchValue(value=expr.value)) if expr.op == "!=": diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index e148d72..c23641a 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -240,6 +240,26 @@ async def test_qql_async_batch_context_manager(self, mocker): assert ref1.result.data == "d1" assert ref2.result.data == "d2" + async def test_qql_async_batch_raises_when_result_count_mismatches_proxy_count(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="Batch executed", + data=[ExecutionResult(success=True, message="Res 1", data="d1")], + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + async with QQLAsyncBatch(conn) as batch: + ref1 = batch.add("SHOW COLLECTIONS") + ref2 = batch.add("SHOW COLLECTIONS") + + assert ref1.result.data == "d1" + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + ref2.result + # ── TestArchitecturalGapsClosed ──────────────────────────────────────────────── @@ -375,6 +395,44 @@ async def mock_resolve(*args, **kwargs): # Verify that _create_collection_and_wait was called exactly once despite concurrency! executor._create_collection_and_wait.assert_called_once() + async def test_async_insert_uses_refreshed_topology_after_create_race(self, mocker): + """If another coroutine creates the collection, upsert must use its vector name.""" + mock_client = AsyncMock() + mock_client.upsert.return_value = None + + mocker.patch("qql.async_executor.Embedder.__init__", return_value=None) + mocker.patch("qql.async_executor.Embedder.embed", return_value=[0.1, 0.2]) + + from qql.executor import CollectionTopology + + stale_topology = CollectionTopology(exists=False, is_named_dense=False) + created_topology = CollectionTopology( + exists=True, + is_named_dense=True, + dense_names=("body",), + sparse_names=(), + ) + + async def mock_ensure(*args, **kwargs): + return created_topology + + mocker.patch("qql.async_executor.AsyncExecutor._resolve_topology", return_value=stale_topology) + mocker.patch("qql.async_executor.AsyncExecutor._ensure_collection", side_effect=mock_ensure) + + from qql import QQLConfig + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + + from qql.parser import Parser + from qql.lexer import Lexer + + node = Parser( + Lexer().tokenize("INSERT INTO COLLECTION docs VALUES {'text': 'hello', 'id': 1}") + ).parse() + await executor.execute(node) + + point = mock_client.upsert.call_args.kwargs["points"][0] + assert point.vector == {"body": [0.1, 0.2]} + def test_strict_batch_grammar(self): """Parser must raise QQLSyntaxError if a batch block ends with bare END instead of END BATCH.""" from qql.parser import Parser diff --git a/tests/test_connection.py b/tests/test_connection.py index 4d40638..f9eb1d8 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -6,6 +6,7 @@ from qql import Connection, QQLConfig, Executor, ExecutionResult, run_query from qql.exceptions import QQLSyntaxError +from qql.utils import render_parameterized_query # ── TestConnectionInit ──────────────────────────────────────────────────────── @@ -295,3 +296,45 @@ def test_qql_batch_context_manager(self, mocker): assert p1.result.data == "res1" assert p2.result.data == "res2" mock_executor.execute.assert_called_once() + + def test_qql_batch_raises_when_result_count_mismatches_proxy_count(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="ok", + data=[ExecutionResult(success=True, message="only one", data="res1")], + ) + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + from qql import QQLBatch + conn = Connection() + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + with QQLBatch(conn) as batch: + p1 = batch.add("SHOW COLLECTIONS") + p2 = batch.add("SHOW COLLECTIONS") + + assert p1.result.data == "res1" + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + p2.result + + +class TestParameterizedRendering: + def test_render_parameterized_query_escapes_strings_and_nulls(self): + rendered = render_parameterized_query( + "SEARCH docs SIMILAR TO :query LIMIT 5 WHERE category = :category", + {"query": "O'Reilly\\notes\nnext", "category": None}, + ) + + assert rendered == ( + "SEARCH docs SIMILAR TO 'O\\'Reilly\\\\notes\\nnext' " + "LIMIT 5 WHERE category = null" + ) + + def test_render_parameterized_query_does_not_replace_inside_strings(self): + rendered = render_parameterized_query( + "SEARCH docs SIMILAR TO ':query' LIMIT 5 WHERE tag = :query", + {"query": "needle"}, + ) + + assert rendered == "SEARCH docs SIMILAR TO ':query' LIMIT 5 WHERE tag = 'needle'" diff --git a/tests/test_executor.py b/tests/test_executor.py index 0a2a74f..bb8da56 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -3201,6 +3201,61 @@ def test_grouped_search_params_with_clause_forwarded(self, executor, mock_client assert kwargs.get("search_params") is not None +class TestBatchGroupedSearch: + def test_grouped_search_in_batch_runs_individually(self, executor, mock_client, mocker): + _mock_hybrid_collection(mock_client) + mock_group_response = mocker.MagicMock() + mock_group_response.groups = [] + mock_client.query_points_groups.return_value = mock_group_response + mock_query_response = mocker.MagicMock() + mock_query_response.points = [] + mock_client.query_batch_points.return_value = [mock_query_response] + + from qql.lexer import Lexer + from qql.parser import Parser + + node = Parser( + Lexer().tokenize( + "BEGIN BATCH " + "SEARCH articles SIMILAR TO 'q' LIMIT 5 GROUP BY category; " + "SEARCH articles SIMILAR TO 'plain' LIMIT 5 " + "END BATCH" + ) + ).parse() + + result = executor.execute(node) + + assert result.success is True + mock_client.query_points_groups.assert_called_once() + mock_client.query_batch_points.assert_called_once() + assert "group(s)" in result.data[0].message + assert "result(s)" in result.data[1].message + + +class TestNullComparisonFilters: + def test_equal_null_builds_is_null_condition(self, executor): + from qql.ast_nodes import CompareExpr + from qdrant_client.models import IsNullCondition + + result = executor._build_qdrant_filter( + CompareExpr(field="deleted_at", op="=", value=None) + ) + + assert isinstance(result, IsNullCondition) + assert result.is_null.key == "deleted_at" + + def test_not_equal_null_builds_not_null_filter(self, executor): + from qql.ast_nodes import CompareExpr + from qdrant_client.models import Filter, IsNullCondition + + result = executor._build_qdrant_filter( + CompareExpr(field="deleted_at", op="!=", value=None) + ) + + assert isinstance(result, Filter) + assert isinstance(result.must_not[0], IsNullCondition) + + class TestUpdateVectorVectorShape: """Gaps 12 & 13 — verify exact vector shape sent to Qdrant for named/unnamed collections.""" diff --git a/tests/test_parser.py b/tests/test_parser.py index ec9e051..6270340 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -46,6 +46,21 @@ def parse(query: str): return Parser(tokens).parse() +class TestStatementBoundaries: + def test_top_level_statement_allows_trailing_semicolon(self): + node = parse("SHOW COLLECTIONS;") + assert isinstance(node, ShowCollectionsStmt) + + def test_batch_block_allows_trailing_semicolon(self): + node = parse( + "BEGIN BATCH\n" + "SHOW COLLECTIONS;\n" + "END BATCH;" + ) + assert len(node.statements) == 1 + assert isinstance(node.statements[0], ShowCollectionsStmt) + + class TestInsert: def test_basic_insert(self): node = parse("INSERT INTO COLLECTION notes VALUES {'text': 'hello'}") From 9865b0b59cbec63c1e9e8d3df6e1b5533b9a6741 Mon Sep 17 00:00:00 2001 From: Srimon Date: Mon, 25 May 2026 09:16:01 +0530 Subject: [PATCH 3/4] feat: enhance batch processing and query handling with improved error management and null support --- src/qql/ast_nodes.py | 6 ++-- src/qql/async_connection.py | 36 ++++++++++++---------- src/qql/connection.py | 36 ++++++++++++---------- src/qql/parser.py | 6 ++-- src/qql/utils.py | 55 +++++++++++++++++++++++++--------- tests/test_async_connection.py | 28 ++++++++++++++++- tests/test_connection.py | 47 ++++++++++++++++++++++++++++- tests/test_executor.py | 35 ++++++++++++++++++++++ tests/test_parser.py | 11 +++++++ 9 files changed, 208 insertions(+), 52 deletions(-) diff --git a/src/qql/ast_nodes.py b/src/qql/ast_nodes.py index f50d92b..f8a826b 100644 --- a/src/qql/ast_nodes.py +++ b/src/qql/ast_nodes.py @@ -25,9 +25,9 @@ class QuantizationConfig: class SearchWith: """Query-time search params supported by Qdrant SearchParams.""" hnsw_ef: int | None = None - exact: bool = False - acorn: bool = False - indexed_only: bool = False + exact: bool | None = None + acorn: bool | None = None + indexed_only: bool | None = None quantization: "QuantizationSearchWith | None" = None mmr_diversity: float | None = None mmr_candidates: int | None = None diff --git a/src/qql/async_connection.py b/src/qql/async_connection.py index b51ad0b..2e611e6 100644 --- a/src/qql/async_connection.py +++ b/src/qql/async_connection.py @@ -182,24 +182,30 @@ def add(self, query: str) -> AsyncOperationProxy: return proxy async def __aenter__(self) -> QQLAsyncBatch: + self._queries.clear() + self._proxies.clear() return self async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: - if exc_type is not None: - return - if not self._queries: - return - results = await self.connection.run_queries_batch(self._queries) - for proxy, res in zip(self._proxies, results, strict=False): - proxy._resolve(res) - if len(results) != len(self._proxies): - error = RuntimeError( - "Batch result count mismatch: " - f"expected {len(self._proxies)}, got {len(results)}" - ) - for proxy in self._proxies[len(results):]: - proxy._reject(error) - raise error + try: + if exc_type is not None: + return + if not self._queries: + return + results = await self.connection.run_queries_batch(self._queries) + if len(results) != len(self._proxies): + error = RuntimeError( + "Batch result count mismatch: " + f"expected {len(self._proxies)}, got {len(results)}" + ) + for proxy in self._proxies: + proxy._reject(error) + raise error + for proxy, res in zip(self._proxies, results, strict=True): + proxy._resolve(res) + finally: + self._queries.clear() + self._proxies.clear() class AsyncOperationProxy: diff --git a/src/qql/connection.py b/src/qql/connection.py index 0fa5092..2a21721 100644 --- a/src/qql/connection.py +++ b/src/qql/connection.py @@ -190,24 +190,30 @@ def add(self, query: str) -> OperationProxy: return proxy def __enter__(self) -> QQLBatch: + self._queries.clear() + self._proxies.clear() return self def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: - if exc_type is not None: - return - if not self._queries: - return - results = self.connection.run_queries_batch(self._queries) - for proxy, res in zip(self._proxies, results, strict=False): - proxy._resolve(res) - if len(results) != len(self._proxies): - error = RuntimeError( - "Batch result count mismatch: " - f"expected {len(self._proxies)}, got {len(results)}" - ) - for proxy in self._proxies[len(results):]: - proxy._reject(error) - raise error + try: + if exc_type is not None: + return + if not self._queries: + return + results = self.connection.run_queries_batch(self._queries) + if len(results) != len(self._proxies): + error = RuntimeError( + "Batch result count mismatch: " + f"expected {len(self._proxies)}, got {len(results)}" + ) + for proxy in self._proxies: + proxy._reject(error) + raise error + for proxy, res in zip(self._proxies, results, strict=True): + proxy._resolve(res) + finally: + self._queries.clear() + self._proxies.clear() class OperationProxy: diff --git a/src/qql/parser.py b/src/qql/parser.py index 617f76d..3c3b2f0 100644 --- a/src/qql/parser.py +++ b/src/qql/parser.py @@ -1278,9 +1278,9 @@ def _parse_value(self) -> Any: def _parse_with_clause(self) -> SearchWith: self._expect(TokenKind.LBRACE) hnsw_ef: int | None = None - exact: bool = False - acorn: bool = False - indexed_only: bool = False + exact: bool | None = None + acorn: bool | None = None + indexed_only: bool | None = None quantization: QuantizationSearchWith | None = None mmr_diversity: float | None = None mmr_candidates: int | None = None diff --git a/src/qql/utils.py b/src/qql/utils.py index 024cc7b..fdd8c07 100644 --- a/src/qql/utils.py +++ b/src/qql/utils.py @@ -126,19 +126,31 @@ def _qql_literal(value: Any) -> str: if value is None: return "null" if isinstance(value, str): - escaped = ( - value.replace("\\", "\\\\") - .replace("'", "\\'") - .replace("\n", "\\n") - .replace("\t", "\\t") - .replace("\r", "\\r") - ) - return f"'{escaped}'" + return _qql_string_literal(value) if isinstance(value, bool): return "true" if value else "false" + if isinstance(value, (list, tuple)): + return "[" + ", ".join(_qql_literal(item) for item in value) + "]" + if isinstance(value, dict): + items = ", ".join( + f"{_qql_string_literal(str(key))}: {_qql_literal(item)}" + for key, item in value.items() + ) + return "{" + items + "}" return str(value) +def _qql_string_literal(value: str) -> str: + escaped = ( + value.replace("\\", "\\\\") + .replace("'", "\\'") + .replace("\n", "\\n") + .replace("\t", "\\t") + .replace("\r", "\\r") + ) + return f"'{escaped}'" + + def collection_topology_kwargs(vectors: Any, sparse_vectors: Any) -> dict[str, Any]: if isinstance(vectors, dict): dense_names = tuple(vectors.keys()) @@ -307,6 +319,9 @@ def build_qdrant_filter(expr: FilterExpr) -> Any: return null_condition if expr.op == "!=": return Filter(must_not=[null_condition]) + raise QQLRuntimeError( + f"Cannot use operator '{expr.op}' with null for field '{expr.field}'" + ) if expr.op == "=": return FieldCondition(key=expr.field, match=MatchValue(value=expr.value)) if expr.op == "!=": @@ -467,17 +482,29 @@ def merge_search_with(base: SearchWith | None, override: SearchWith) -> SearchWi if base is None: return override return SearchWith( - hnsw_ef=override.hnsw_ef or base.hnsw_ef, - exact=override.exact or base.exact, - acorn=override.acorn or base.acorn, - indexed_only=override.indexed_only or base.indexed_only, - quantization=override.quantization or base.quantization, + hnsw_ef=override.hnsw_ef if override.hnsw_ef is not None else base.hnsw_ef, + exact=override.exact if override.exact is not None else base.exact, + acorn=override.acorn if override.acorn is not None else base.acorn, + indexed_only=( + override.indexed_only + if override.indexed_only is not None + else base.indexed_only + ), + quantization=( + override.quantization + if override.quantization is not None + else base.quantization + ), mmr_diversity=( override.mmr_diversity if override.mmr_diversity is not None else base.mmr_diversity ), - mmr_candidates=override.mmr_candidates or base.mmr_candidates, + mmr_candidates=( + override.mmr_candidates + if override.mmr_candidates is not None + else base.mmr_candidates + ), ) diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index c23641a..5661032 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -256,10 +256,36 @@ async def test_qql_async_batch_raises_when_result_count_mismatches_proxy_count(s ref1 = batch.add("SHOW COLLECTIONS") ref2 = batch.add("SHOW COLLECTIONS") - assert ref1.result.data == "d1" + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + ref1.result with pytest.raises(RuntimeError, match="Batch result count mismatch"): ref2.result + async def test_qql_async_batch_reuse_does_not_replay_previous_queries(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.side_effect = [ + ExecutionResult(success=True, message="ok", data=[ + ExecutionResult(success=True, message="first", data="d1"), + ]), + ExecutionResult(success=True, message="ok", data=[ + ExecutionResult(success=True, message="second", data="d2"), + ]), + ] + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + batch = QQLAsyncBatch(conn) + async with batch: + first = batch.add("SHOW COLLECTIONS") + async with batch: + second = batch.add("SHOW COLLECTIONS") + + assert first.result.data == "d1" + assert second.result.data == "d2" + assert len(mock_executor.execute.call_args_list[0].args[0].statements) == 1 + assert len(mock_executor.execute.call_args_list[1].args[0].statements) == 1 + # ── TestArchitecturalGapsClosed ──────────────────────────────────────────────── diff --git a/tests/test_connection.py b/tests/test_connection.py index f9eb1d8..605c93e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -314,10 +314,37 @@ def test_qql_batch_raises_when_result_count_mismatches_proxy_count(self, mocker) p1 = batch.add("SHOW COLLECTIONS") p2 = batch.add("SHOW COLLECTIONS") - assert p1.result.data == "res1" + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + p1.result with pytest.raises(RuntimeError, match="Batch result count mismatch"): p2.result + def test_qql_batch_reuse_does_not_replay_previous_queries(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.side_effect = [ + ExecutionResult(success=True, message="ok", data=[ + ExecutionResult(success=True, message="first", data="res1"), + ]), + ExecutionResult(success=True, message="ok", data=[ + ExecutionResult(success=True, message="second", data="res2"), + ]), + ] + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + from qql import QQLBatch + conn = Connection() + batch = QQLBatch(conn) + with batch: + first = batch.add("SHOW COLLECTIONS") + with batch: + second = batch.add("SHOW COLLECTIONS") + + assert first.result.data == "res1" + assert second.result.data == "res2" + assert len(mock_executor.execute.call_args_list[0].args[0].statements) == 1 + assert len(mock_executor.execute.call_args_list[1].args[0].statements) == 1 + class TestParameterizedRendering: def test_render_parameterized_query_escapes_strings_and_nulls(self): @@ -338,3 +365,21 @@ def test_render_parameterized_query_does_not_replace_inside_strings(self): ) assert rendered == "SEARCH docs SIMILAR TO ':query' LIMIT 5 WHERE tag = 'needle'" + + def test_render_parameterized_query_renders_compound_literals(self): + rendered = render_parameterized_query( + "INSERT INTO COLLECTION docs VALUES :values", + { + "values": { + "text": "hello", + "tags": ["a", "b"], + "meta": {"score": 1, "active": True}, + } + }, + ) + + assert rendered == ( + "INSERT INTO COLLECTION docs VALUES " + "{'text': 'hello', 'tags': ['a', 'b'], " + "'meta': {'score': 1, 'active': true}}" + ) diff --git a/tests/test_executor.py b/tests/test_executor.py index bb8da56..24589fe 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -3255,6 +3255,41 @@ def test_not_equal_null_builds_not_null_filter(self, executor): assert isinstance(result, Filter) assert isinstance(result.must_not[0], IsNullCondition) + def test_ordering_comparison_to_null_raises_clear_error(self, executor): + from qql.ast_nodes import CompareExpr + + with pytest.raises(QQLRuntimeError, match="Cannot use operator '>' with null"): + executor._build_qdrant_filter( + CompareExpr(field="deleted_at", op=">", value=None) + ) + + +class TestMergeSearchWith: + def test_merge_search_with_preserves_zero_values(self): + from qql.ast_nodes import SearchWith + from qql.utils import merge_search_with + + merged = merge_search_with( + SearchWith(hnsw_ef=128, mmr_candidates=10), + SearchWith(hnsw_ef=0, mmr_candidates=0), + ) + + assert merged.hnsw_ef == 0 + assert merged.mmr_candidates == 0 + + def test_merge_search_with_can_override_true_with_false(self): + from qql.ast_nodes import SearchWith + from qql.utils import merge_search_with + + merged = merge_search_with( + SearchWith(exact=True, acorn=True, indexed_only=True), + SearchWith(exact=False, acorn=False, indexed_only=False), + ) + + assert merged.exact is False + assert merged.acorn is False + assert merged.indexed_only is False + class TestUpdateVectorVectorShape: """Gaps 12 & 13 — verify exact vector shape sent to Qdrant for named/unnamed collections.""" diff --git a/tests/test_parser.py b/tests/test_parser.py index 6270340..c2bd64b 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1101,6 +1101,17 @@ def test_with_exact_false(self): assert node.with_clause is not None assert node.with_clause.exact is False + def test_exact_keyword_survives_with_clause_merge(self): + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 EXACT WITH { hnsw_ef: 128 }") + assert node.with_clause is not None + assert node.with_clause.exact is True + assert node.with_clause.hnsw_ef == 128 + + def test_with_exact_false_can_override_exact_keyword(self): + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 EXACT WITH { exact: false }") + assert node.with_clause is not None + assert node.with_clause.exact is False + def test_with_acorn(self): node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 WITH { acorn: true }") assert node.with_clause is not None From bfb81c42bfb2ca5881242e158763880b541efcbb Mon Sep 17 00:00:00 2001 From: Srimon Date: Mon, 25 May 2026 10:13:02 +0530 Subject: [PATCH 4/4] feat: enhance InExpr and NotInExpr handling with null support in filter building --- src/qql/ast_nodes.py | 4 ++-- src/qql/utils.py | 24 ++++++++++++++++++-- tests/test_async_connection.py | 4 ++-- tests/test_connection.py | 6 ++--- tests/test_executor.py | 40 ++++++++++++++++++++++++++++++++++ 5 files changed, 69 insertions(+), 9 deletions(-) diff --git a/src/qql/ast_nodes.py b/src/qql/ast_nodes.py index f8a826b..a7fbc95 100644 --- a/src/qql/ast_nodes.py +++ b/src/qql/ast_nodes.py @@ -114,14 +114,14 @@ class BetweenExpr: class InExpr: """field IN (v1, v2, ...)""" field: str - values: tuple[str | int | float | bool, ...] + values: tuple[str | int | float | bool | None, ...] @dataclass(frozen=True) class NotInExpr: """field NOT IN (v1, v2, ...)""" field: str - values: tuple[str | int | float | bool, ...] + values: tuple[str | int | float | bool | None, ...] @dataclass(frozen=True) diff --git a/src/qql/utils.py b/src/qql/utils.py index fdd8c07..52749de 100644 --- a/src/qql/utils.py +++ b/src/qql/utils.py @@ -335,11 +335,31 @@ def build_qdrant_filter(expr: FilterExpr) -> Any: if isinstance(expr, BetweenExpr): return FieldCondition(key=expr.field, range=Range(gte=expr.low, lte=expr.high)) if isinstance(expr, InExpr): - return FieldCondition(key=expr.field, match=MatchAny(any=list(expr.values))) + non_nulls = [value for value in expr.values if value is not None] + null_condition = IsNullCondition(is_null=PayloadField(key=expr.field)) + if len(non_nulls) == len(expr.values): + return FieldCondition(key=expr.field, match=MatchAny(any=non_nulls)) + if not non_nulls: + return null_condition + return Filter( + should=[ + null_condition, + FieldCondition(key=expr.field, match=MatchAny(any=non_nulls)), + ] + ) if isinstance(expr, NotInExpr): + non_nulls = [value for value in expr.values if value is not None] + null_condition = IsNullCondition(is_null=PayloadField(key=expr.field)) + if len(non_nulls) != len(expr.values): + must_not = [null_condition] + if non_nulls: + must_not.append( + FieldCondition(key=expr.field, match=MatchAny(any=non_nulls)) + ) + return Filter(must_not=must_not) return FieldCondition( key=expr.field, - match=MatchExcept(**{"except": list(expr.values)}), + match=MatchExcept(**{"except": non_nulls}), ) if isinstance(expr, IsNullExpr): return IsNullCondition(is_null=PayloadField(key=expr.field)) diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index 5661032..57b892a 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -257,9 +257,9 @@ async def test_qql_async_batch_raises_when_result_count_mismatches_proxy_count(s ref2 = batch.add("SHOW COLLECTIONS") with pytest.raises(RuntimeError, match="Batch result count mismatch"): - ref1.result + _ = ref1.result with pytest.raises(RuntimeError, match="Batch result count mismatch"): - ref2.result + _ = ref2.result async def test_qql_async_batch_reuse_does_not_replay_previous_queries(self, mocker): mocker.patch("qql.async_connection.AsyncQdrantClient") diff --git a/tests/test_connection.py b/tests/test_connection.py index 605c93e..6be0853 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -291,7 +291,7 @@ def test_qql_batch_context_manager(self, mocker): p1 = batch.add("SHOW COLLECTIONS") p2 = batch.add("SHOW COLLECTIONS") with pytest.raises(RuntimeError): - p1.result + _ = p1.result assert p1.result.data == "res1" assert p2.result.data == "res2" @@ -315,9 +315,9 @@ def test_qql_batch_raises_when_result_count_mismatches_proxy_count(self, mocker) p2 = batch.add("SHOW COLLECTIONS") with pytest.raises(RuntimeError, match="Batch result count mismatch"): - p1.result + _ = p1.result with pytest.raises(RuntimeError, match="Batch result count mismatch"): - p2.result + _ = p2.result def test_qql_batch_reuse_does_not_replay_previous_queries(self, mocker): mocker.patch("qdrant_client.QdrantClient") diff --git a/tests/test_executor.py b/tests/test_executor.py index 24589fe..9912a78 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1729,6 +1729,46 @@ def test_build_not_in(self, executor): assert isinstance(result, FieldCondition) assert isinstance(result.match, MatchExcept) + def test_build_in_with_only_null(self, executor): + from qdrant_client.models import IsNullCondition + from qql.ast_nodes import InExpr + + result = executor._build_qdrant_filter(InExpr(field="status", values=(None,))) + assert isinstance(result, IsNullCondition) + assert result.is_null.key == "status" + + def test_build_in_with_null_and_values(self, executor): + from qdrant_client.models import Filter, FieldCondition, IsNullCondition, MatchAny + from qql.ast_nodes import InExpr + + result = executor._build_qdrant_filter(InExpr(field="status", values=(None, "draft"))) + assert isinstance(result, Filter) + assert isinstance(result.should[0], IsNullCondition) + assert isinstance(result.should[1], FieldCondition) + assert isinstance(result.should[1].match, MatchAny) + assert result.should[1].match.any == ["draft"] + + def test_build_not_in_with_only_null(self, executor): + from qdrant_client.models import Filter, IsNullCondition + from qql.ast_nodes import NotInExpr + + result = executor._build_qdrant_filter(NotInExpr(field="status", values=(None,))) + assert isinstance(result, Filter) + assert isinstance(result.must_not[0], IsNullCondition) + + def test_build_not_in_with_null_and_values(self, executor): + from qdrant_client.models import Filter, FieldCondition, IsNullCondition, MatchAny + from qql.ast_nodes import NotInExpr + + result = executor._build_qdrant_filter( + NotInExpr(field="status", values=(None, "deleted")) + ) + assert isinstance(result, Filter) + assert isinstance(result.must_not[0], IsNullCondition) + assert isinstance(result.must_not[1], FieldCondition) + assert isinstance(result.must_not[1].match, MatchAny) + assert result.must_not[1].match.any == ["deleted"] + def test_build_is_null(self, executor): from qdrant_client.models import IsNullCondition from qql.ast_nodes import IsNullExpr