diff --git a/README.md b/README.md index 7b15af8..25ba17b 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ [![PyPI version](https://img.shields.io/pypi/v/qql-cli?color=blue&label=PyPI)](https://pypi.org/project/qql-cli/) [![Python 3.12+](https://img.shields.io/pypi/pyversions/qql-cli)](https://pypi.org/project/qql-cli/) [![MIT License](https://img.shields.io/badge/license-MIT-green)](LICENSE) -[![Tests](https://img.shields.io/badge/tests-549%20passing-brightgreen)](tests/) +[![Tests](https://img.shields.io/badge/tests-635%20passing-brightgreen)](tests/) -Write `INSERT`, `SELECT`, `SEARCH`, `SCROLL`, `RECOMMEND`, `UPDATE`, `DELETE`, and `CREATE COLLECTION` statements instead of Python SDK calls. Supports hybrid dense+sparse vector search, grouped search (GROUP BY), cross-encoder reranking, quantization (scalar, turbo, binary, product), SQL-style `WHERE` filters, script execution, and collection dump/restore. +Write `INSERT`, `SELECT`, `SEARCH`, `SCROLL`, `RECOMMEND`, `UPDATE`, `DELETE`, and `CREATE COLLECTION` statements instead of Python SDK calls. Supports hybrid dense+sparse vector search, grouped search (GROUP BY), cross-encoder reranking, quantization (scalar, turbo, binary, product), SQL-style `WHERE` filters, script execution, collection dump/restore, async execution, gRPC transport, parameterized queries, and batched query execution. ``` qql> INSERT INTO COLLECTION notes VALUES {'text': 'Qdrant is a vector database', 'author': 'alice', 'year': 2024} @@ -50,16 +50,23 @@ Your query string When you run `INSERT`, the `text` field is automatically converted into a dense vector using [Fastembed](https://github.com/qdrant/fastembed). In **hybrid mode** (`USING HYBRID`), a sparse BM25 vector is also generated alongside the dense vector, and searches use Qdrant's Reciprocal Rank Fusion (RRF) by default to merge the results of both retrieval methods. You can switch hybrid search to DBSF with `FUSION 'dbsf'`. -QQL also exposes a **programmatic API** for use inside Python applications — no CLI required: +QQL also exposes a **programmatic API** for use inside Python applications — no CLI required. Use `Connection` for sync code, `AsyncConnection` for async apps, and batch helpers when you want QQL to combine compatible operations into fewer Qdrant requests: ```python -from qql import Connection +from qql import Connection, QQLBatch with Connection("http://localhost:6333") as conn: conn.run_query("INSERT INTO COLLECTION notes VALUES {'text': 'Qdrant is fast'}") - result = conn.run_query("SEARCH notes SIMILAR TO 'vector database' LIMIT 5") - for hit in result.data: - print(hit["score"], hit["payload"]) + result = conn.run_parameterized_query( + "SEARCH notes SIMILAR TO :query LIMIT 5", + {"query": "vector database"}, + ) + + with QQLBatch(conn) as batch: + neurology = batch.add("SEARCH notes SIMILAR TO 'neurology' LIMIT 5") + cardiology = batch.add("SEARCH notes SIMILAR TO 'cardiology' LIMIT 5") + + print(neurology.result.data, cardiology.result.data) ``` --- @@ -97,8 +104,8 @@ Full documentation lives in the [`docs/`](docs/) folder and at **[pavanjava.gith | [SEARCH / SELECT / SCROLL / RECOMMEND / Hybrid / GROUP BY / RERANK](docs/search.md) | Semantic search, grouped search, point retrieval, pagination, hybrid, reranking, recommendations | | [WHERE Filters](docs/filters.md) | Full SQL-style filter operators | | [Collections & Quantization](docs/collections.md) | SHOW, CREATE, DROP, QUANTIZE (scalar/turbo/binary/product), CREATE INDEX, UPDATE VECTOR, UPDATE PAYLOAD | -| [Scripts: EXECUTE / DUMP](docs/scripts.md) | Script files, collection backup/restore | -| [Programmatic Usage](docs/programmatic.md) | Use QQL as a Python library via `Connection` or `run_query()` | +| [Scripts: EXECUTE / DUMP](docs/scripts.md) | Script files, `BEGIN BATCH` blocks, collection backup/restore | +| [Programmatic Usage](docs/programmatic.md) | Sync/async Python APIs, parameterized queries, batching, gRPC | | [Reference: Models / Config / Errors](docs/reference.md) | Embedding models, config file, error reference | --- @@ -170,6 +177,12 @@ DELETE FROM articles WHERE year < 2020 -- Scripts EXECUTE /path/to/script.qql DUMP articles /path/to/backup.qql + +-- Batch block +BEGIN BATCH; + SEARCH articles SIMILAR TO 'query one' LIMIT 5; + SEARCH articles SIMILAR TO 'query two' LIMIT 5; +END BATCH ``` --- @@ -182,7 +195,7 @@ Tests do not require a running Qdrant instance — the Qdrant client is mocked. pytest tests/ -v ``` -Expected: **549 tests passing**. +Expected: **635 tests passing**. --- diff --git a/docs/getting-started.md b/docs/getting-started.md index bff4473..ab64e9d 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -5,7 +5,7 @@ title: "Getting Started" # Getting Started with QQL -QQL is a SQL-like query language and CLI for [Qdrant](https://qdrant.tech). Instead of writing Python SDK calls you write natural query statements to insert, search, manage, and delete vector data. +QQL is a SQL-like query language and CLI for [Qdrant](https://qdrant.tech). Instead of writing Python SDK calls you write natural query statements to insert, search, manage, and delete vector data. It can also be used as a sync or async Python library with batching, parameterized queries, and optional gRPC transport. --- @@ -154,6 +154,12 @@ SHOW COLLECTION notes -- Retrieve a point by ID SELECT * FROM notes WHERE id = 1 + +-- Run compatible queries as one batch +BEGIN BATCH; + SEARCH notes SIMILAR TO 'vector databases' LIMIT 5; + SEARCH notes SIMILAR TO 'semantic search' LIMIT 5; +END BATCH ``` --- @@ -164,5 +170,6 @@ SELECT * FROM notes WHERE id = 1 - [SEARCH / SELECT / SCROLL / RECOMMEND / Hybrid / RERANK](search.md) — querying - [WHERE Filters](filters.md) — payload filtering - [Collections & Quantization](collections.md) — managing collections -- [Scripts: EXECUTE / DUMP](scripts.md) — automating with script files +- [Scripts: EXECUTE / DUMP](scripts.md) — automating with script files and batch blocks +- [Programmatic Usage](programmatic.md) — sync/async APIs, batching, parameterized queries, gRPC - [Embedding Models](reference.md#embedding-models) — model reference diff --git a/docs/programmatic.md b/docs/programmatic.md index d632a0d..8db6599 100644 --- a/docs/programmatic.md +++ b/docs/programmatic.md @@ -16,6 +16,8 @@ single connection to Qdrant once and reuses it for every `run_query()` call — more efficient than the legacy `run_query()` function, which creates a new client on every invocation. +Use `AsyncConnection` when your application already runs on `asyncio`. + ### Basic usage ```python @@ -70,6 +72,22 @@ with Connection("https://.qdrant.io", secret="") as print(result.data) ``` +### gRPC transport + +QQL can ask the Qdrant client to prefer gRPC for lower request overhead: + +```python +from qql import Connection + +with Connection( + "http://localhost:6333", + prefer_grpc=True, + grpc_port=6334, +) as conn: + result = conn.run_query("SHOW COLLECTIONS") + print(result.data) +``` + ### Custom embedding model ```python @@ -155,9 +173,117 @@ with Connection("http://localhost:6333") as conn: | `url` | `str` | `"http://localhost:6333"` | Qdrant instance URL | | `secret` | `str \| None` | `None` | API key; `None` for unauthenticated | | `default_model` | `str \| None` | `None` → `sentence-transformers/all-MiniLM-L6-v2` | Dense embedding model used when no `USING MODEL` clause is given | +| `prefer_grpc` | `bool` | `False` | Passes `prefer_grpc=True` to the Qdrant client | +| `grpc_port` | `int` | `6334` | gRPC port used when `prefer_grpc=True` | | `default_dense_vector_name` | `str` | `"dense"` | Dense vector name used when QQL creates a collection and no explicit `USING VECTOR` name is given | | `default_sparse_vector_name` | `str` | `"sparse"` | Sparse vector name used when QQL creates a hybrid collection and no explicit sparse vector name is given | +--- + +## Parameterized Queries + +Parameterized helpers render `:name` placeholders before parsing the QQL statement. String values are quoted and escaped; booleans are rendered as `true` / `false`. + +```python +from qql import Connection + +with Connection("http://localhost:6333") as conn: + result = conn.run_parameterized_query( + "SEARCH notes SIMILAR TO :query LIMIT 5 WHERE author = :author", + {"query": "vector database", "author": "alice"}, + ) + + results = conn.run_parameterized_batch( + "SEARCH notes SIMILAR TO :query LIMIT 5 WHERE category = :category", + [ + {"query": "brain stroke", "category": "Neurology"}, + {"query": "heart attack", "category": "Cardiology"}, + ], + ) +``` + +Parameterized queries are a convenience for building QQL strings safely in application code; they are not sent to Qdrant as server-side prepared statements. + +--- + +## Batch Execution + +`run_queries_batch()` parses multiple QQL strings into a `BatchBlockStmt`. The executor groups compatible statements: + +- compatible `SEARCH` / `RECOMMEND` statements use Qdrant `query_batch_points` +- compatible `INSERT` statements become one `INSERT BULK` +- mixed or incompatible statements still execute in order + +```python +from qql import Connection + +with Connection("http://localhost:6333") as conn: + results = conn.run_queries_batch([ + "SEARCH docs SIMILAR TO 'neurology' LIMIT 5", + "SEARCH docs SIMILAR TO 'cardiology' LIMIT 5", + ]) + + for result in results: + print(result.message) +``` + +For ergonomic batching in application code, use `QQLBatch`: + +```python +from qql import Connection, QQLBatch + +with Connection("http://localhost:6333") as conn: + with QQLBatch(conn) as batch: + neuro = batch.add("SEARCH docs SIMILAR TO 'neurology' LIMIT 5") + cardio = batch.add("SEARCH docs SIMILAR TO 'cardiology' LIMIT 5") + + print(neuro.result.data) + print(cardio.result.data) +``` + +Each proxy's `.result` becomes available after the context manager exits. + +--- + +## Async API + +`AsyncConnection` mirrors the sync API for `asyncio` applications and uses `AsyncQdrantClient` under the hood. + +```python +from qql import AsyncConnection + +async with AsyncConnection("http://localhost:6333") as conn: + await conn.run_query( + "INSERT INTO COLLECTION notes VALUES {'text': 'async QQL'}" + ) + result = await conn.run_query( + "SEARCH notes SIMILAR TO 'async vector search' LIMIT 5" + ) + print(result.data) +``` + +Async batching and parameterized helpers are also available: + +```python +from qql import AsyncConnection, QQLAsyncBatch + +async with AsyncConnection("http://localhost:6333", prefer_grpc=True) as conn: + result = await conn.run_parameterized_query( + "SEARCH docs SIMILAR TO :query LIMIT 5", + {"query": "clinical notes"}, + ) + + async with QQLAsyncBatch(conn) as batch: + first = batch.add("SEARCH docs SIMILAR TO 'neurology' LIMIT 5") + second = batch.add("SEARCH docs SIMILAR TO 'cardiology' LIMIT 5") + + print(first.result.data, second.result.data) +``` + +The async executor preserves the same `ExecutionResult` shape as the sync executor. + +--- + ### Power-user: `executor` property For low-level access to the pipeline, use `conn.executor` directly: @@ -250,7 +376,8 @@ class ExecutionResult: |---|---| | INSERT (dense) | `{"id": int \| "", "collection": ""}` | | INSERT (hybrid) | `{"id": int \| "", "collection": ""}` | -| INSERT BULK | `None` (count in `result.message`) | +| INSERT BULK | `{"ids": [int \| "", ...]}` | +| BEGIN BATCH / programmatic batch | `[ExecutionResult, ...]` | | SELECT | `{"id": str, "payload": dict}` or `None` when not found | | SEARCH | `[{"id": str, "score": float, "payload": dict}, ...]` | | SCROLL | `{"points": [{"id": str, "payload": dict}, ...], "next_offset": str \| int \| None}` | diff --git a/docs/reference.md b/docs/reference.md index ce9734e..7164c22 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -5,7 +5,7 @@ title: "Reference" # Reference — Models, Config, Project Structure, Errors -Default embedding models, configuration parameters, project layout, and common error codes for troubleshooting. +Default embedding models, configuration parameters, public APIs, project layout, and common error codes for troubleshooting. --- @@ -147,23 +147,48 @@ You can edit this file directly to change the default model without reconnecting --- +## Public Python API + +| API | Description | +|---|---| +| `Connection` | Stateful sync QQL client backed by `QdrantClient` | +| `AsyncConnection` | Stateful async QQL client backed by `AsyncQdrantClient` | +| `QQLBatch` | Sync context manager for collecting statements and resolving per-statement results after execution | +| `QQLAsyncBatch` | Async context manager equivalent of `QQLBatch` | +| `Executor` | Low-level sync AST executor | +| `AsyncExecutor` | Low-level async AST executor | +| `ExecutionResult` | Standard result object returned by all operations | + +Both sync and async connections support: + +- `run_query(query)` +- `run_queries_batch([query, ...])` +- `run_parameterized_query(template, params)` +- `run_parameterized_batch(template, [params, ...])` +- `prefer_grpc=True` and `grpc_port=` connection options + +--- + ## Project Structure -``` +```text qql/ ├── pyproject.toml # Package config; installs the `qql` CLI command ├── src/ │ └── qql/ -│ ├── __init__.py # Public API: Connection, run_query() +│ ├── __init__.py # Public API exports: sync, async, batching, parser/executor │ ├── cli.py # CLI entry point: connect, disconnect, execute, dump, REPL │ ├── config.py # QQLConfig dataclass + ~/.qql/config.json I/O -│ ├── connection.py # Connection class — stateful programmatic API +│ ├── connection.py # Sync Connection, QQLBatch, parameterized query helpers +│ ├── async_connection.py # AsyncConnection and QQLAsyncBatch │ ├── exceptions.py # QQLError, QQLSyntaxError, QQLRuntimeError │ ├── lexer.py # Tokenizer: string → List[Token] │ ├── ast_nodes.py # Frozen dataclasses for each statement and filter type │ ├── parser.py # Recursive descent parser: tokens → AST node │ ├── embedder.py # Embedder (dense) + SparseEmbedder (BM25) + CrossEncoderEmbedder (rerank) -│ ├── executor.py # AST node → Qdrant client call + filter + hybrid search +│ ├── executor.py # Sync AST node → Qdrant client call +│ ├── async_executor.py # Async AST node → AsyncQdrantClient call +│ ├── utils.py # Shared pure helpers for parsing, filters, batching, vectors │ ├── script.py # Script runner: parse and execute .qql files statement by statement │ └── dumper.py # Collection exporter: scroll all points → .qql INSERT BULK script └── tests/ @@ -171,6 +196,7 @@ qql/ ├── test_parser.py # Parser unit tests ├── test_executor.py # Executor unit tests (mocked Qdrant client) ├── test_connection.py # Connection class unit tests (mocked Qdrant client) + ├── test_async_connection.py # AsyncConnection / AsyncExecutor tests ├── test_script.py # Script runner unit tests └── test_dumper.py # Dumper unit tests ``` @@ -185,7 +211,7 @@ Tests do not require a running Qdrant instance — the Qdrant client is mocked. pytest tests/ -v ``` -Expected output: **604 tests passing**. +Expected output: **635 tests passing**. --- @@ -218,3 +244,6 @@ Expected output: **604 tests passing**. | `Unknown index type '...'` | Invalid schema type in CREATE INDEX | Use one of: `keyword`, `integer`, `float`, `bool`, `text`, `geo`, `datetime`, `uuid` | | `Unknown CREATE INDEX option '...'` | Unsupported advanced option for the chosen payload index type | Check which `WITH { ... }` keys are supported for `keyword`, `uuid`, or `text` | | `Qdrant error during CREATE INDEX: ...` | Qdrant rejected the index creation | Check field name and collection state | +| `Unterminated batch block; expected END BATCH` | A `BEGIN BATCH` block was not closed | Add `END BATCH` at the end of the block | +| `Batch has not been executed yet.` | Read a `QQLBatch` proxy result before leaving the context manager | Access `.result` only after the `with QQLBatch(...)` block exits | +| `AsyncBatch has not been executed yet.` | Read a `QQLAsyncBatch` proxy result before leaving the async context manager | Access `.result` only after the `async with QQLAsyncBatch(...)` block exits | diff --git a/docs/scripts.md b/docs/scripts.md index 1fd512f..6938eb8 100644 --- a/docs/scripts.md +++ b/docs/scripts.md @@ -5,7 +5,7 @@ title: "Scripts: EXECUTE / DUMP" # Script Files — EXECUTE and DUMP -QQL supports reading from and writing to `.qql` script files, making it easy to automate bulk operations, seed databases, and back up collections. +QQL supports reading from and writing to `.qql` script files, making it easy to automate bulk operations, seed databases, and back up collections. Scripts can contain regular statements or explicit `BEGIN BATCH ... END BATCH` blocks. --- @@ -51,10 +51,54 @@ SHOW COLLECTIONS **Rules:** - `--` to end-of-line is a comment and is ignored (inline or full-line) - Statements can span multiple lines (e.g. `INSERT BULK ... VALUES [...]`) +- `BEGIN BATCH ... END BATCH` is treated as one statement by the script splitter +- Semicolons are optional between top-level statements, but useful inside batch blocks - `RECOMMEND` statements work in `.qql` files the same way they do in the REPL - Blank lines between statements are ignored - By default all statements run even if one fails; use `--stop-on-error` to halt early +--- + +## BEGIN BATCH — group statements for fewer Qdrant calls + +Use `BEGIN BATCH ... END BATCH` when you want QQL to parse several statements as one executable batch. The executor keeps statement order in the returned results while grouping compatible operations internally. + +```sql +BEGIN BATCH; + SEARCH articles SIMILAR TO 'stroke symptoms' LIMIT 5 WHERE department = 'neurology'; + SEARCH articles SIMILAR TO 'cardiac markers' LIMIT 5 WHERE department = 'cardiology'; + RECOMMEND FROM articles POSITIVE IDS (1001, 1002) LIMIT 5; +END BATCH +``` + +Batch execution rules: + +- compatible `SEARCH` / `RECOMMEND` statements for the same collection use Qdrant `query_batch_points` +- compatible `INSERT` statements are combined into one bulk insert +- incompatible or mutation statements still execute in order +- each child statement produces its own `ExecutionResult` + +You can also use batch blocks directly in the REPL or through `Connection.run_query()`. + +```python +from qql import Connection + +with Connection("http://localhost:6333") as conn: + result = conn.run_query(""" + BEGIN BATCH; + SEARCH articles SIMILAR TO 'neurology' LIMIT 5; + SEARCH articles SIMILAR TO 'cardiology' LIMIT 5; + END BATCH + """) + + for child in result.data: + print(child.message) +``` + +Programmatic callers can use `run_queries_batch()` or `QQLBatch` instead of writing a batch block by hand. See [Programmatic Usage](programmatic.md#batch-execution). + +--- + **Included examples:** - [`resources/sample.qql`](../resources/sample.qql) seeds the demo medical dataset - [`resources/sample_v2.qql`](../resources/sample_v2.qql) is a compact end-to-end example with explicit IDs and runnable `RECOMMEND` statements diff --git a/src/qql/__init__.py b/src/qql/__init__.py index deeb737..296dbfe 100644 --- a/src/qql/__init__.py +++ b/src/qql/__init__.py @@ -12,15 +12,22 @@ QQLConfig, load_config, ) -from .connection import Connection +from .connection import Connection, QQLBatch, OperationProxy +from .async_connection import AsyncConnection, QQLAsyncBatch, AsyncOperationProxy from .exceptions import QQLError, QQLRuntimeError, QQLSyntaxError from .executor import ExecutionResult, Executor +from .async_executor import AsyncExecutor from .lexer import Lexer from .parser import Parser __all__ = [ "__version__", "Connection", + "QQLBatch", + "OperationProxy", + "AsyncConnection", + "QQLAsyncBatch", + "AsyncOperationProxy", "DEFAULT_DENSE_VECTOR_NAME", "DEFAULT_MODEL", "DEFAULT_SPARSE_VECTOR_NAME", @@ -30,6 +37,7 @@ "QQLSyntaxError", "ExecutionResult", "Executor", + "AsyncExecutor", "Lexer", "Parser", "load_config", diff --git a/src/qql/ast_nodes.py b/src/qql/ast_nodes.py index 5f0b2f2..a7fbc95 100644 --- a/src/qql/ast_nodes.py +++ b/src/qql/ast_nodes.py @@ -25,9 +25,9 @@ class QuantizationConfig: class SearchWith: """Query-time search params supported by Qdrant SearchParams.""" hnsw_ef: int | None = None - exact: bool = False - acorn: bool = False - indexed_only: bool = False + exact: bool | None = None + acorn: bool | None = None + indexed_only: bool | None = None quantization: "QuantizationSearchWith | None" = None mmr_diversity: float | None = None mmr_candidates: int | None = None @@ -114,14 +114,14 @@ class BetweenExpr: class InExpr: """field IN (v1, v2, ...)""" field: str - values: tuple[str | int | float | bool, ...] + values: tuple[str | int | float | bool | None, ...] @dataclass(frozen=True) class NotInExpr: """field NOT IN (v1, v2, ...)""" field: str - values: tuple[str | int | float | bool, ...] + values: tuple[str | int | float | bool | None, ...] @dataclass(frozen=True) @@ -340,6 +340,11 @@ class UpdatePayloadStmt: query_filter: FilterExpr | None = None +@dataclass(frozen=True) +class BatchBlockStmt: + statements: tuple[ASTNode, ...] + + # Union type for all top-level statement nodes ASTNode = ( InsertStmt @@ -357,4 +362,5 @@ class UpdatePayloadStmt: | DeleteStmt | UpdateVectorStmt | UpdatePayloadStmt + | BatchBlockStmt ) diff --git a/src/qql/async_connection.py b/src/qql/async_connection.py new file mode 100644 index 0000000..2e611e6 --- /dev/null +++ b/src/qql/async_connection.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +from typing import Any +from qdrant_client import AsyncQdrantClient + +from .config import DEFAULT_MODEL, QQLConfig +from .async_executor import AsyncExecutor +from .executor import ExecutionResult +from .lexer import Lexer +from .parser import Parser +from .utils import render_parameterized_query + + +class AsyncConnection: + """Stateful asynchronous connection to a Qdrant instance. + + Creates a single ``AsyncQdrantClient`` and ``AsyncExecutor`` once and reuses + them for every :meth:`run_query` call — more efficient than the standalone + one-shot helpers, which create a fresh client on every invocation. + + **Basic usage**:: + + conn = AsyncConnection("http://localhost:6333", secret="my-key") + result = await conn.run_query( + "INSERT INTO COLLECTION docs VALUES {'text': 'hello world'}" + ) + result = await conn.run_query("SEARCH docs SIMILAR TO 'hello' LIMIT 5") + await conn.close() + + **Context manager (preferred)**:: + + async with AsyncConnection("http://localhost:6333") as conn: + result = await conn.run_query("SHOW COLLECTIONS") + print(result.data) + + **Qdrant Cloud**:: + + async with AsyncConnection("https://.qdrant.io", secret="") as conn: + result = await conn.run_query("SHOW COLLECTIONS") + + **Custom embedding model**:: + + async with AsyncConnection( + "http://localhost:6333", + default_model="BAAI/bge-base-en-v1.5", + ) as conn: + result = await conn.run_query( + "INSERT INTO COLLECTION docs VALUES {'text': 'hello'}" + ) + """ + + def __init__( + self, + url: str = "http://localhost:6333", + secret: str | None = None, + default_model: str | None = None, + prefer_grpc: bool = False, + grpc_port: int = 6334, + ) -> None: + """Create an asynchronous connection to a Qdrant instance. + + Args: + url: Base URL of the Qdrant instance (default: ``http://localhost:6333``). + secret: API key for authenticated instances; ``None`` for unauthenticated. + default_model: Dense embedding model used when no ``USING MODEL`` clause + is specified. Defaults to ``sentence-transformers/all-MiniLM-L6-v2``. + prefer_grpc: Whether to connect via fast gRPC transport. + grpc_port: The gRPC port of Qdrant instance (default: 6334). + """ + self._config = QQLConfig( + url=url, + secret=secret, + default_model=default_model or DEFAULT_MODEL, + ) + client_kwargs = {"url": url, "api_key": secret} + if prefer_grpc: + client_kwargs["prefer_grpc"] = True + client_kwargs["grpc_port"] = grpc_port + self._client = AsyncQdrantClient(**client_kwargs) + self._executor = AsyncExecutor(self._client, self._config) + + # ── Public API ──────────────────────────────────────────────────────── + + async def run_query(self, query: str) -> ExecutionResult: + """Parse and execute a single QQL statement asynchronously. + + Args: + query: A QQL query string, e.g. ``"SEARCH docs SIMILAR TO 'hello' LIMIT 5"``. + + Returns: + An :class:`~qql.ExecutionResult` with ``success``, ``message``, and ``data`` fields. + + Raises: + QQLSyntaxError: The query string could not be parsed. + QQLRuntimeError: The query parsed correctly but Qdrant rejected it. + """ + tokens = Lexer().tokenize(query) + node = Parser(tokens).parse() + return await self._executor.execute(node) + + async def run_queries_batch(self, queries: list[str]) -> list[ExecutionResult]: + """Parse and execute a batch of QQL statements asynchronously. + + Combines compatible operations (such as SEARCH queries) to execute in + a single network request. + """ + from .ast_nodes import BatchBlockStmt + nodes = [] + for q in queries: + tokens = Lexer().tokenize(q) + node = Parser(tokens).parse() + nodes.append(node) + + batch_node = BatchBlockStmt(statements=tuple(nodes)) + res = await self._executor.execute(batch_node) + return res.data + + async def run_parameterized_query(self, template: str, params: dict[str, Any]) -> ExecutionResult: + """Execute one QQL query template with named parameters asynchronously. + + Uses named placeholders prefixed with ':' (e.g. :query, :category). + """ + return await self.run_query(render_parameterized_query(template, params)) + + async def run_parameterized_batch(self, template: str, params: list[dict[str, Any]]) -> list[ExecutionResult]: + """Execute a single QQL query template with a batch of parameters asynchronously. + + Uses named placeholders prefixed with ':' (e.g. :query, :category). + """ + queries = [render_parameterized_query(template, p) for p in params] + return await self.run_queries_batch(queries) + + async def close(self) -> None: + """Close the underlying Qdrant asynchronous connection pool.""" + await self._client.close() + + # ── Context manager ─────────────────────────────────────────────────── + + async def __aenter__(self) -> AsyncConnection: + return self + + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + await self.close() + + # ── Power-user properties ───────────────────────────────────────────── + + @property + def config(self) -> QQLConfig: + """The :class:`~qql.QQLConfig` in use (url, secret, default_model).""" + return self._config + + @property + def executor(self) -> AsyncExecutor: + """Direct access to the :class:`~qql.AsyncExecutor` for low-level use. + + Example — run multiple statements sharing a pre-built AST node:: + + from qql.lexer import Lexer + from qql.parser import Parser + + conn = AsyncConnection("http://localhost:6333") + tokens = Lexer().tokenize("SHOW COLLECTIONS") + node = Parser(tokens).parse() + result = await conn.executor.execute(node) + """ + return self._executor + + +class QQLAsyncBatch: + """Asynchronous session context manager for executing batch queries and mutations in QQL.""" + + def __init__(self, connection: AsyncConnection) -> None: + self.connection = connection + self._queries: list[str] = [] + self._proxies: list[AsyncOperationProxy] = [] + + def add(self, query: str) -> AsyncOperationProxy: + """Queue a QQL statement for batch execution.""" + self._queries.append(query) + proxy = AsyncOperationProxy() + self._proxies.append(proxy) + return proxy + + async def __aenter__(self) -> QQLAsyncBatch: + self._queries.clear() + self._proxies.clear() + return self + + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + try: + if exc_type is not None: + return + if not self._queries: + return + results = await self.connection.run_queries_batch(self._queries) + if len(results) != len(self._proxies): + error = RuntimeError( + "Batch result count mismatch: " + f"expected {len(self._proxies)}, got {len(results)}" + ) + for proxy in self._proxies: + proxy._reject(error) + raise error + for proxy, res in zip(self._proxies, results, strict=True): + proxy._resolve(res) + finally: + self._queries.clear() + self._proxies.clear() + + +class AsyncOperationProxy: + """Proxy handle that resolves to an ExecutionResult after QQLAsyncBatch exits.""" + + def __init__(self) -> None: + self._result: ExecutionResult | None = None + self._exception: RuntimeError | None = None + + def _resolve(self, result: ExecutionResult) -> None: + self._result = result + + def _reject(self, exception: RuntimeError) -> None: + self._exception = exception + + @property + def result(self) -> ExecutionResult: + """The resolved ExecutionResult.""" + if self._exception is not None: + raise self._exception + if self._result is None: + raise RuntimeError("AsyncBatch has not been executed yet.") + return self._result diff --git a/src/qql/async_executor.py b/src/qql/async_executor.py new file mode 100644 index 0000000..53bc933 --- /dev/null +++ b/src/qql/async_executor.py @@ -0,0 +1,1406 @@ +from __future__ import annotations + +import time +import asyncio +from typing import Any + +from qdrant_client import AsyncQdrantClient +from qdrant_client.http.exceptions import UnexpectedResponse +from qdrant_client.models import ( + Distance, + Filter, + FusionQuery, + LookupLocation, + Modifier, + PointStruct, + PointVectors, + Prefetch, + RecommendInput, + RecommendQuery, + QueryRequest, + SearchParams, + SparseVector, + SparseVectorParams, + VectorParams, + PayloadSchemaType, +) + +from .ast_nodes import ( + ASTNode, + AlterCollectionStmt, + CreateCollectionStmt, + CreateIndexStmt, + DeleteStmt, + DropCollectionStmt, + InsertBulkStmt, + InsertStmt, + RecommendStmt, + SelectStmt, + ScrollStmt, + SearchStmt, + ShowCollectionStmt, + ShowCollectionsStmt, + UpdateVectorStmt, + UpdatePayloadStmt, + BatchBlockStmt, +) +from .config import QQLConfig +from .embedder import Embedder, SparseEmbedder +from .exceptions import QQLRuntimeError +from .executor import Executor, ExecutionResult, CollectionTopology +from .utils import ( + build_bulk_insert_from_group, + build_dense_point_vector, + build_dense_query, + collection_topology_kwargs, + exclude_ids_from_filter, + extract_point_id_and_payload, + group_batch_statements, + has_mmr, + inserted_point_results, + parse_recommend_strategy, + resolve_hybrid_fusion, + validate_search_mmr_usage, +) + +_RERANK_FETCH_MULTIPLIER = 4 +_HYBRID_PREFETCH_MULTIPLIER = 4 +_COLLECTION_VISIBILITY_TIMEOUT_SECONDS = 5.0 +_COLLECTION_VISIBILITY_POLL_SECONDS = 0.05 + + +class AsyncExecutor(Executor): + """Asynchronous QQL execution engine for ``AsyncQdrantClient``. + + The async executor mirrors :class:`~qql.Executor` at the statement boundary: + every AST node supported by the sync executor has an async execution path + here. Pure parsing, validation, vector-shaping, filter-building, and result + formatting helpers live in ``qql.utils`` or are inherited from + :class:`~qql.Executor`; only Qdrant client calls and collection-creation + coordination are implemented with ``async``/``await`` in this module. + """ + + def __init__(self, client: AsyncQdrantClient, config: QQLConfig) -> None: + super().__init__(client=client, config=config) # type: ignore[arg-type] + self._client: AsyncQdrantClient = client + self._creation_lock = asyncio.Lock() + + async def execute(self, node: ASTNode) -> ExecutionResult: + if isinstance(node, InsertBulkStmt): + return await self._execute_insert_bulk(node) + if isinstance(node, InsertStmt): + return await self._execute_insert(node) + if isinstance(node, CreateCollectionStmt): + return await self._execute_create(node) + if isinstance(node, AlterCollectionStmt): + return await self._execute_alter_collection(node) + if isinstance(node, CreateIndexStmt): + return await self._execute_create_index(node) + if isinstance(node, DropCollectionStmt): + return await self._execute_drop(node) + if isinstance(node, ShowCollectionsStmt): + return await self._execute_show(node) + if isinstance(node, ShowCollectionStmt): + return await self._execute_show_collection(node) + if isinstance(node, ScrollStmt): + return await self._execute_scroll(node) + if isinstance(node, SelectStmt): + return await self._execute_select(node) + if isinstance(node, SearchStmt): + return await self._execute_search(node) + if isinstance(node, RecommendStmt): + return await self._execute_recommend(node) + if isinstance(node, DeleteStmt): + return await self._execute_delete(node) + if isinstance(node, UpdateVectorStmt): + return await self._execute_update_vector(node) + if isinstance(node, UpdatePayloadStmt): + return await self._execute_update_payload(node) + if isinstance(node, BatchBlockStmt): + return await self._execute_batch_block(node) + raise QQLRuntimeError(f"Unknown AST node type: {type(node)}") + + # ── Topology & Helper methods ───────────────────────────────────────── + + async def _resolve_topology(self, name: str) -> CollectionTopology: + if not await self._client.collection_exists(name): + return CollectionTopology(exists=False, is_named_dense=False) + + info = await self._client.get_collection(name) + params = info.config.params + vectors = params.vectors # type: ignore[union-attr] + sparse_vectors = params.sparse_vectors or {} + return CollectionTopology(**collection_topology_kwargs(vectors, sparse_vectors)) + + async def _ensure_collection( + self, + name: str, + vector_size: int, + topology: CollectionTopology, + explicit_vector: str | None, + ) -> CollectionTopology: + if topology.exists: + info = await self._client.get_collection(name) + vectors = info.config.params.vectors # type: ignore[union-attr] + sparse_vectors = info.config.params.sparse_vectors or {} + current_topology = CollectionTopology( + **collection_topology_kwargs(vectors, sparse_vectors) + ) + if isinstance(vectors, dict): + vector_name = current_topology.dense_using(explicit_vector) + if vector_name is None: + raise QQLRuntimeError("Collection has no dense vector") + vector_config = vectors[vector_name] + expected_size = getattr(vector_config, "size", None) + if expected_size is not None and expected_size != vector_size: + raise QQLRuntimeError( + f"Vector dimension mismatch: collection '{name}' vector " + f"'{vector_name}' expects {expected_size} dims, but " + f"model produces {vector_size} dims. Specify a compatible " + "model with USING MODEL ''." + ) + elif vectors is not None: + if vectors.size != vector_size: + raise QQLRuntimeError( + f"Vector dimension mismatch: collection '{name}' expects " + f"{vectors.size} dims, but model produces {vector_size} dims. " + f"Specify a compatible model with USING MODEL ''." + ) + else: + raise QQLRuntimeError("Collection has no dense vector") + return current_topology + else: + async with self._creation_lock: + current_topology = await self._resolve_topology(name) + if current_topology.exists: + return await self._ensure_collection(name, vector_size, current_topology, explicit_vector) + + await self._create_collection_and_wait( + collection_name=name, + vectors_config={ + explicit_vector or self._default_dense_vector_name(): VectorParams( + size=vector_size, distance=Distance.COSINE + ) + }, + ) + return await self._resolve_topology(name) + + async def _create_collection_and_wait(self, **kwargs: Any) -> None: + collection_name = kwargs["collection_name"] + await self._client.create_collection(**kwargs) + + deadline = time.monotonic() + _COLLECTION_VISIBILITY_TIMEOUT_SECONDS + while time.monotonic() < deadline: + if await self._client.collection_exists(collection_name): + return + await asyncio.sleep(_COLLECTION_VISIBILITY_POLL_SECONDS) + + raise QQLRuntimeError( + f"Collection '{collection_name}' was created but did not become visible in time" + ) + + async def _build_hybrid_vectors( + self, + query_text: str, + dense_model: str, + sparse_model_name: str, + ) -> tuple[list[float], SparseVector]: + dense_embedder = Embedder(dense_model) + sparse_embedder = SparseEmbedder(sparse_model_name) + + dense_vector = dense_embedder.embed(query_text) + sparse_obj = sparse_embedder.query_embed(query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + return dense_vector, sparse_vector + + # ── Statement executors ─────────────────────────────────────────────── + + async def _execute_insert(self, node: InsertStmt) -> ExecutionResult: + if "text" not in node.values: + raise QQLRuntimeError("INSERT requires a 'text' field in VALUES") + + topology = await self._resolve_topology(node.collection) + use_hybrid = node.hybrid or (topology.exists and topology.is_hybrid) + + if use_hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_embedder = Embedder(dense_model) + sparse_embedder = SparseEmbedder(sparse_model_name) + + dense_vector = dense_embedder.embed(node.values["text"]) + sparse_obj = sparse_embedder.embed(node.values["text"]) + + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + + dense_name = node.dense_vector or self._default_dense_vector_name() + sparse_name = node.sparse_vector or self._default_sparse_vector_name() + + if topology.exists: + resolved_dense = topology.dense_using(node.dense_vector) + if resolved_dense is None: + raise QQLRuntimeError( + "Hybrid collections must use named dense vectors" + ) + dense_name = resolved_dense + sparse_name = topology.sparse_using(node.sparse_vector) + else: + async with self._creation_lock: + current_topology = await self._resolve_topology(node.collection) + if not current_topology.exists: + await self._create_collection_and_wait( + collection_name=node.collection, + vectors_config={ + dense_name: VectorParams( + size=len(dense_vector), distance=Distance.COSINE + ) + }, + sparse_vectors_config={ + sparse_name: SparseVectorParams(modifier=Modifier.IDF) + }, + ) + else: + dense_name = current_topology.dense_using(node.dense_vector) or dense_name + sparse_name = current_topology.sparse_using(node.sparse_vector) + + point_id, payload = extract_point_id_and_payload(node.values) + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=[ + PointStruct( + id=point_id, + vector={dense_name: dense_vector, sparse_name: sparse_vector}, + payload=payload, + ) + ], + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted 1 point [{point_id}] (hybrid)", + data={"id": point_id, "collection": node.collection}, + ) + + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + vector = embedder.embed(node.values["text"]) + + topology = await self._ensure_collection( + node.collection, len(vector), topology, node.dense_vector + ) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), + ) + + point_id, payload = extract_point_id_and_payload(node.values) + + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=[PointStruct(id=point_id, vector=point_vector, payload=payload)], + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted 1 point [{point_id}]", + data={"id": point_id, "collection": node.collection}, + ) + + async def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: + if not node.values_list: + raise QQLRuntimeError("INSERT BULK VALUES list is empty") + for i, vals in enumerate(node.values_list): + if "text" not in vals: + raise QQLRuntimeError( + f"INSERT BULK: item at index {i} is missing required 'text' field" + ) + + topology = await self._resolve_topology(node.collection) + use_hybrid = node.hybrid or (topology.exists and topology.is_hybrid) + + if use_hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_embedder = Embedder(dense_model) + sparse_embedder = SparseEmbedder(sparse_model_name) + dense_name = node.dense_vector or self._default_dense_vector_name() + sparse_name = node.sparse_vector or self._default_sparse_vector_name() + if topology.exists: + resolved_dense = topology.dense_using(node.dense_vector) + if resolved_dense is None: + raise QQLRuntimeError( + "Hybrid collections must use named dense vectors" + ) + dense_name = resolved_dense + sparse_name = topology.sparse_using(node.sparse_vector) + + dense_vectors = [ + dense_embedder.embed(vals["text"]) for vals in node.values_list + ] + sparse_objs = [sparse_embedder.embed(vals["text"]) for vals in node.values_list] + + first_dense_vector = dense_vectors[0] if dense_vectors else None + if not topology.exists: + assert first_dense_vector is not None + async with self._creation_lock: + current_topology = await self._resolve_topology(node.collection) + if not current_topology.exists: + await self._create_collection_and_wait( + collection_name=node.collection, + vectors_config={ + dense_name: VectorParams(size=len(first_dense_vector), distance=Distance.COSINE) + }, + sparse_vectors_config={ + sparse_name: SparseVectorParams(modifier=Modifier.IDF) + }, + ) + else: + dense_name = current_topology.dense_using(node.dense_vector) or dense_name + sparse_name = current_topology.sparse_using(node.sparse_vector) + + points: list[PointStruct] = [] + for idx, vals in enumerate(node.values_list): + point_id, payload = extract_point_id_and_payload(vals) + dense_vector = dense_vectors[idx] + sparse_obj = sparse_objs[idx] + sparse_vector = SparseVector( + indices=sparse_obj["indices"], values=sparse_obj["values"] + ) + points.append( + PointStruct( + id=point_id, + vector={dense_name: dense_vector, sparse_name: sparse_vector}, + payload=payload, + ) + ) + + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=points, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT BULK: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted {len(points)} points (hybrid)", + data={"ids": [p.id for p in points]}, + ) + + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + + vectors = [embedder.embed(vals["text"]) for vals in node.values_list] + + first_vector = vectors[0] if vectors else None + assert first_vector is not None + topology = await self._ensure_collection( + node.collection, len(first_vector), topology, node.dense_vector + ) + points = [] + for idx, vals in enumerate(node.values_list): + vector = vectors[idx] + point_id, payload = extract_point_id_and_payload(vals) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), + ) + points.append( + PointStruct(id=point_id, vector=point_vector, payload=payload) + ) + + try: + await self._client.upsert( + collection_name=node.collection, + wait=True, + points=points, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during INSERT BULK: {e}") from e + + return ExecutionResult( + success=True, + message=f"Inserted {len(points)} points", + data={"ids": [p.id for p in points]}, + ) + + async def _execute_create(self, node: CreateCollectionStmt) -> ExecutionResult: + if await self._client.collection_exists(node.collection): + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' already exists", + ) + + dense_model_name = node.model or self._config.default_model + + quant_config = ( + self._build_quantization_config(node.quantization) + if node.quantization is not None + else None + ) + quant_label = ( + f", {node.quantization.type.value} quantization" + if node.quantization is not None + else "" + ) + hnsw_config = self._build_hnsw_config(node.config) + optimizers_config = self._build_optimizers_config(node.config) + params_config = self._build_collection_params_create_kwargs(node.config) + config_label = self._describe_collection_config(node.config) + vector_on_disk = ( + node.config.vectors.on_disk + if node.config is not None and node.config.vectors is not None + else None + ) + + if node.hybrid: + embedder = Embedder(dense_model_name) + dims = embedder.dimensions + dense_name = node.dense_vector or self._default_dense_vector_name() + sparse_name = node.sparse_vector or self._default_sparse_vector_name() + create_kwargs: dict[str, Any] = { + "collection_name": node.collection, + "vectors_config": { + dense_name: VectorParams( + size=dims, + distance=Distance.COSINE, + on_disk=vector_on_disk, + ) + }, + "sparse_vectors_config": { + sparse_name: SparseVectorParams(modifier=Modifier.IDF) + }, + } + if quant_config is not None: + create_kwargs["quantization_config"] = quant_config + if hnsw_config is not None: + create_kwargs["hnsw_config"] = hnsw_config + if optimizers_config is not None: + create_kwargs["optimizers_config"] = optimizers_config + create_kwargs.update(params_config) + await self._create_collection_and_wait(**create_kwargs) + return ExecutionResult( + success=True, + message=( + f"Collection '{node.collection}' created " + f"(hybrid: {dims}-dim dense + BM25 sparse, cosine distance{quant_label}{config_label})" + ), + ) + + embedder = Embedder(dense_model_name) + dims = embedder.dimensions + dense_name = node.dense_vector or self._default_dense_vector_name() + create_kwargs = { + "collection_name": node.collection, + "vectors_config": { + dense_name: VectorParams( + size=dims, + distance=Distance.COSINE, + on_disk=vector_on_disk, + ) + }, + } + if quant_config is not None: + create_kwargs["quantization_config"] = quant_config + if hnsw_config is not None: + create_kwargs["hnsw_config"] = hnsw_config + if optimizers_config is not None: + create_kwargs["optimizers_config"] = optimizers_config + create_kwargs.update(params_config) + await self._create_collection_and_wait(**create_kwargs) + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' created ({dims}-dimensional vectors, cosine distance{quant_label}{config_label})", + ) + + async def _execute_alter_collection(self, node: AlterCollectionStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + topology = await self._resolve_topology(node.collection) + + update_kwargs: dict[str, Any] = {"collection_name": node.collection} + vectors_config = self._build_vectors_config_diff(topology, node.config) + hnsw_config = self._build_hnsw_config(node.config) + optimizers_config = self._build_optimizers_config(node.config) + collection_params = self._build_collection_params_diff(node.config) + quantization_config = self._build_alter_quantization_config(node.quantization) + + if vectors_config is not None: + update_kwargs["vectors_config"] = vectors_config + if hnsw_config is not None: + update_kwargs["hnsw_config"] = hnsw_config + if optimizers_config is not None: + update_kwargs["optimizers_config"] = optimizers_config + if collection_params is not None: + update_kwargs["collection_params"] = collection_params + if quantization_config is not None: + update_kwargs["quantization_config"] = quantization_config + + try: + await self._client.update_collection(**update_kwargs) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during ALTER COLLECTION: {e}") from e + + return ExecutionResult( + success=True, + message=( + f"Collection '{node.collection}' altered" + f"{self._describe_collection_config(node.config)}" + f"{self._describe_quantization_update(node.quantization)}" + ), + ) + + async def _execute_create_index(self, node: CreateIndexStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + schema_map = { + "keyword": PayloadSchemaType.KEYWORD, + "integer": PayloadSchemaType.INTEGER, + "float": PayloadSchemaType.FLOAT, + "bool": PayloadSchemaType.BOOL, + "text": PayloadSchemaType.TEXT, + "geo": PayloadSchemaType.GEO, + "datetime": PayloadSchemaType.DATETIME, + "uuid": PayloadSchemaType.UUID, + } + try: + schema_map[node.schema] + except KeyError as e: + raise QQLRuntimeError( + "Unknown index type '" + f"{node.schema}'. Expected one of: keyword, integer, float, bool, text, geo, datetime, uuid" + ) from e + field_schema = self._build_payload_index_schema(node) + + try: + await self._client.create_payload_index( + collection_name=node.collection, + field_name=node.field_name, + field_schema=field_schema, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during CREATE INDEX: {e}") from e + + option_label = f" with options {node.options}" if node.options else "" + return ExecutionResult( + success=True, + message=( + f"Created index on '{node.collection}.{node.field_name}' as '{node.schema}'{option_label}" + ), + ) + + async def _execute_drop(self, node: DropCollectionStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + await self._client.delete_collection(node.collection) + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' dropped", + ) + + async def _execute_show(self, node: ShowCollectionsStmt) -> ExecutionResult: + response = await self._client.get_collections() + names = [c.name for c in response.collections] + return ExecutionResult( + success=True, + message=f"{len(names)} collection(s) found", + data=names, + ) + + async def _execute_show_collection(self, node: ShowCollectionStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + info = await self._client.get_collection(node.collection) + config = info.config + params = config.params + + vectors = params.vectors # type: ignore[union-attr] + sparse_vector_params = params.sparse_vectors or {} + if isinstance(vectors, dict): + vector_details = {} + for vname, vconfig in vectors.items(): + vector_details[vname] = { + "size": vconfig.size, + "distance": str(vconfig.distance) if vconfig.distance else None, + "on_disk": vconfig.on_disk, + } + elif vectors is None: + raise QQLRuntimeError( + f"Collection '{node.collection}' has no vector configuration" + ) + else: + vector_details = { + "": { + "size": vectors.size, + "distance": str(vectors.distance) if vectors.distance else None, + "on_disk": vectors.on_disk, + } + } + topology = "hybrid" if sparse_vector_params else "dense" + + sparse_vectors = {} + if sparse_vector_params: + for sname, sconfig in sparse_vector_params.items(): + sparse_vectors[sname] = { + "modifier": str(sconfig.modifier) if sconfig.modifier else None, + } + + quant_config = config.quantization_config + quantization = None + if quant_config is not None: + qtype = type(quant_config).__name__ + if hasattr(quant_config, "scalar"): + quantization = "scalar" + elif hasattr(quant_config, "binary"): + quantization = "binary" + elif hasattr(quant_config, "product"): + quantization = "product" + elif hasattr(quant_config, "turbo"): + quantization = "turbo" + else: + quantization = qtype + + hnsw = { + "m": config.hnsw_config.m, + "ef_construct": config.hnsw_config.ef_construct, + } + if config.hnsw_config.full_scan_threshold is not None: + hnsw["full_scan_threshold"] = config.hnsw_config.full_scan_threshold + if config.hnsw_config.max_indexing_threads is not None: + hnsw["max_indexing_threads"] = config.hnsw_config.max_indexing_threads + if config.hnsw_config.on_disk is not None: + hnsw["on_disk"] = config.hnsw_config.on_disk + if config.hnsw_config.payload_m is not None: + hnsw["payload_m"] = config.hnsw_config.payload_m + if config.hnsw_config.inline_storage is not None: + hnsw["inline_storage"] = config.hnsw_config.inline_storage + + payload_indexes = {} + for field_name, idx_info in (info.payload_schema or {}).items(): + payload_indexes[field_name] = self._serialize_payload_index_info(idx_info) + + sharding = { + "shard_number": params.shard_number, + "replication_factor": params.replication_factor, + "write_consistency_factor": params.write_consistency_factor, + "read_fan_out_factor": params.read_fan_out_factor, + "read_fan_out_delay_ms": params.read_fan_out_delay_ms, + "on_disk_payload": params.on_disk_payload, + } + + data = { + "name": node.collection, + "status": str(info.status), + "points_count": info.points_count, + "indexed_vectors_count": info.indexed_vectors_count, + "segments_count": info.segments_count, + "topology": topology, + "vectors": vector_details, + "sparse_vectors": sparse_vectors or None, + "quantization": quantization, + "hnsw_config": hnsw, + "payload_schema": payload_indexes or None, + "sharding": sharding, + } + + return ExecutionResult( + success=True, + message=f"Collection '{node.collection}' diagnostics", + data=data, + ) + + async def _execute_scroll(self, node: ScrollStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + scroll_filter: Filter | None = None + if node.query_filter is not None: + scroll_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + + try: + records, next_offset = await self._client.scroll( + collection_name=node.collection, + scroll_filter=scroll_filter, + limit=node.limit, + offset=node.after, + with_payload=True, + with_vectors=False, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SCROLL: {e}") from e + + points = [ + {"id": str(rec.id), "payload": rec.payload or {}} + for rec in records + ] + return ExecutionResult( + success=True, + message=f"Scrolled {len(points)} point(s) from '{node.collection}'", + data={"points": points, "next_offset": next_offset}, + ) + + async def _execute_select(self, node: SelectStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + try: + records = await self._client.retrieve( + collection_name=node.collection, + ids=[node.point_id], + with_payload=True, + with_vectors=False, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SELECT: {e}") from e + + if not records: + return ExecutionResult( + success=True, + message=f"Point '{node.point_id}' not found in '{node.collection}'", + ) + + record = records[0] + return ExecutionResult( + success=True, + message=f"Retrieved point '{node.point_id}' from '{node.collection}'", + data={"id": str(record.id), "payload": record.payload or {}}, + ) + + async def _execute_search(self, node: SearchStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + topology = await self._resolve_topology(node.collection) + + qdrant_filter: Filter | None = None + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + + search_params = self._build_search_params(node.with_clause) + validate_search_mmr_usage(node) + + fetch_limit = node.limit * _RERANK_FETCH_MULTIPLIER if node.rerank else node.limit + + lookup_from: LookupLocation | None = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + if node.group_by is not None: + return await self._execute_search_groups( + node, qdrant_filter, search_params, topology + ) + + if node.hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_vector, sparse_vector = await self._build_hybrid_vectors( + node.query_text, dense_model, sparse_model_name + ) + + try: + response = await self._client.query_points( + collection_name=node.collection, + prefetch=[ + Prefetch( + query=build_dense_query(dense_vector, node.with_clause), + using=topology.dense_using(node.dense_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + Prefetch( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + ], + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), + limit=fetch_limit, + offset=node.offset or None, + query_filter=qdrant_filter, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SEARCH: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (hybrid, reranked)", + data=results, + ) + + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (hybrid)", + data=results, + ) + + if node.sparse_only: + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + sparse_embedder = SparseEmbedder(sparse_model_name) + sparse_obj = sparse_embedder.query_embed(node.query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + + try: + response = await self._client.query_points( + collection_name=node.collection, + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=fetch_limit, + offset=node.offset or None, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SEARCH: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (sparse, reranked)", + data=results, + ) + + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (sparse)", + data=results, + ) + + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + vector = embedder.embed(node.query_text) + + try: + query_using = topology.dense_using(node.dense_vector) + response = await self._client.query_points( + collection_name=node.collection, + query=build_dense_query(vector, node.with_clause), + using=query_using, + limit=fetch_limit, + offset=node.offset or None, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during SEARCH: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (reranked)", + data=results, + ) + + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s)", + data=results, + ) + + async def _execute_search_groups( + self, + node: SearchStmt, + qdrant_filter: Filter | None, + search_params: SearchParams | None, + topology: CollectionTopology, + ) -> ExecutionResult: + lookup_from: LookupLocation | None = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + try: + if node.hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_vector, sparse_vector = await self._build_hybrid_vectors( + node.query_text, dense_model, sparse_model_name + ) + response = await self._client.query_points_groups( + collection_name=node.collection, + group_by=node.group_by, + prefetch=[ + Prefetch( + query=build_dense_query(dense_vector, node.with_clause), + using=topology.dense_using(node.dense_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + Prefetch( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + ], + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), + limit=node.limit, + group_size=node.group_size, + query_filter=qdrant_filter, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + label = "hybrid, grouped" + elif node.sparse_only: + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + sparse_obj = SparseEmbedder(sparse_model_name).query_embed(node.query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + response = await self._client.query_points_groups( + collection_name=node.collection, + group_by=node.group_by, + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=node.limit, + group_size=node.group_size, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + label = "sparse, grouped" + else: + model_name = node.model or self._config.default_model + vector = Embedder(model_name).embed(node.query_text) + query_using = topology.dense_using(node.dense_vector) + response = await self._client.query_points_groups( + collection_name=node.collection, + group_by=node.group_by, + query=build_dense_query(vector, node.with_clause), + using=query_using, + limit=node.limit, + group_size=node.group_size, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + ) + label = "grouped" + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during GROUP BY SEARCH: {e}") from e + + groups = [ + { + "group_id": str(g.id), + "hits": [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in g.hits + ], + } + for g in response.groups + ] + return ExecutionResult( + success=True, + message=f"Found {len(groups)} group(s) by '{node.group_by}' ({label})", + data=groups, + ) + + async def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + qdrant_filter: Filter | None = None + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + qdrant_filter = exclude_ids_from_filter( + qdrant_filter, + [*node.positive_ids, *node.negative_ids], + ) + + recommend_input = RecommendInput( + positive=list(node.positive_ids), + negative=list(node.negative_ids) or None, + strategy=parse_recommend_strategy(node.strategy), + ) + + search_params = self._build_search_params(node.with_clause) + if has_mmr(node.with_clause): + raise QQLRuntimeError("MMR is supported only for SEARCH statements") + + lookup_from: LookupLocation | None = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + try: + response = await self._client.query_points( + collection_name=node.collection, + query=RecommendQuery(recommend=recommend_input), + limit=node.limit, + offset=node.offset or None, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + using=node.using, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during RECOMMEND: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + return ExecutionResult( + success=True, + message=f"Found {len(results)} recommendation(s)", + data=results, + ) + + async def _execute_delete(self, node: DeleteStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + try: + if node.query_filter is not None: + await self._client.delete( + collection_name=node.collection, + wait=True, + points_selector=self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ), + ) + return ExecutionResult( + success=True, + message=f"Deleted points from '{node.collection}' by filter", + ) + + from qdrant_client.models import PointIdsList + + if node.point_id is None: + raise QQLRuntimeError("DELETE requires either a point id or a filter") + + await self._client.delete( + collection_name=node.collection, + wait=True, + points_selector=PointIdsList(points=[node.point_id]), + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during DELETE: {e}") from e + + return ExecutionResult( + success=True, + message=f"Deleted point '{node.point_id}' from '{node.collection}'", + ) + + async def _execute_update_vector(self, node: UpdateVectorStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + topology = await self._resolve_topology(node.collection) + vector_name = topology.dense_payload_name(node.vector_name) + vector_struct: Any = ( + {vector_name: list(node.vector)} if vector_name else list(node.vector) + ) + try: + await self._client.update_vectors( + collection_name=node.collection, + points=[PointVectors(id=node.point_id, vector=vector_struct)], + wait=True, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during UPDATE VECTOR: {e}") from e + return ExecutionResult( + success=True, + message=f"Updated vector for point [{node.point_id}] in '{node.collection}'", + data=[], + ) + + async def _execute_update_payload(self, node: UpdatePayloadStmt) -> ExecutionResult: + if not await self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + try: + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + await self._client.set_payload( + collection_name=node.collection, + payload=node.payload, + points=qdrant_filter, + wait=True, + ) + return ExecutionResult( + success=True, + message=f"Payload updated in '{node.collection}' (filter-based)", + data=[], + ) + await self._client.set_payload( + collection_name=node.collection, + payload=node.payload, + points=[node.point_id], + wait=True, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during UPDATE PAYLOAD: {e}") from e + return ExecutionResult( + success=True, + message=f"Payload updated for point [{node.point_id}] in '{node.collection}'", + data=[], + ) + + async def _execute_batch_block(self, node: BatchBlockStmt) -> ExecutionResult: + if not node.statements: + return ExecutionResult(success=True, message="Executed empty batch", data=[]) + + all_results = [] + succeeded_count = 0 + + for group in group_batch_statements(node.statements): + if group.kind == 'query': + res = await self._execute_query_batch(group.collection, group.statements) + all_results.extend(res) + succeeded_count += len([r for r in res if r.success]) + elif group.kind == 'insert': + bulk_node = build_bulk_insert_from_group( + group.collection, + group.statements, + ) + res = await self._execute_insert_bulk(bulk_node) + insert_results = inserted_point_results( + res, + group.statements, + ExecutionResult, + ) + all_results.extend(insert_results) + succeeded_count += len([r for r in insert_results if r.success]) + else: + for s in group.statements: + res = await self.execute(s) + all_results.append(res) + if res.success: + succeeded_count += 1 + + total_stmts = len(node.statements) + return ExecutionResult( + success=succeeded_count == total_stmts, + message=f"Batch executed {succeeded_count}/{total_stmts} statement(s) successfully", + data=all_results, + ) + + async def _execute_query_batch( + self, + collection_name: str, + nodes: list[SearchStmt | RecommendStmt], + ) -> list[ExecutionResult]: + if not await self._client.collection_exists(collection_name): + raise QQLRuntimeError(f"Collection '{collection_name}' does not exist") + + topology = await self._resolve_topology(collection_name) + requests = [] + + for node in nodes: + qdrant_filter = None + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + + search_params = self._build_search_params(node.with_clause) + + lookup_from = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + if isinstance(node, SearchStmt): + validate_search_mmr_usage(node) + fetch_limit = node.limit * _RERANK_FETCH_MULTIPLIER if node.rerank else node.limit + + if node.hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_vector, sparse_vector = await self._build_hybrid_vectors( + node.query_text, dense_model, sparse_model_name + ) + + req = QueryRequest( + prefetch=[ + Prefetch( + query=build_dense_query(dense_vector, node.with_clause), + using=topology.dense_using(node.dense_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + Prefetch( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + ], + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + elif node.sparse_only: + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + sparse_embedder = SparseEmbedder(sparse_model_name) + sparse_obj = sparse_embedder.query_embed(node.query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + + req = QueryRequest( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + else: + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + vector = embedder.embed(node.query_text) + query_using = topology.dense_using(node.dense_vector) + + req = QueryRequest( + query=build_dense_query(vector, node.with_clause), + using=query_using, + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + else: + qdrant_filter = exclude_ids_from_filter( + qdrant_filter, + [*node.positive_ids, *node.negative_ids], + ) + recommend_input = RecommendInput( + positive=list(node.positive_ids), + negative=list(node.negative_ids) or None, + strategy=parse_recommend_strategy(node.strategy), + ) + if has_mmr(node.with_clause): + raise QQLRuntimeError("MMR is supported only for SEARCH statements") + + req = QueryRequest( + query=RecommendQuery(recommend=recommend_input), + limit=node.limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + using=node.using, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + + requests.append(req) + + try: + responses = await self._client.query_batch_points( + collection_name=collection_name, + requests=requests, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during Batch Query: {e}") from e + + execution_results = [] + for i, response in enumerate(responses): + node = nodes[i] + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if isinstance(node, SearchStmt) and node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + label = "hybrid, reranked" if node.hybrid else ("sparse, reranked" if node.sparse_only else "reranked") + msg = f"Found {len(results)} result(s) ({label})" + else: + if isinstance(node, SearchStmt): + label = "hybrid" if node.hybrid else ("sparse" if node.sparse_only else "") + label_suffix = f" ({label})" if label else "" + msg = f"Found {len(results)} result(s){label_suffix}" + else: + msg = f"Found {len(results)} recommendation(s)" + + execution_results.append( + ExecutionResult(success=True, message=msg, data=results) + ) + + return execution_results diff --git a/src/qql/connection.py b/src/qql/connection.py index e51f19f..2a21721 100644 --- a/src/qql/connection.py +++ b/src/qql/connection.py @@ -1,9 +1,11 @@ from __future__ import annotations +from typing import Any from .config import DEFAULT_MODEL, QQLConfig from .executor import Executor, ExecutionResult from .lexer import Lexer from .parser import Parser +from .utils import render_parameterized_query class Connection: @@ -51,6 +53,8 @@ def __init__( url: str = "http://localhost:6333", secret: str | None = None, default_model: str | None = None, + prefer_grpc: bool = False, + grpc_port: int = 6334, ) -> None: """Create a connection to a Qdrant instance. @@ -60,6 +64,8 @@ def __init__( default_model: Dense embedding model used when no ``USING MODEL`` clause is specified. Defaults to ``sentence-transformers/all-MiniLM-L6-v2``. + prefer_grpc: Whether to connect via fast gRPC transport. + grpc_port: The gRPC port of Qdrant instance (default: 6334). """ from qdrant_client import QdrantClient @@ -68,7 +74,11 @@ def __init__( secret=secret, default_model=default_model or DEFAULT_MODEL, ) - self._client = QdrantClient(url=url, api_key=secret) + client_kwargs = {"url": url, "api_key": secret} + if prefer_grpc: + client_kwargs["prefer_grpc"] = True + client_kwargs["grpc_port"] = grpc_port + self._client = QdrantClient(**client_kwargs) self._executor = Executor(self._client, self._config) # ── Public API ──────────────────────────────────────────────────────── @@ -92,6 +102,38 @@ def run_query(self, query: str) -> ExecutionResult: node = Parser(tokens).parse() return self._executor.execute(node) + def run_queries_batch(self, queries: list[str]) -> list[ExecutionResult]: + """Parse and execute a batch of QQL statements. + + Combines compatible operations (such as SEARCH queries) to execute in + a single network request. + """ + from .ast_nodes import BatchBlockStmt + nodes = [] + for q in queries: + tokens = Lexer().tokenize(q) + node = Parser(tokens).parse() + nodes.append(node) + + batch_node = BatchBlockStmt(statements=tuple(nodes)) + res = self._executor.execute(batch_node) + return res.data + + def run_parameterized_query(self, template: str, params: dict[str, Any]) -> ExecutionResult: + """Execute one QQL query template with named parameters. + + Uses named placeholders prefixed with ':' (e.g. :query, :category). + """ + return self.run_query(render_parameterized_query(template, params)) + + def run_parameterized_batch(self, template: str, params: list[dict[str, Any]]) -> list[ExecutionResult]: + """Execute a single QQL query template with a batch of parameters. + + Uses named placeholders prefixed with ':' (e.g. :query, :category). + """ + queries = [render_parameterized_query(template, p) for p in params] + return self.run_queries_batch(queries) + def close(self) -> None: """Close the underlying Qdrant HTTP connection pool. @@ -130,3 +172,68 @@ def executor(self) -> Executor: result = conn.executor.execute(node) """ return self._executor + + +class QQLBatch: + """Session context manager for executing batch queries and mutations in QQL.""" + + def __init__(self, connection: Connection) -> None: + self.connection = connection + self._queries: list[str] = [] + self._proxies: list[OperationProxy] = [] + + def add(self, query: str) -> OperationProxy: + """Queue a QQL statement for batch execution.""" + self._queries.append(query) + proxy = OperationProxy() + self._proxies.append(proxy) + return proxy + + def __enter__(self) -> QQLBatch: + self._queries.clear() + self._proxies.clear() + return self + + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + try: + if exc_type is not None: + return + if not self._queries: + return + results = self.connection.run_queries_batch(self._queries) + if len(results) != len(self._proxies): + error = RuntimeError( + "Batch result count mismatch: " + f"expected {len(self._proxies)}, got {len(results)}" + ) + for proxy in self._proxies: + proxy._reject(error) + raise error + for proxy, res in zip(self._proxies, results, strict=True): + proxy._resolve(res) + finally: + self._queries.clear() + self._proxies.clear() + + +class OperationProxy: + """Proxy handle that resolves to an ExecutionResult after QQLBatch exits.""" + + def __init__(self) -> None: + self._result: ExecutionResult | None = None + self._exception: RuntimeError | None = None + + def _resolve(self, result: ExecutionResult) -> None: + self._result = result + + def _reject(self, exception: RuntimeError) -> None: + self._exception = exception + + @property + def result(self) -> ExecutionResult: + """The resolved ExecutionResult.""" + if self._exception is not None: + raise self._exception + if self._result is None: + raise RuntimeError("Batch has not been executed yet.") + return self._result diff --git a/src/qql/executor.py b/src/qql/executor.py index 975947d..fa6f6f6 100644 --- a/src/qql/executor.py +++ b/src/qql/executor.py @@ -1,7 +1,6 @@ from __future__ import annotations import time -import uuid from dataclasses import dataclass from typing import Any @@ -15,29 +14,15 @@ CompressionRatio, Distance, Disabled, - FieldCondition, Filter, - Fusion, FusionQuery, - HasIdCondition, HnswConfigDiff, - IsEmptyCondition, - IsNullCondition, KeywordIndexParams, KeywordIndexType, Language, LookupLocation, - MatchAny, - MatchExcept, - MatchPhrase, - MatchText, - MatchTextAny, - MatchValue, - Mmr, Modifier, - NearestQuery, OptimizersConfigDiff, - PayloadField, PayloadSchemaType, PointStruct, PointVectors, @@ -45,10 +30,9 @@ ProductQuantization, ProductQuantizationConfig, QuantizationSearchParams, - Range, RecommendInput, RecommendQuery, - RecommendStrategy, + QueryRequest, ScalarQuantization, ScalarQuantizationConfig, ScalarType, @@ -72,28 +56,14 @@ from .ast_nodes import ( ASTNode, AlterCollectionStmt, - AndExpr, - BetweenExpr, CollectionConfig, - CompareExpr, CreateCollectionStmt, CreateIndexStmt, DeleteStmt, DropCollectionStmt, FilterExpr, - InExpr, InsertBulkStmt, InsertStmt, - IsEmptyExpr, - IsNotEmptyExpr, - IsNotNullExpr, - IsNullExpr, - MatchAnyExpr, - MatchPhraseExpr, - MatchTextExpr, - NotExpr, - NotInExpr, - OrExpr, QuantizationUpdate, QuantizationConfig, QuantizationType, @@ -106,10 +76,27 @@ ShowCollectionsStmt, UpdateVectorStmt, UpdatePayloadStmt, + BatchBlockStmt, ) from .config import QQLConfig from .embedder import CrossEncoderEmbedder, Embedder, SparseEmbedder from .exceptions import QQLRuntimeError +from .utils import ( + build_bulk_insert_from_group, + build_dense_point_vector, + build_dense_query, + build_qdrant_filter, + collection_topology_kwargs, + exclude_ids_from_filter, + extract_point_id_and_payload, + group_batch_statements, + has_mmr, + inserted_point_results, + parse_recommend_strategy, + resolve_hybrid_fusion, + validate_search_mmr_usage, + wrap_as_filter, +) _RERANK_FETCH_MULTIPLIER = 4 _HYBRID_PREFETCH_MULTIPLIER = 4 @@ -236,6 +223,8 @@ def execute(self, node: ASTNode) -> ExecutionResult: return self._execute_update_vector(node) if isinstance(node, UpdatePayloadStmt): return self._execute_update_payload(node) + if isinstance(node, BatchBlockStmt): + return self._execute_batch_block(node) raise QQLRuntimeError(f"Unknown AST node type: {type(node)}") # ── Statement executors ─────────────────────────────────────────────── @@ -255,6 +244,10 @@ def _fetch_collection_info(self, name: str): raise QQLRuntimeError( f"Qdrant error fetching collection '{name}': {e}" ) from e + except ValueError as e: + if f"Collection {name} not found" in str(e): + return None + raise def _topology_from_collection_info(self, info: Any) -> CollectionTopology: """Parse a CollectionInfo object into a :class:`CollectionTopology`. @@ -265,40 +258,7 @@ def _topology_from_collection_info(self, info: Any) -> CollectionTopology: params = info.config.params vectors = params.vectors # type: ignore[union-attr] sparse_vectors = params.sparse_vectors or {} - - if isinstance(vectors, dict): - dense_names = tuple(vectors.keys()) - dense_sizes: tuple[tuple[str, int], ...] = tuple( - (k, v.size) - for k, v in vectors.items() - if getattr(v, "size", None) is not None - ) - has_unnamed_dense = False - is_named_dense = True - elif vectors is None: - dense_names = () - dense_sizes = () - has_unnamed_dense = False - is_named_dense = False - else: - # Single unnamed dense vector - dense_names = () - unnamed_size = getattr(vectors, "size", None) - dense_sizes = (("", unnamed_size),) if unnamed_size is not None else () - has_unnamed_dense = True - is_named_dense = False - - sparse_names = ( - tuple(sparse_vectors.keys()) if isinstance(sparse_vectors, dict) else () - ) - return CollectionTopology( - exists=True, - is_named_dense=is_named_dense, - has_unnamed_dense=has_unnamed_dense, - dense_names=dense_names, - sparse_names=sparse_names, - dense_sizes=dense_sizes, - ) + return CollectionTopology(**collection_topology_kwargs(vectors, sparse_vectors)) def _resolve_topology(self, name: str) -> CollectionTopology: """Return the topology for *name* using exactly one Qdrant API call. @@ -362,7 +322,7 @@ def _execute_insert(self, node: InsertStmt) -> ExecutionResult: }, ) - point_id, payload = self._extract_point_id_and_payload(node.values) + point_id, payload = extract_point_id_and_payload(node.values) try: self._client.upsert( collection_name=node.collection, @@ -392,9 +352,14 @@ def _execute_insert(self, node: InsertStmt) -> ExecutionResult: self._ensure_collection( node.collection, len(vector), topology, node.dense_vector ) - point_vector = self._build_dense_point_vector(topology, vector, node.dense_vector) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), + ) - point_id, payload = self._extract_point_id_and_payload(node.values) + point_id, payload = extract_point_id_and_payload(node.values) try: self._client.upsert( @@ -443,7 +408,7 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: first_dense_vector: list[float] | None = None points: list[PointStruct] = [] for vals in node.values_list: - point_id, payload = self._extract_point_id_and_payload(vals) + point_id, payload = extract_point_id_and_payload(vals) dense_vector = dense_embedder.embed(vals["text"]) if first_dense_vector is None: first_dense_vector = dense_vector @@ -483,6 +448,7 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: return ExecutionResult( success=True, message=f"Inserted {len(points)} points (hybrid)", + data={"ids": [p.id for p in points]}, ) # ── Standard dense-only bulk INSERT ─────────────────────────────── @@ -495,9 +461,12 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: vector = embedder.embed(vals["text"]) if first_vector is None: first_vector = vector - point_id, payload = self._extract_point_id_and_payload(vals) - point_vector = self._build_dense_point_vector( - topology, vector, node.dense_vector + point_id, payload = extract_point_id_and_payload(vals) + point_vector = build_dense_point_vector( + topology, + vector, + node.dense_vector, + self._default_dense_vector_name(), ) points.append( PointStruct(id=point_id, vector=point_vector, payload=payload) @@ -520,6 +489,7 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: return ExecutionResult( success=True, message=f"Inserted {len(points)} points", + data={"ids": [p.id for p in points]}, ) def _execute_create(self, node: CreateCollectionStmt) -> ExecutionResult: @@ -889,7 +859,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: ) search_params = self._build_search_params(node.with_clause) - self._validate_search_mmr_usage(node) + validate_search_mmr_usage(node) # When reranking is requested, fetch more candidates so the reranker has # enough material to reorder; only `node.limit` results are returned. @@ -921,7 +891,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: collection_name=node.collection, prefetch=[ Prefetch( - query=self._build_dense_query(dense_vector, node.with_clause), + query=build_dense_query(dense_vector, node.with_clause), using=topology.dense_using(node.dense_vector), limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, params=search_params, @@ -933,7 +903,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: params=search_params, ), ], - query=FusionQuery(fusion=self._resolve_hybrid_fusion(node.fusion)), + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), limit=fetch_limit, offset=node.offset or None, query_filter=qdrant_filter, @@ -1015,7 +985,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: query_using = topology.dense_using(node.dense_vector) response = self._client.query_points( collection_name=node.collection, - query=self._build_dense_query(vector, node.with_clause), + query=build_dense_query(vector, node.with_clause), using=query_using, limit=fetch_limit, offset=node.offset or None, @@ -1066,15 +1036,6 @@ def _build_hybrid_vectors( ) return dense_vector, sparse_vector - def _resolve_hybrid_fusion(self, fusion: str | None) -> Fusion: - if fusion is None or fusion == "rrf": - return Fusion.RRF - if fusion == "dbsf": - return Fusion.DBSF - raise QQLRuntimeError( - f"Unsupported hybrid fusion '{fusion}'; expected 'rrf' or 'dbsf'" - ) - def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: if not self._client.collection_exists(node.collection): raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") @@ -1084,7 +1045,7 @@ def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: qdrant_filter = self._wrap_as_filter( self._build_qdrant_filter(node.query_filter) ) - qdrant_filter = self._exclude_ids_from_filter( + qdrant_filter = exclude_ids_from_filter( qdrant_filter, [*node.positive_ids, *node.negative_ids], ) @@ -1092,11 +1053,11 @@ def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: recommend_input = RecommendInput( positive=list(node.positive_ids), negative=list(node.negative_ids) or None, - strategy=self._parse_recommend_strategy(node.strategy), + strategy=parse_recommend_strategy(node.strategy), ) search_params = self._build_search_params(node.with_clause) - if self._has_mmr(node.with_clause): + if has_mmr(node.with_clause): raise QQLRuntimeError("MMR is supported only for SEARCH statements") lookup_from: LookupLocation | None = None @@ -1501,107 +1462,6 @@ def _describe_quantization_update( return f", quantization={quantization.config.type.value}" return "" - def _has_mmr(self, with_clause: SearchWith | None) -> bool: - return with_clause is not None and ( - with_clause.mmr_diversity is not None or with_clause.mmr_candidates is not None - ) - - def _validate_search_mmr_usage(self, node: SearchStmt) -> None: - if not self._has_mmr(node.with_clause): - return - if node.sparse_only: - raise QQLRuntimeError("MMR is not supported with USING SPARSE yet") - - def _build_dense_query( - self, - vector: list[float], - with_clause: SearchWith | None, - ) -> list[float] | NearestQuery: - if not self._has_mmr(with_clause): - return vector - return NearestQuery( - nearest=vector, - mmr=Mmr( - diversity=with_clause.mmr_diversity, - candidates_limit=with_clause.mmr_candidates, - ), - ) - - def _parse_recommend_strategy( - self, strategy: str | None - ) -> RecommendStrategy | None: - if strategy is None: - return None - try: - return RecommendStrategy(strategy) - except ValueError as e: - raise QQLRuntimeError( - "Unknown recommend strategy " - f"'{strategy}'. Expected one of: average_vector, best_score, sum_scores" - ) from e - - def _exclude_ids_from_filter( - self, - query_filter: Filter | None, - point_ids: list[str | int], - ) -> Filter | None: - if not point_ids: - return query_filter - - exclude_condition = HasIdCondition(has_id=point_ids) - if query_filter is None: - return Filter(must_not=[exclude_condition]) - - return Filter( - must=list(query_filter.must or []), - should=list(query_filter.should or []), - must_not=[*(query_filter.must_not or []), exclude_condition], - min_should=query_filter.min_should, - ) - - def _extract_point_id_and_payload( - self, values: dict[str, Any] - ) -> tuple[str | int, dict[str, Any]]: - payload = dict(values) - if "id" not in payload: - return str(uuid.uuid4()), payload - - point_id = payload.pop("id") - if isinstance(point_id, bool): - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) - if isinstance(point_id, int): - if point_id < 0: - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) - return point_id, payload - if isinstance(point_id, str): - try: - uuid.UUID(point_id) - except ValueError as e: - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) from e - return point_id, payload - raise QQLRuntimeError( - "INSERT id must be an unsigned integer or UUID string when provided" - ) - - def _build_dense_point_vector( - self, - topology: CollectionTopology, - vector: list[float], - explicit_vector: str | None, - ) -> list[float] | dict[str, list[float]]: - if not topology.exists: - return {explicit_vector or self._default_dense_vector_name(): vector} - vector_name = topology.dense_payload_name(explicit_vector) - if vector_name is None: - return vector - return {vector_name: vector} - def _apply_reranking( self, query: str, @@ -1682,7 +1542,7 @@ def _execute_search_groups( group_by=node.group_by, prefetch=[ Prefetch( - query=self._build_dense_query(dense_vector, node.with_clause), + query=build_dense_query(dense_vector, node.with_clause), using=topology.dense_using(node.dense_vector), limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, params=search_params, @@ -1694,7 +1554,7 @@ def _execute_search_groups( params=search_params, ), ], - query=FusionQuery(fusion=self._resolve_hybrid_fusion(node.fusion)), + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), limit=node.limit, group_size=node.group_size, query_filter=qdrant_filter, @@ -1729,7 +1589,7 @@ def _execute_search_groups( response = self._client.query_points_groups( collection_name=node.collection, group_by=node.group_by, - query=self._build_dense_query(vector, node.with_clause), + query=build_dense_query(vector, node.with_clause), using=query_using, limit=node.limit, group_size=node.group_size, @@ -1815,98 +1675,221 @@ def _execute_update_payload(self, node: UpdatePayloadStmt) -> ExecutionResult: data=[], ) - # ── Filter conversion ───────────────────────────────────────────────── - - def _build_qdrant_filter(self, expr: FilterExpr) -> Any: - """Convert a FilterExpr AST node into a Qdrant model object. - - Returns one of: Filter, FieldCondition, IsNullCondition, IsEmptyCondition. - Use _wrap_as_filter() to guarantee the top-level result is a Filter. - """ - # ── Logical combinators ─────────────────────────────────────────── - if isinstance(expr, AndExpr): - return Filter(must=[self._build_qdrant_filter(op) for op in expr.operands]) - - if isinstance(expr, OrExpr): - return Filter(should=[self._build_qdrant_filter(op) for op in expr.operands]) - - if isinstance(expr, NotExpr): - return Filter(must_not=[self._build_qdrant_filter(expr.operand)]) - - # ── Comparison ──────────────────────────────────────────────────── - if isinstance(expr, CompareExpr): - if expr.op == "=": - return FieldCondition( - key=expr.field, match=MatchValue(value=expr.value) + def _execute_batch_block(self, node: BatchBlockStmt) -> ExecutionResult: + if not node.statements: + return ExecutionResult(success=True, message="Executed empty batch", data=[]) + + all_results = [] + succeeded_count = 0 + + for group in group_batch_statements(node.statements): + if group.kind == 'query': + res = self._execute_query_batch(group.collection, group.statements) + all_results.extend(res) + succeeded_count += len([r for r in res if r.success]) + elif group.kind == 'insert': + bulk_node = build_bulk_insert_from_group( + group.collection, + group.statements, ) - if expr.op == "!=": - return Filter( - must_not=[ - FieldCondition(key=expr.field, match=MatchValue(value=expr.value)) - ] + res = self._execute_insert_bulk(bulk_node) + insert_results = inserted_point_results( + res, + group.statements, + ExecutionResult, ) - _range_key = {">": "gt", ">=": "gte", "<": "lt", "<=": "lte"}[expr.op] - return FieldCondition( - key=expr.field, range=Range(**{_range_key: expr.value}) - ) - - # ── BETWEEN ─────────────────────────────────────────────────────── - if isinstance(expr, BetweenExpr): - return FieldCondition( - key=expr.field, range=Range(gte=expr.low, lte=expr.high) - ) - - # ── IN / NOT IN ─────────────────────────────────────────────────── - if isinstance(expr, InExpr): - return FieldCondition( - key=expr.field, match=MatchAny(any=list(expr.values)) - ) - - if isinstance(expr, NotInExpr): - return FieldCondition( - key=expr.field, - match=MatchExcept(**{"except": list(expr.values)}), - ) - - # ── IS NULL / IS NOT NULL ───────────────────────────────────────── - if isinstance(expr, IsNullExpr): - return IsNullCondition(is_null=PayloadField(key=expr.field)) + all_results.extend(insert_results) + succeeded_count += len([r for r in insert_results if r.success]) + else: + for s in group.statements: + res = self.execute(s) + all_results.append(res) + if res.success: + succeeded_count += 1 + + total_stmts = len(node.statements) + return ExecutionResult( + success=succeeded_count == total_stmts, + message=f"Batch executed {succeeded_count}/{total_stmts} statement(s) successfully", + data=all_results, + ) - if isinstance(expr, IsNotNullExpr): - return Filter( - must_not=[IsNullCondition(is_null=PayloadField(key=expr.field))] + def _execute_query_batch( + self, + collection_name: str, + nodes: list[SearchStmt | RecommendStmt], + ) -> list[ExecutionResult]: + if not self._client.collection_exists(collection_name): + raise QQLRuntimeError(f"Collection '{collection_name}' does not exist") + + topology = self._resolve_topology(collection_name) + requests = [] + + for node in nodes: + qdrant_filter = None + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + + search_params = self._build_search_params(node.with_clause) + + lookup_from = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + if isinstance(node, SearchStmt): + validate_search_mmr_usage(node) + fetch_limit = node.limit * _RERANK_FETCH_MULTIPLIER if node.rerank else node.limit + + if node.hybrid: + dense_model = node.model or self._config.default_model + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + dense_vector, sparse_vector = self._build_hybrid_vectors( + node.query_text, dense_model, sparse_model_name + ) + + req = QueryRequest( + prefetch=[ + Prefetch( + query=build_dense_query(dense_vector, node.with_clause), + using=topology.dense_using(node.dense_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + Prefetch( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=node.limit * _HYBRID_PREFETCH_MULTIPLIER, + params=search_params, + ), + ], + query=FusionQuery(fusion=resolve_hybrid_fusion(node.fusion)), + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + elif node.sparse_only: + sparse_model_name = node.sparse_model or SparseEmbedder.DEFAULT_MODEL + sparse_embedder = SparseEmbedder(sparse_model_name) + sparse_obj = sparse_embedder.query_embed(node.query_text) + sparse_vector = SparseVector( + indices=sparse_obj["indices"], + values=sparse_obj["values"], + ) + + req = QueryRequest( + query=sparse_vector, + using=topology.sparse_using(node.sparse_vector), + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + else: + model_name = node.model or self._config.default_model + embedder = Embedder(model_name) + vector = embedder.embed(node.query_text) + query_using = topology.dense_using(node.dense_vector) + + req = QueryRequest( + query=build_dense_query(vector, node.with_clause), + using=query_using, + limit=fetch_limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + else: + qdrant_filter = exclude_ids_from_filter( + qdrant_filter, + [*node.positive_ids, *node.negative_ids], + ) + recommend_input = RecommendInput( + positive=list(node.positive_ids), + negative=list(node.negative_ids) or None, + strategy=parse_recommend_strategy(node.strategy), + ) + if has_mmr(node.with_clause): + raise QQLRuntimeError("MMR is supported only for SEARCH statements") + + req = QueryRequest( + query=RecommendQuery(recommend=recommend_input), + limit=node.limit, + offset=node.offset or None, + filter=qdrant_filter, + params=search_params, + score_threshold=node.score_threshold, + using=node.using, + lookup_from=lookup_from, + with_payload=True, + with_vector=False, + ) + + requests.append(req) + + try: + responses = self._client.query_batch_points( + collection_name=collection_name, + requests=requests, ) - - # ── IS EMPTY / IS NOT EMPTY ─────────────────────────────────────── - if isinstance(expr, IsEmptyExpr): - return IsEmptyCondition(is_empty=PayloadField(key=expr.field)) - - if isinstance(expr, IsNotEmptyExpr): - return Filter( - must_not=[IsEmptyCondition(is_empty=PayloadField(key=expr.field))] + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during Batch Query: {e}") from e + + execution_results = [] + for i, response in enumerate(responses): + node = nodes[i] + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + if isinstance(node, SearchStmt) and node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + label = "hybrid, reranked" if node.hybrid else ("sparse, reranked" if node.sparse_only else "reranked") + msg = f"Found {len(results)} result(s) ({label})" + else: + if isinstance(node, SearchStmt): + label = "hybrid" if node.hybrid else ("sparse" if node.sparse_only else "") + label_suffix = f" ({label})" if label else "" + msg = f"Found {len(results)} result(s){label_suffix}" + else: + msg = f"Found {len(results)} recommendation(s)" + + execution_results.append( + ExecutionResult(success=True, message=msg, data=results) ) + + return execution_results - # ── Full-text MATCH ─────────────────────────────────────────────── - if isinstance(expr, MatchTextExpr): - return FieldCondition(key=expr.field, match=MatchText(text=expr.text)) - - if isinstance(expr, MatchAnyExpr): - return FieldCondition( - key=expr.field, match=MatchTextAny(text_any=expr.text) - ) + # ── Filter conversion ───────────────────────────────────────────────── - if isinstance(expr, MatchPhraseExpr): - return FieldCondition( - key=expr.field, match=MatchPhrase(phrase=expr.text) - ) + def _build_qdrant_filter(self, expr: FilterExpr) -> Any: + """Convert a FilterExpr AST node into a Qdrant model object. - raise QQLRuntimeError(f"Unknown filter expression type: {type(expr)}") + Returns one of: Filter, FieldCondition, IsNullCondition, IsEmptyCondition. + Use _wrap_as_filter() to guarantee the top-level result is a Filter. + """ + return build_qdrant_filter(expr) def _wrap_as_filter(self, qdrant_expr: Any) -> Filter: """Ensure the top-level expression is a Filter (required by query_points).""" - if isinstance(qdrant_expr, Filter): - return qdrant_expr - return Filter(must=[qdrant_expr]) + return wrap_as_filter(qdrant_expr) # ── Collection helpers ──────────────────────────────────────────────── diff --git a/src/qql/lexer.py b/src/qql/lexer.py index 39babee..1df588d 100644 --- a/src/qql/lexer.py +++ b/src/qql/lexer.py @@ -71,6 +71,10 @@ class TokenKind(Enum): UPDATE = auto() SET = auto() PAYLOAD = auto() + # ── Batch keywords ──────────────────────────────────────────────────── + BEGIN = auto() + BATCH = auto() + END = auto() # ── Filter keywords ─────────────────────────────────────────────────── AND = auto() OR = auto() @@ -99,6 +103,7 @@ class TokenKind(Enum): COLON = auto() COMMA = auto() EQUALS = auto() + SEMICOLON = auto() # ── Comparison operators ────────────────────────────────────────────── NOT_EQUALS = auto() # != GT = auto() # > @@ -176,6 +181,9 @@ class TokenKind(Enum): "UPDATE": TokenKind.UPDATE, "SET": TokenKind.SET, "PAYLOAD": TokenKind.PAYLOAD, + "BEGIN": TokenKind.BEGIN, + "BATCH": TokenKind.BATCH, + "END": TokenKind.END, # Filter keywords "AND": TokenKind.AND, "OR": TokenKind.OR, @@ -239,6 +247,9 @@ def tokenize(self, query: str) -> list[Token]: elif ch == ",": tokens.append(Token(TokenKind.COMMA, ",", i)) i += 1 + elif ch == ";": + tokens.append(Token(TokenKind.SEMICOLON, ";", i)) + i += 1 # ── Comparison operators (multi-char look-ahead) ────────────── elif ch == "=": diff --git a/src/qql/parser.py b/src/qql/parser.py index b62b9b0..3c3b2f0 100644 --- a/src/qql/parser.py +++ b/src/qql/parser.py @@ -42,9 +42,16 @@ UpdatePayloadStmt, VectorsConfig, HnswRuntimeConfig, + BatchBlockStmt, ) from .exceptions import QQLSyntaxError from .lexer import Token, TokenKind +from .utils import ( + parse_search_group_by, + parse_search_lookup, + parse_search_using, + parse_search_with, +) # Comparison operator token → string symbol mapping _CMP_OPS: dict[TokenKind, str] = { @@ -56,9 +63,6 @@ TokenKind.LTE: "<=", } -_HYBRID_FUSION_VALUES = {"rrf", "dbsf"} - - class Parser: def __init__(self, tokens: list[Token]) -> None: self._tokens = tokens @@ -67,36 +71,67 @@ def __init__(self, tokens: list[Token]) -> None: # ── Public entry point ──────────────────────────────────────────────── def parse(self) -> ASTNode: + node = self._parse_single_statement() + while self._peek().kind == TokenKind.SEMICOLON: + self._advance() + self._expect(TokenKind.EOF) + return node + + def _parse_single_statement(self) -> ASTNode: tok = self._peek() if tok.kind == TokenKind.INSERT: - node = self._parse_insert() + return self._parse_insert() elif tok.kind == TokenKind.CREATE: - node = self._parse_create() + return self._parse_create() elif tok.kind == TokenKind.ALTER: - node = self._parse_alter() + return self._parse_alter() elif tok.kind == TokenKind.DROP: - node = self._parse_drop() + return self._parse_drop() elif tok.kind == TokenKind.SHOW: - node = self._parse_show() + return self._parse_show() elif tok.kind == TokenKind.SCROLL: - node = self._parse_scroll() + return self._parse_scroll() elif tok.kind == TokenKind.SELECT: - node = self._parse_select() + return self._parse_select() elif tok.kind == TokenKind.SEARCH: - node = self._parse_search() + return self._parse_search() elif tok.kind == TokenKind.RECOMMEND: - node = self._parse_recommend() + return self._parse_recommend() elif tok.kind == TokenKind.DELETE: - node = self._parse_delete() + return self._parse_delete() elif tok.kind == TokenKind.UPDATE: - node = self._parse_update() + return self._parse_update() + elif tok.kind == TokenKind.BEGIN: + return self._parse_batch_block() else: raise QQLSyntaxError( f"Unexpected token '{tok.value}'; expected a QQL statement keyword", tok.pos, ) - self._expect(TokenKind.EOF) - return node + + def _parse_batch_block(self) -> BatchBlockStmt: + self._expect(TokenKind.BEGIN) + self._expect(TokenKind.BATCH) + statements = [] + while True: + while self._peek().kind == TokenKind.SEMICOLON: + self._advance() + + if self._peek().kind == TokenKind.END: + self._advance() + self._expect(TokenKind.BATCH) + break + + if self._peek().kind == TokenKind.EOF: + raise QQLSyntaxError("Unterminated batch block; expected END BATCH", self._peek().pos) + + stmt = self._parse_single_statement() + statements.append(stmt) + + if self._peek().kind == TokenKind.SEMICOLON: + self._advance() + + return BatchBlockStmt(statements=tuple(statements)) # ── Statement parsers ───────────────────────────────────────────────── @@ -696,80 +731,14 @@ def _parse_search(self) -> SearchStmt: self._expect(TokenKind.THRESHOLD) score_threshold = float(self._parse_number()) - lookup_from: tuple[str, str | None] | None = None - if self._peek().kind == TokenKind.LOOKUP: - self._advance() - self._expect(TokenKind.FROM) - lookup_collection = self._parse_identifier() - lookup_vector: str | None = None - if self._peek().kind == TokenKind.VECTOR: - self._advance() - lookup_vector = self._expect(TokenKind.STRING).value - lookup_from = (lookup_collection, lookup_vector) + lookup_from = parse_search_lookup(self) with_clause: SearchWith | None = None if self._peek().kind == TokenKind.EXACT: self._advance() with_clause = SearchWith(exact=True) - model: str | None = None - hybrid: bool = False - fusion: str | None = None - sparse_only: bool = False - sparse_model: str | None = None - dense_vector: str | None = None - sparse_vector: str | None = None - if self._peek().kind == TokenKind.USING: - self._advance() # consume USING - if self._peek().kind == TokenKind.HYBRID: - self._advance() # consume HYBRID - hybrid = True - # Optional FUSION / DENSE|SPARSE MODEL|VECTOR sub-clauses, any order. - while self._peek().kind in (TokenKind.FUSION, TokenKind.DENSE, TokenKind.SPARSE): - sub = self._advance() - if sub.kind == TokenKind.FUSION: - value_tok = self._expect(TokenKind.STRING) - fusion = value_tok.value.lower() - if fusion not in _HYBRID_FUSION_VALUES: - raise QQLSyntaxError( - f"Unsupported hybrid fusion '{value_tok.value}'; expected 'rrf' or 'dbsf'", - value_tok.pos, - ) - continue - if self._peek().kind == TokenKind.MODEL: - self._advance() - m = self._expect(TokenKind.STRING).value - if sub.kind == TokenKind.DENSE: - model = m - else: - sparse_model = m - elif self._peek().kind == TokenKind.VECTOR: - self._advance() - name = self._expect(TokenKind.STRING).value - if sub.kind == TokenKind.DENSE: - dense_vector = name - else: - sparse_vector = name - else: - raise QQLSyntaxError( - "Expected MODEL or VECTOR after DENSE/SPARSE in USING HYBRID", - self._peek().pos, - ) - elif self._peek().kind == TokenKind.SPARSE: - self._advance() # consume SPARSE - sparse_only = True - while self._peek().kind in (TokenKind.MODEL, TokenKind.VECTOR): - sub = self._advance() - if sub.kind == TokenKind.MODEL: - sparse_model = self._expect(TokenKind.STRING).value - else: - sparse_vector = self._expect(TokenKind.STRING).value - elif self._peek().kind == TokenKind.VECTOR: - self._advance() - dense_vector = self._expect(TokenKind.STRING).value - else: - self._expect(TokenKind.MODEL) - model = self._expect(TokenKind.STRING).value + using = parse_search_using(self) query_filter: FilterExpr | None = None if self._peek().kind == TokenKind.WHERE: self._advance() # consume WHERE @@ -783,79 +752,25 @@ def _parse_search(self) -> SearchStmt: self._advance() # consume MODEL rerank_model = self._expect(TokenKind.STRING).value - if self._peek().kind == TokenKind.EXACT: - self._advance() - if with_clause is None: - with_clause = SearchWith(exact=True) - else: - with_clause = SearchWith( - hnsw_ef=with_clause.hnsw_ef, - exact=True, - acorn=with_clause.acorn, - indexed_only=with_clause.indexed_only, - quantization=with_clause.quantization, - mmr_diversity=with_clause.mmr_diversity, - mmr_candidates=with_clause.mmr_candidates, - ) - - if self._peek().kind == TokenKind.WITH: - self._advance() # consume WITH - parsed_with = self._parse_with_clause() - if with_clause is None: - with_clause = parsed_with - else: - with_clause = SearchWith( - hnsw_ef=parsed_with.hnsw_ef or with_clause.hnsw_ef, - exact=parsed_with.exact or with_clause.exact, - acorn=parsed_with.acorn or with_clause.acorn, - indexed_only=parsed_with.indexed_only or with_clause.indexed_only, - quantization=parsed_with.quantization or with_clause.quantization, - mmr_diversity=( - parsed_with.mmr_diversity - if parsed_with.mmr_diversity is not None - else with_clause.mmr_diversity - ), - mmr_candidates=parsed_with.mmr_candidates or with_clause.mmr_candidates, - ) - group_by: str | None = None - group_size: int = 3 - if self._peek().kind == TokenKind.GROUP: - if offset > 0: - raise QQLSyntaxError("OFFSET cannot be used with GROUP BY", self._peek().pos) - self._advance() # consume GROUP - self._expect(TokenKind.BY) - group_by = self._parse_field_path() - if rerank: - raise QQLSyntaxError( - "GROUP BY and RERANK cannot be combined in the same SEARCH statement", - self._peek().pos, - ) - if self._peek().kind == TokenKind.GROUP_SIZE: - self._advance() # consume GROUP_SIZE - gs_tok = self._peek() - group_size = int(self._expect(TokenKind.INTEGER).value) - if group_size <= 0: - raise QQLSyntaxError( - f"GROUP_SIZE must be a positive integer, got {group_size}", - gs_tok.pos, - ) + with_clause = parse_search_with(self, with_clause) + group = parse_search_group_by(self, offset, rerank) return SearchStmt( collection=collection, query_text=query_text, limit=limit, - model=model, - hybrid=hybrid, - fusion=fusion, - sparse_only=sparse_only, - sparse_model=sparse_model, + model=using.model, + hybrid=using.hybrid, + fusion=using.fusion, + sparse_only=using.sparse_only, + sparse_model=using.sparse_model, query_filter=query_filter, rerank=rerank, rerank_model=rerank_model, with_clause=with_clause, - group_by=group_by, - group_size=group_size, - dense_vector=dense_vector, - sparse_vector=sparse_vector, + group_by=group.group_by, + group_size=group.group_size, + dense_vector=using.dense_vector, + sparse_vector=using.sparse_vector, offset=offset, score_threshold=score_threshold, lookup_from=lookup_from, @@ -1165,12 +1080,15 @@ def _parse_field_path(self) -> str: f"Expected a field name, got '{tok.value}'", tok.pos ) - def _parse_literal(self) -> str | int | float | bool: - """STRING | INTEGER | FLOAT | boolean""" + def _parse_literal(self) -> str | int | float | bool | None: + """STRING | INTEGER | FLOAT | boolean | NULL""" tok = self._peek() if tok.kind == TokenKind.STRING: self._advance() return tok.value + if tok.kind == TokenKind.NULL: + self._advance() + return None if tok.kind == TokenKind.INTEGER: self._advance() return int(tok.value) @@ -1186,7 +1104,7 @@ def _parse_literal(self) -> str | int | float | bool: self._advance() return False raise QQLSyntaxError( - f"Expected a literal value (string, integer, float, or boolean), got '{tok.value}'", + f"Expected a literal value (string, integer, float, boolean, or null), got '{tok.value}'", tok.pos, ) @@ -1203,10 +1121,10 @@ def _parse_number(self) -> int | float: f"Expected a number, got '{tok.value}'", tok.pos ) - def _parse_literal_list(self) -> list[str | int | float | bool]: + def _parse_literal_list(self) -> list[str | int | float | bool | None]: """'(' literal { ',' literal } [','] ')' — used by IN / NOT IN.""" self._expect(TokenKind.LPAREN) - items: list[str | int | float | bool] = [] + items: list[str | int | float | bool | None] = [] if self._peek().kind == TokenKind.RPAREN: self._advance() return items @@ -1360,9 +1278,9 @@ def _parse_value(self) -> Any: def _parse_with_clause(self) -> SearchWith: self._expect(TokenKind.LBRACE) hnsw_ef: int | None = None - exact: bool = False - acorn: bool = False - indexed_only: bool = False + exact: bool | None = None + acorn: bool | None = None + indexed_only: bool | None = None quantization: QuantizationSearchWith | None = None mmr_diversity: float | None = None mmr_candidates: int | None = None diff --git a/src/qql/script.py b/src/qql/script.py index ab2de16..08023b6 100644 --- a/src/qql/script.py +++ b/src/qql/script.py @@ -30,6 +30,7 @@ TokenKind.SEARCH, TokenKind.RECOMMEND, TokenKind.DELETE, + TokenKind.BEGIN, } _DEPTH_OPEN = {TokenKind.LBRACE, TokenKind.LBRACKET, TokenKind.LPAREN} @@ -84,13 +85,14 @@ def split_statements(tokens: list[Token]) -> list[list[Token]]: """Split a flat token list into per-statement chunks. A new chunk begins whenever a statement-starter keyword (INSERT, CREATE, - DROP, SHOW, SCROLL, SELECT, SEARCH, RECOMMEND, DELETE) is encountered at - brace/bracket/paren depth 0. + DROP, SHOW, SCROLL, SELECT, SEARCH, RECOMMEND, DELETE, BEGIN) is encountered at + brace/bracket/paren depth 0 and batch_depth 0. The EOF sentinel is consumed and never included in any chunk. """ chunks: list[list[Token]] = [] current: list[Token] = [] depth = 0 + batch_depth = 0 for tok in tokens: if tok.kind == TokenKind.EOF: @@ -100,11 +102,16 @@ def split_statements(tokens: list[Token]) -> list[list[Token]]: elif tok.kind in _DEPTH_CLOSE: depth -= 1 - # New statement starts when we see a starter at the top level - if tok.kind in _STMT_STARTERS and depth == 0 and current: + # New statement starts when we see a starter at the top level outside of a batch block + if tok.kind in _STMT_STARTERS and depth == 0 and batch_depth == 0 and current: chunks.append(current) current = [] + if tok.kind == TokenKind.BEGIN: + batch_depth += 1 + elif tok.kind == TokenKind.END: + batch_depth -= 1 + current.append(tok) if current: diff --git a/src/qql/utils.py b/src/qql/utils.py new file mode 100644 index 0000000..52749de --- /dev/null +++ b/src/qql/utils.py @@ -0,0 +1,669 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from typing import Any + +from qdrant_client.models import ( + FieldCondition, + Filter, + Fusion, + HasIdCondition, + IsEmptyCondition, + IsNullCondition, + MatchAny, + MatchExcept, + MatchPhrase, + MatchText, + MatchTextAny, + MatchValue, + Mmr, + NearestQuery, + PayloadField, + Range, + RecommendStrategy, +) + +from .ast_nodes import ( + ASTNode, + AndExpr, + BetweenExpr, + CompareExpr, + FilterExpr, + InExpr, + InsertBulkStmt, + InsertStmt, + IsEmptyExpr, + IsNotEmptyExpr, + IsNotNullExpr, + IsNullExpr, + MatchAnyExpr, + MatchPhraseExpr, + MatchTextExpr, + NotExpr, + NotInExpr, + OrExpr, + RecommendStmt, + SearchStmt, + SearchWith, +) +from .exceptions import QQLRuntimeError, QQLSyntaxError +from .lexer import TokenKind + +_HYBRID_FUSION_VALUES = {"rrf", "dbsf"} + + +@dataclass(frozen=True) +class BatchGroup: + kind: str + collection: str | None + statements: list[ASTNode] + + +@dataclass(frozen=True) +class SearchUsingOptions: + model: str | None = None + hybrid: bool = False + fusion: str | None = None + sparse_only: bool = False + sparse_model: str | None = None + dense_vector: str | None = None + sparse_vector: str | None = None + + +@dataclass(frozen=True) +class SearchGroupByOptions: + group_by: str | None = None + group_size: int = 3 + + +def render_parameterized_query(template: str, params: dict[str, Any]) -> str: + rendered = [] + in_string = False + quote_char = "" + i = 0 + while i < len(template): + ch = template[i] + if in_string: + rendered.append(ch) + if ch == "\\" and i + 1 < len(template): + rendered.append(template[i + 1]) + i += 2 + continue + if ch == quote_char: + in_string = False + quote_char = "" + i += 1 + continue + + if ch in ("'", '"'): + in_string = True + quote_char = ch + rendered.append(ch) + i += 1 + continue + + if ch == ":": + name_start = i + 1 + name_end = name_start + while name_end < len(template) and ( + template[name_end].isalnum() or template[name_end] == "_" + ): + name_end += 1 + name = template[name_start:name_end] + if name in params: + rendered.append(_qql_literal(params[name])) + i = name_end + continue + + rendered.append(ch) + i += 1 + + return "".join(rendered) + + +def _qql_literal(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, str): + return _qql_string_literal(value) + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, (list, tuple)): + return "[" + ", ".join(_qql_literal(item) for item in value) + "]" + if isinstance(value, dict): + items = ", ".join( + f"{_qql_string_literal(str(key))}: {_qql_literal(item)}" + for key, item in value.items() + ) + return "{" + items + "}" + return str(value) + + +def _qql_string_literal(value: str) -> str: + escaped = ( + value.replace("\\", "\\\\") + .replace("'", "\\'") + .replace("\n", "\\n") + .replace("\t", "\\t") + .replace("\r", "\\r") + ) + return f"'{escaped}'" + + +def collection_topology_kwargs(vectors: Any, sparse_vectors: Any) -> dict[str, Any]: + if isinstance(vectors, dict): + dense_names = tuple(vectors.keys()) + dense_sizes = tuple( + (k, v.size) + for k, v in vectors.items() + if getattr(v, "size", None) is not None + ) + has_unnamed_dense = False + is_named_dense = True + elif vectors is None: + dense_names = () + dense_sizes = () + has_unnamed_dense = False + is_named_dense = False + else: + dense_names = () + unnamed_size = getattr(vectors, "size", None) + dense_sizes = (("", unnamed_size),) if unnamed_size is not None else () + has_unnamed_dense = True + is_named_dense = False + + sparse_names = tuple(sparse_vectors.keys()) if isinstance(sparse_vectors, dict) else () + return { + "exists": True, + "is_named_dense": is_named_dense, + "has_unnamed_dense": has_unnamed_dense, + "dense_names": dense_names, + "sparse_names": sparse_names, + "dense_sizes": dense_sizes, + } + + +def _append_batch_group( + groups: list[BatchGroup], + kind: str | None, + collection: str | None, + statements: list[ASTNode], +) -> None: + if statements: + groups.append(BatchGroup(kind or "other", collection, statements)) + + +def _compatible_insert_batch(current_group: list[ASTNode], stmt: InsertStmt) -> bool: + if not current_group: + return False + prev_stmt = current_group[0] + if not isinstance(prev_stmt, InsertStmt): + return False + return ( + stmt.model == prev_stmt.model + and stmt.hybrid == prev_stmt.hybrid + and stmt.sparse_model == prev_stmt.sparse_model + and stmt.dense_vector == prev_stmt.dense_vector + and stmt.sparse_vector == prev_stmt.sparse_vector + ) + + +def group_batch_statements(statements: tuple[ASTNode, ...]) -> list[BatchGroup]: + groups: list[BatchGroup] = [] + current_type: str | None = None + current_collection: str | None = None + current_group: list[ASTNode] = [] + + for stmt in statements: + if isinstance(stmt, SearchStmt) and stmt.group_by is not None: + _append_batch_group(groups, current_type, current_collection, current_group) + groups.append(BatchGroup("other", stmt.collection, [stmt])) + current_type = None + current_collection = None + current_group = [] + continue + + if isinstance(stmt, (SearchStmt, RecommendStmt)): + coll = stmt.collection + if current_type == "query" and current_collection == coll: + current_group.append(stmt) + continue + _append_batch_group(groups, current_type, current_collection, current_group) + current_type = "query" + current_collection = coll + current_group = [stmt] + continue + + if isinstance(stmt, InsertStmt): + coll = stmt.collection + if ( + current_type == "insert" + and current_collection == coll + and _compatible_insert_batch(current_group, stmt) + ): + current_group.append(stmt) + continue + _append_batch_group(groups, current_type, current_collection, current_group) + current_type = "insert" + current_collection = coll + current_group = [stmt] + continue + + _append_batch_group(groups, current_type, current_collection, current_group) + groups.append(BatchGroup("other", None, [stmt])) + current_type = None + current_collection = None + current_group = [] + + _append_batch_group(groups, current_type, current_collection, current_group) + return groups + + +def build_bulk_insert_from_group( + collection: str, + statements: list[ASTNode], +) -> InsertBulkStmt: + first = statements[0] + if not isinstance(first, InsertStmt): + raise QQLRuntimeError("Batch insert group must contain INSERT statements") + insert_statements = [stmt for stmt in statements if isinstance(stmt, InsertStmt)] + return InsertBulkStmt( + collection=collection, + values_list=tuple(stmt.values for stmt in insert_statements), + model=first.model, + hybrid=first.hybrid, + sparse_model=first.sparse_model, + dense_vector=first.dense_vector, + sparse_vector=first.sparse_vector, + ) + + +def inserted_point_results( + result: Any, + statements: list[ASTNode], + result_type: type, +) -> list[Any]: + if not result.success: + return [result_type(success=False, message=result.message) for _ in statements] + + inserted_ids = result.data.get("ids", []) if isinstance(result.data, dict) else [] + is_hybrid = "hybrid" in result.message + label = "hybrid, batched" if is_hybrid else "batched" + rows = [] + for idx, stmt in enumerate(statements): + if not isinstance(stmt, InsertStmt): + continue + point_id = inserted_ids[idx] if idx < len(inserted_ids) else "unknown" + rows.append( + result_type( + success=True, + message=f"Inserted 1 point [{point_id}] ({label})", + data={"id": point_id, "collection": stmt.collection}, + ) + ) + return rows + + +def build_qdrant_filter(expr: FilterExpr) -> Any: + if isinstance(expr, AndExpr): + return Filter(must=[build_qdrant_filter(op) for op in expr.operands]) + if isinstance(expr, OrExpr): + return Filter(should=[build_qdrant_filter(op) for op in expr.operands]) + if isinstance(expr, NotExpr): + return Filter(must_not=[build_qdrant_filter(expr.operand)]) + if isinstance(expr, CompareExpr): + if expr.value is None: + null_condition = IsNullCondition(is_null=PayloadField(key=expr.field)) + if expr.op == "=": + return null_condition + if expr.op == "!=": + return Filter(must_not=[null_condition]) + raise QQLRuntimeError( + f"Cannot use operator '{expr.op}' with null for field '{expr.field}'" + ) + if expr.op == "=": + return FieldCondition(key=expr.field, match=MatchValue(value=expr.value)) + if expr.op == "!=": + return Filter( + must_not=[ + FieldCondition(key=expr.field, match=MatchValue(value=expr.value)) + ] + ) + range_key = {">": "gt", ">=": "gte", "<": "lt", "<=": "lte"}[expr.op] + return FieldCondition(key=expr.field, range=Range(**{range_key: expr.value})) + if isinstance(expr, BetweenExpr): + return FieldCondition(key=expr.field, range=Range(gte=expr.low, lte=expr.high)) + if isinstance(expr, InExpr): + non_nulls = [value for value in expr.values if value is not None] + null_condition = IsNullCondition(is_null=PayloadField(key=expr.field)) + if len(non_nulls) == len(expr.values): + return FieldCondition(key=expr.field, match=MatchAny(any=non_nulls)) + if not non_nulls: + return null_condition + return Filter( + should=[ + null_condition, + FieldCondition(key=expr.field, match=MatchAny(any=non_nulls)), + ] + ) + if isinstance(expr, NotInExpr): + non_nulls = [value for value in expr.values if value is not None] + null_condition = IsNullCondition(is_null=PayloadField(key=expr.field)) + if len(non_nulls) != len(expr.values): + must_not = [null_condition] + if non_nulls: + must_not.append( + FieldCondition(key=expr.field, match=MatchAny(any=non_nulls)) + ) + return Filter(must_not=must_not) + return FieldCondition( + key=expr.field, + match=MatchExcept(**{"except": non_nulls}), + ) + if isinstance(expr, IsNullExpr): + return IsNullCondition(is_null=PayloadField(key=expr.field)) + if isinstance(expr, IsNotNullExpr): + return Filter(must_not=[IsNullCondition(is_null=PayloadField(key=expr.field))]) + if isinstance(expr, IsEmptyExpr): + return IsEmptyCondition(is_empty=PayloadField(key=expr.field)) + if isinstance(expr, IsNotEmptyExpr): + return Filter(must_not=[IsEmptyCondition(is_empty=PayloadField(key=expr.field))]) + if isinstance(expr, MatchTextExpr): + return FieldCondition(key=expr.field, match=MatchText(text=expr.text)) + if isinstance(expr, MatchAnyExpr): + return FieldCondition(key=expr.field, match=MatchTextAny(text_any=expr.text)) + if isinstance(expr, MatchPhraseExpr): + return FieldCondition(key=expr.field, match=MatchPhrase(phrase=expr.text)) + raise QQLRuntimeError(f"Unknown filter expression type: {type(expr)}") + + +def wrap_as_filter(qdrant_expr: Any) -> Filter: + if isinstance(qdrant_expr, Filter): + return qdrant_expr + return Filter(must=[qdrant_expr]) + + +def resolve_hybrid_fusion(fusion: str | None) -> Fusion: + if fusion is None or fusion == "rrf": + return Fusion.RRF + if fusion == "dbsf": + return Fusion.DBSF + raise QQLRuntimeError( + f"Unsupported hybrid fusion '{fusion}'; expected 'rrf' or 'dbsf'" + ) + + +def has_mmr(with_clause: SearchWith | None) -> bool: + return with_clause is not None and ( + with_clause.mmr_diversity is not None or with_clause.mmr_candidates is not None + ) + + +def validate_search_mmr_usage(node: SearchStmt) -> None: + if not has_mmr(node.with_clause): + return + if node.sparse_only: + raise QQLRuntimeError("MMR is not supported with USING SPARSE yet") + + +def build_dense_query( + vector: list[float], + with_clause: SearchWith | None, +) -> list[float] | NearestQuery: + if not has_mmr(with_clause): + return vector + return NearestQuery( + nearest=vector, + mmr=Mmr( + diversity=with_clause.mmr_diversity, + candidates_limit=with_clause.mmr_candidates, + ), + ) + + +def parse_recommend_strategy(strategy: str | None) -> RecommendStrategy | None: + if strategy is None: + return None + try: + return RecommendStrategy(strategy) + except ValueError as e: + raise QQLRuntimeError( + "Unknown recommend strategy " + f"'{strategy}'. Expected one of: average_vector, best_score, sum_scores" + ) from e + + +def exclude_ids_from_filter( + query_filter: Filter | None, + point_ids: list[str | int], +) -> Filter | None: + if not point_ids: + return query_filter + + exclude_condition = HasIdCondition(has_id=point_ids) + if query_filter is None: + return Filter(must_not=[exclude_condition]) + + return Filter( + must=list(query_filter.must or []), + should=list(query_filter.should or []), + must_not=[*(query_filter.must_not or []), exclude_condition], + min_should=query_filter.min_should, + ) + + +def extract_point_id_and_payload( + values: dict[str, Any], +) -> tuple[str | int, dict[str, Any]]: + payload = dict(values) + if "id" not in payload: + return str(uuid.uuid4()), payload + + point_id = payload.pop("id") + if isinstance(point_id, bool): + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + if isinstance(point_id, int): + if point_id < 0: + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + return point_id, payload + if isinstance(point_id, str): + try: + uuid.UUID(point_id) + except ValueError as e: + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) from e + return point_id, payload + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + + +def build_dense_point_vector( + topology: Any, + vector: list[float], + explicit_vector: str | None, + default_dense_vector_name: str, +) -> list[float] | dict[str, list[float]]: + if not topology.exists: + return {explicit_vector or default_dense_vector_name: vector} + vector_name = topology.dense_payload_name(explicit_vector) + if vector_name is None: + return vector + return {vector_name: vector} + + +def merge_search_with(base: SearchWith | None, override: SearchWith) -> SearchWith: + if base is None: + return override + return SearchWith( + hnsw_ef=override.hnsw_ef if override.hnsw_ef is not None else base.hnsw_ef, + exact=override.exact if override.exact is not None else base.exact, + acorn=override.acorn if override.acorn is not None else base.acorn, + indexed_only=( + override.indexed_only + if override.indexed_only is not None + else base.indexed_only + ), + quantization=( + override.quantization + if override.quantization is not None + else base.quantization + ), + mmr_diversity=( + override.mmr_diversity + if override.mmr_diversity is not None + else base.mmr_diversity + ), + mmr_candidates=( + override.mmr_candidates + if override.mmr_candidates is not None + else base.mmr_candidates + ), + ) + + +def parse_search_lookup(parser: Any) -> tuple[str, str | None] | None: + if parser._peek().kind != TokenKind.LOOKUP: + return None + parser._advance() + parser._expect(TokenKind.FROM) + lookup_collection = parser._parse_identifier() + lookup_vector: str | None = None + if parser._peek().kind == TokenKind.VECTOR: + parser._advance() + lookup_vector = parser._expect(TokenKind.STRING).value + return lookup_collection, lookup_vector + + +def parse_search_using(parser: Any) -> SearchUsingOptions: + if parser._peek().kind != TokenKind.USING: + return SearchUsingOptions() + + parser._advance() + if parser._peek().kind == TokenKind.HYBRID: + return _parse_hybrid_using(parser) + if parser._peek().kind == TokenKind.SPARSE: + return _parse_sparse_using(parser) + if parser._peek().kind == TokenKind.VECTOR: + parser._advance() + return SearchUsingOptions(dense_vector=parser._expect(TokenKind.STRING).value) + + parser._expect(TokenKind.MODEL) + return SearchUsingOptions(model=parser._expect(TokenKind.STRING).value) + + +def _parse_hybrid_using(parser: Any) -> SearchUsingOptions: + parser._advance() + model: str | None = None + fusion: str | None = None + sparse_model: str | None = None + dense_vector: str | None = None + sparse_vector: str | None = None + + while parser._peek().kind in (TokenKind.FUSION, TokenKind.DENSE, TokenKind.SPARSE): + sub = parser._advance() + if sub.kind == TokenKind.FUSION: + value_tok = parser._expect(TokenKind.STRING) + fusion = value_tok.value.lower() + if fusion not in _HYBRID_FUSION_VALUES: + raise QQLSyntaxError( + f"Unsupported hybrid fusion '{value_tok.value}'; expected 'rrf' or 'dbsf'", + value_tok.pos, + ) + continue + if parser._peek().kind == TokenKind.MODEL: + parser._advance() + parsed_model = parser._expect(TokenKind.STRING).value + if sub.kind == TokenKind.DENSE: + model = parsed_model + else: + sparse_model = parsed_model + continue + if parser._peek().kind == TokenKind.VECTOR: + parser._advance() + name = parser._expect(TokenKind.STRING).value + if sub.kind == TokenKind.DENSE: + dense_vector = name + else: + sparse_vector = name + continue + raise QQLSyntaxError( + "Expected MODEL or VECTOR after DENSE/SPARSE in USING HYBRID", + parser._peek().pos, + ) + + return SearchUsingOptions( + model=model, + hybrid=True, + fusion=fusion, + sparse_model=sparse_model, + dense_vector=dense_vector, + sparse_vector=sparse_vector, + ) + + +def _parse_sparse_using(parser: Any) -> SearchUsingOptions: + parser._advance() + sparse_model: str | None = None + sparse_vector: str | None = None + while parser._peek().kind in (TokenKind.MODEL, TokenKind.VECTOR): + sub = parser._advance() + if sub.kind == TokenKind.MODEL: + sparse_model = parser._expect(TokenKind.STRING).value + else: + sparse_vector = parser._expect(TokenKind.STRING).value + return SearchUsingOptions( + sparse_only=True, + sparse_model=sparse_model, + sparse_vector=sparse_vector, + ) + + +def parse_search_with(parser: Any, with_clause: SearchWith | None) -> SearchWith | None: + if parser._peek().kind == TokenKind.EXACT: + parser._advance() + with_clause = merge_search_with(with_clause, SearchWith(exact=True)) + + if parser._peek().kind == TokenKind.WITH: + parser._advance() + with_clause = merge_search_with(with_clause, parser._parse_with_clause()) + + return with_clause + + +def parse_search_group_by( + parser: Any, + offset: int, + rerank: bool, +) -> SearchGroupByOptions: + if parser._peek().kind != TokenKind.GROUP: + return SearchGroupByOptions() + + if offset > 0: + raise QQLSyntaxError("OFFSET cannot be used with GROUP BY", parser._peek().pos) + parser._advance() + parser._expect(TokenKind.BY) + group_by = parser._parse_field_path() + if rerank: + raise QQLSyntaxError( + "GROUP BY and RERANK cannot be combined in the same SEARCH statement", + parser._peek().pos, + ) + + group_size = 3 + if parser._peek().kind == TokenKind.GROUP_SIZE: + parser._advance() + group_size_tok = parser._peek() + group_size = int(parser._expect(TokenKind.INTEGER).value) + if group_size <= 0: + raise QQLSyntaxError( + f"GROUP_SIZE must be a positive integer, got {group_size}", + group_size_tok.pos, + ) + return SearchGroupByOptions(group_by=group_by, group_size=group_size) diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py new file mode 100644 index 0000000..57b892a --- /dev/null +++ b/tests/test_async_connection.py @@ -0,0 +1,473 @@ +"""Tests for the AsyncConnection and QQLAsyncBatch classes. + +All tests mock AsyncQdrantClient so no live Qdrant instance is required. +""" +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock + +from qql import ( + AsyncConnection, + QQLConfig, + AsyncExecutor, + ExecutionResult, + QQLAsyncBatch, +) +from qql.exceptions import QQLSyntaxError + + +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +# ── TestAsyncConnectionInit ─────────────────────────────────────────────────── + +class TestAsyncConnectionInit: + """AsyncConnection.__init__ stores config and wires up the async executor.""" + + def test_default_url_and_no_secret(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection() + assert conn.config.url == "http://localhost:6333" + assert conn.config.secret is None + + def test_custom_url_and_secret_passed_to_async_qdrant_client(self, mocker): + mock_client_cls = mocker.patch("qql.async_connection.AsyncQdrantClient") + AsyncConnection("https://cloud.example.io", secret="s3cr3t") + mock_client_cls.assert_called_once_with( + url="https://cloud.example.io", api_key="s3cr3t" + ) + + def test_custom_default_model_stored_in_config(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection("http://localhost:6333", default_model="BAAI/bge-small-en-v1.5") + assert conn.config.default_model == "BAAI/bge-small-en-v1.5" + + def test_config_and_executor_properties_return_correct_types(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection("http://localhost:6333") + assert isinstance(conn.config, QQLConfig) + assert isinstance(conn.executor, AsyncExecutor) + + +# ── TestAsyncConnectionRunQuery ──────────────────────────────────────────────── + +@pytest.mark.anyio +class TestAsyncConnectionRunQuery: + """AsyncConnection.run_query() pipes through the Lexer → Parser → AsyncExecutor.""" + + async def test_run_query_calls_executor_execute(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[] + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection("http://localhost:6333") + await conn.run_query("SHOW COLLECTIONS") + mock_executor.execute.assert_called_once() + + async def test_executor_instance_reused_across_queries(self, mocker): + """AsyncExecutor() is constructed once; run_query() never re-instantiates it.""" + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[] + ) + executor_cls = mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection("http://localhost:6333") + await conn.run_query("SHOW COLLECTIONS") + await conn.run_query("SHOW COLLECTIONS") + await conn.run_query("SHOW COLLECTIONS") + + # AsyncExecutor constructor called exactly once, not once per query + executor_cls.assert_called_once() + # But execute() called three times + assert mock_executor.execute.call_count == 3 + + async def test_invalid_query_raises_qql_syntax_error(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + conn = AsyncConnection("http://localhost:6333") + with pytest.raises(QQLSyntaxError): + await conn.run_query("TOTALLY INVALID QUERY GIBBERISH") + + async def test_run_query_returns_execution_result(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="1 collection(s) found", data=["docs"] + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection("http://localhost:6333") + result = await conn.run_query("SHOW COLLECTIONS") + assert isinstance(result, ExecutionResult) + assert result.success is True + + +# ── TestAsyncConnectionLifecycle ─────────────────────────────────────────────── + +@pytest.mark.anyio +class TestAsyncConnectionLifecycle: + """AsyncConnection.close() and the async context-manager protocol.""" + + async def test_close_calls_client_close(self, mocker): + mock_client = AsyncMock() + mocker.patch("qql.async_connection.AsyncQdrantClient", return_value=mock_client) + conn = AsyncConnection("http://localhost:6333") + await conn.close() + mock_client.close.assert_called_once() + + async def test_context_manager_closes_on_exit(self, mocker): + mock_client = AsyncMock() + mocker.patch("qql.async_connection.AsyncQdrantClient", return_value=mock_client) + + async with AsyncConnection("http://localhost:6333") as conn: + assert conn._client is mock_client + + mock_client.close.assert_called_once() + + async def test_context_manager_closes_on_exception(self, mocker): + mock_client = AsyncMock() + mocker.patch("qql.async_connection.AsyncQdrantClient", return_value=mock_client) + + with pytest.raises(RuntimeError, match="oops"): + async with AsyncConnection("http://localhost:6333"): + raise RuntimeError("oops") + + mock_client.close.assert_called_once() + + +# ── TestAsyncConnectionBatch ─────────────────────────────────────────────────── + +@pytest.mark.anyio +class TestAsyncConnectionBatch: + """AsyncConnection batching support (run_queries_batch, run_parameterized_batch, QQLAsyncBatch).""" + + async def test_run_queries_batch(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="Batch executed successfully", + data=[ + ExecutionResult(success=True, message="Found 1 result(s)"), + ExecutionResult(success=True, message="Found 2 result(s)"), + ], + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + results = await conn.run_queries_batch([ + "SEARCH docs SIMILAR TO 'neurology' LIMIT 5", + "SEARCH docs SIMILAR TO 'cardiology' LIMIT 5", + ]) + assert len(results) == 2 + assert results[0].message == "Found 1 result(s)" + assert results[1].message == "Found 2 result(s)" + + async def test_run_parameterized_batch(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="Batch executed successfully", + data=[ + ExecutionResult(success=True, message="ok"), + ExecutionResult(success=True, message="ok"), + ], + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + results = await conn.run_parameterized_batch( + "SEARCH docs SIMILAR TO :query LIMIT 5 WHERE category = :category", + [ + {"query": "brain stroke", "category": "Neurology"}, + {"query": "heart attack", "category": "Cardiology"}, + ], + ) + assert len(results) == 2 + mock_executor.execute.assert_called_once() + stmt = mock_executor.execute.call_args[0][0] + # Verify both statements compiled correctly + assert len(stmt.statements) == 2 + + async def test_run_parameterized_query(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="ok", + data="res1", + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + result = await conn.run_parameterized_query( + "SEARCH docs SIMILAR TO :query LIMIT 5 WHERE category = :category", + {"query": "brain stroke", "category": "Neurology"}, + ) + + stmt = mock_executor.execute.call_args[0][0] + assert result.data == "res1" + assert stmt.query_text == "brain stroke" + + async def test_qql_async_batch_context_manager(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="Batch executed successfully", + data=[ + ExecutionResult(success=True, message="Res 1", data="d1"), + ExecutionResult(success=True, message="Res 2", data="d2"), + ], + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + async with QQLAsyncBatch(conn) as batch: + ref1 = batch.add("SEARCH docs SIMILAR TO 'neurology' LIMIT 5") + ref2 = batch.add("SEARCH docs SIMILAR TO 'cardiology' LIMIT 5") + + assert ref1.result.message == "Res 1" + assert ref2.result.message == "Res 2" + assert ref1.result.data == "d1" + assert ref2.result.data == "d2" + + async def test_qql_async_batch_raises_when_result_count_mismatches_proxy_count(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="Batch executed", + data=[ExecutionResult(success=True, message="Res 1", data="d1")], + ) + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + async with QQLAsyncBatch(conn) as batch: + ref1 = batch.add("SHOW COLLECTIONS") + ref2 = batch.add("SHOW COLLECTIONS") + + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + _ = ref1.result + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + _ = ref2.result + + async def test_qql_async_batch_reuse_does_not_replay_previous_queries(self, mocker): + mocker.patch("qql.async_connection.AsyncQdrantClient") + mock_executor = AsyncMock() + mock_executor.execute.side_effect = [ + ExecutionResult(success=True, message="ok", data=[ + ExecutionResult(success=True, message="first", data="d1"), + ]), + ExecutionResult(success=True, message="ok", data=[ + ExecutionResult(success=True, message="second", data="d2"), + ]), + ] + mocker.patch("qql.async_connection.AsyncExecutor", return_value=mock_executor) + + conn = AsyncConnection() + batch = QQLAsyncBatch(conn) + async with batch: + first = batch.add("SHOW COLLECTIONS") + async with batch: + second = batch.add("SHOW COLLECTIONS") + + assert first.result.data == "d1" + assert second.result.data == "d2" + assert len(mock_executor.execute.call_args_list[0].args[0].statements) == 1 + assert len(mock_executor.execute.call_args_list[1].args[0].statements) == 1 + + +# ── TestArchitecturalGapsClosed ──────────────────────────────────────────────── + +@pytest.mark.anyio +class TestArchitecturalGapsClosed: + """Rigorous tests covering async execution, race conditions, strict parser validation, and ID propagation.""" + + async def test_async_search_embeds_once(self, mocker): + """AsyncExecutor keeps the hot path direct and avoids threadpool overhead for cached embeddings.""" + mock_client = AsyncMock() + mock_client.collection_exists.return_value = True + + # Mock embedders to track how they are called + mocker.patch("qql.async_executor.Embedder.__init__", return_value=None) + mock_embed = mocker.patch("qql.async_executor.Embedder.embed", return_value=[0.1, 0.2]) + + from qql import QQLConfig + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + + from qql.parser import Parser + from qql.lexer import Lexer + + node = Parser(Lexer().tokenize("SEARCH docs SIMILAR TO 'neurology' LIMIT 5")).parse() + + result = await executor.execute(node) + assert result.success is True + mock_embed.assert_called_once_with("neurology") + + async def test_batch_parsing_rejects_trailing_statements(self): + """run_queries_batch must raise QQLSyntaxError when a query contains trailing tokens.""" + conn = AsyncConnection() + with pytest.raises(QQLSyntaxError, match="Expected EOF"): + await conn.run_queries_batch([ + "SHOW COLLECTIONS; DROP COLLECTION x" + ]) + + async def test_batched_insert_propagates_correct_ids(self, mocker): + """_execute_batch_block preserves individual point IDs when aggregating Inserts into a bulk Insert.""" + mock_client = AsyncMock() + mock_client.collection_exists.return_value = True + mock_client.upsert.return_value = None + + # Mock get_collection to return a mock config with matching vector size + mock_info = mocker.MagicMock() + mock_info.config.params.vectors.size = 2 + mock_client.get_collection.return_value = mock_info + + mocker.patch("qql.async_executor.Embedder.__init__", return_value=None) + mocker.patch("qql.async_executor.Embedder.embed", return_value=[0.1, 0.2]) + + from qql.executor import CollectionTopology + topology = CollectionTopology( + exists=True, + is_named_dense=False, + has_unnamed_dense=True, + dense_names=(), + sparse_names=(), + ) + mocker.patch("qql.async_executor.AsyncExecutor._resolve_topology", return_value=topology) + + from qql import QQLConfig + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + + from qql.parser import Parser + from qql.lexer import Lexer + + # Aggregate multiple INSERTs inside a Batch block statement + qql_batch = ( + "BEGIN BATCH\n" + "INSERT INTO COLLECTION docs VALUES {'text': 'a', 'id': 101};\n" + "INSERT INTO COLLECTION docs VALUES {'text': 'b', 'id': 102};\n" + "END BATCH" + ) + node = Parser(Lexer().tokenize(qql_batch)).parse() + res = await executor.execute(node) + + assert res.success is True + # Check that individual execution results correctly maintain their original custom point ID identity! + assert len(res.data) == 2 + assert res.data[0].data["id"] == 101 + assert res.data[1].data["id"] == 102 + + async def test_race_condition_collection_creation(self, mocker): + """Concurrent inserts into a non-existent collection serialize creation to avoid Qdrant conflicts.""" + import asyncio + mock_client = AsyncMock() + + # Mock get_collection to return a mock config with matching vector size + mock_info = mocker.MagicMock() + mock_info.config.params.vectors.size = 2 + mock_client.get_collection.return_value = mock_info + + from qql.executor import CollectionTopology + # Mock resolve_topology sequence using real CollectionTopology objects + topology_sequence = [ + CollectionTopology(exists=False, is_named_dense=False), # First insert task resolve topology + CollectionTopology(exists=False, is_named_dense=False), # Second insert task resolve topology + CollectionTopology(exists=False, is_named_dense=False), # Inside lock for first insert + CollectionTopology(exists=True, is_named_dense=False, has_unnamed_dense=True, dense_names=(), sparse_names=()), # Inside lock for second insert + ] + + mocker.patch("qql.async_executor.Embedder.__init__", return_value=None) + mocker.patch("qql.async_executor.Embedder.embed", return_value=[0.1, 0.2]) + + # Override _resolve_topology to yield the sequence + calls = 0 + async def mock_resolve(*args, **kwargs): + nonlocal calls + val = topology_sequence[calls] + calls += 1 + return val + + mocker.patch("qql.async_executor.AsyncExecutor._resolve_topology", side_effect=mock_resolve) + mocker.patch("qql.async_executor.AsyncExecutor._create_collection_and_wait", return_value=None) + + from qql import QQLConfig + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + + from qql.parser import Parser + from qql.lexer import Lexer + + insert_node_1 = Parser(Lexer().tokenize("INSERT INTO COLLECTION docs VALUES {'text': 'a', 'id': 1}")).parse() + insert_node_2 = Parser(Lexer().tokenize("INSERT INTO COLLECTION docs VALUES {'text': 'b', 'id': 2}")).parse() + + # Fire both concurrently + res1, res2 = await asyncio.gather( + executor.execute(insert_node_1), + executor.execute(insert_node_2), + ) + + assert res1.success is True + assert res2.success is True + # Verify that _create_collection_and_wait was called exactly once despite concurrency! + executor._create_collection_and_wait.assert_called_once() + + async def test_async_insert_uses_refreshed_topology_after_create_race(self, mocker): + """If another coroutine creates the collection, upsert must use its vector name.""" + mock_client = AsyncMock() + mock_client.upsert.return_value = None + + mocker.patch("qql.async_executor.Embedder.__init__", return_value=None) + mocker.patch("qql.async_executor.Embedder.embed", return_value=[0.1, 0.2]) + + from qql.executor import CollectionTopology + + stale_topology = CollectionTopology(exists=False, is_named_dense=False) + created_topology = CollectionTopology( + exists=True, + is_named_dense=True, + dense_names=("body",), + sparse_names=(), + ) + + async def mock_ensure(*args, **kwargs): + return created_topology + + mocker.patch("qql.async_executor.AsyncExecutor._resolve_topology", return_value=stale_topology) + mocker.patch("qql.async_executor.AsyncExecutor._ensure_collection", side_effect=mock_ensure) + + from qql import QQLConfig + executor = AsyncExecutor(mock_client, QQLConfig(url="http://localhost:6333")) + + from qql.parser import Parser + from qql.lexer import Lexer + + node = Parser( + Lexer().tokenize("INSERT INTO COLLECTION docs VALUES {'text': 'hello', 'id': 1}") + ).parse() + await executor.execute(node) + + point = mock_client.upsert.call_args.kwargs["points"][0] + assert point.vector == {"body": [0.1, 0.2]} + + def test_strict_batch_grammar(self): + """Parser must raise QQLSyntaxError if a batch block ends with bare END instead of END BATCH.""" + from qql.parser import Parser + from qql.lexer import Lexer + + invalid_batch = ( + "BEGIN BATCH\n" + "SHOW COLLECTIONS\n" + "END" + ) + with pytest.raises(QQLSyntaxError, match="Expected BATCH"): + Parser(Lexer().tokenize(invalid_batch)).parse() diff --git a/tests/test_connection.py b/tests/test_connection.py index c209698..6be0853 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -6,6 +6,7 @@ from qql import Connection, QQLConfig, Executor, ExecutionResult, run_query from qql.exceptions import QQLSyntaxError +from qql.utils import render_parameterized_query # ── TestConnectionInit ──────────────────────────────────────────────────────── @@ -191,3 +192,194 @@ def test_run_query_invalid_syntax_still_raises(self, mocker): mocker.patch("qdrant_client.QdrantClient") with pytest.raises(QQLSyntaxError): run_query("TOTALLY INVALID", url="http://localhost:6333") + + +# ── TestConnectionBatching ─────────────────────────────────────────────────── + +class TestConnectionBatching: + """Connection batching functionality: prefer_grpc, QQLBatch, parameterized run.""" + + def test_prefer_grpc_passed_to_qdrant_client(self, mocker): + mock_client_cls = mocker.patch("qdrant_client.QdrantClient") + Connection("http://localhost:6333", prefer_grpc=True, grpc_port=9999) + mock_client_cls.assert_called_once_with( + url="http://localhost:6333", api_key=None, prefer_grpc=True, grpc_port=9999 + ) + + def test_run_queries_batch_pipes_to_executor(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[ + ExecutionResult(success=True, message="ok1", data="res1"), + ExecutionResult(success=True, message="ok2", data="res2") + ] + ) + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + conn = Connection() + results = conn.run_queries_batch([ + "SHOW COLLECTIONS", + "SHOW COLLECTIONS" + ]) + assert len(results) == 2 + assert results[0].data == "res1" + assert results[1].data == "res2" + mock_executor.execute.assert_called_once() + + def test_run_parameterized_batch_substitutes_vars(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[ + ExecutionResult(success=True, message="ok1", data="res1"), + ExecutionResult(success=True, message="ok2", data="res2") + ] + ) + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + conn = Connection() + results = conn.run_parameterized_batch( + "SEARCH docs SIMILAR TO :query LIMIT :limit WHERE category = :cat", + [ + {"query": "ML", "limit": 5, "cat": "news"}, + {"query": "AI", "limit": 10, "cat": "tech"} + ] + ) + assert len(results) == 2 + # Check that AST was called with reconstructed strings + called_node = mock_executor.execute.call_args[0][0] + # In BatchBlockStmt, statements should contain substituted strings + assert called_node.statements[0].query_text == "ML" + assert called_node.statements[0].limit == 5 + assert called_node.statements[1].query_text == "AI" + assert called_node.statements[1].limit == 10 + + def test_run_parameterized_query_substitutes_vars(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data="res1" + ) + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + conn = Connection() + result = conn.run_parameterized_query( + "SEARCH docs SIMILAR TO :query LIMIT :limit WHERE category = :cat", + {"query": "ML", "limit": 5, "cat": "news"}, + ) + + called_node = mock_executor.execute.call_args[0][0] + assert result.data == "res1" + assert called_node.query_text == "ML" + assert called_node.limit == 5 + + def test_qql_batch_context_manager(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, message="ok", data=[ + ExecutionResult(success=True, message="ok1", data="res1"), + ExecutionResult(success=True, message="ok2", data="res2") + ] + ) + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + from qql import QQLBatch + conn = Connection() + with QQLBatch(conn) as batch: + p1 = batch.add("SHOW COLLECTIONS") + p2 = batch.add("SHOW COLLECTIONS") + with pytest.raises(RuntimeError): + _ = p1.result + + assert p1.result.data == "res1" + assert p2.result.data == "res2" + mock_executor.execute.assert_called_once() + + def test_qql_batch_raises_when_result_count_mismatches_proxy_count(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.return_value = ExecutionResult( + success=True, + message="ok", + data=[ExecutionResult(success=True, message="only one", data="res1")], + ) + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + from qql import QQLBatch + conn = Connection() + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + with QQLBatch(conn) as batch: + p1 = batch.add("SHOW COLLECTIONS") + p2 = batch.add("SHOW COLLECTIONS") + + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + _ = p1.result + with pytest.raises(RuntimeError, match="Batch result count mismatch"): + _ = p2.result + + def test_qql_batch_reuse_does_not_replay_previous_queries(self, mocker): + mocker.patch("qdrant_client.QdrantClient") + mock_executor = mocker.MagicMock() + mock_executor.execute.side_effect = [ + ExecutionResult(success=True, message="ok", data=[ + ExecutionResult(success=True, message="first", data="res1"), + ]), + ExecutionResult(success=True, message="ok", data=[ + ExecutionResult(success=True, message="second", data="res2"), + ]), + ] + mocker.patch("qql.connection.Executor", return_value=mock_executor) + + from qql import QQLBatch + conn = Connection() + batch = QQLBatch(conn) + with batch: + first = batch.add("SHOW COLLECTIONS") + with batch: + second = batch.add("SHOW COLLECTIONS") + + assert first.result.data == "res1" + assert second.result.data == "res2" + assert len(mock_executor.execute.call_args_list[0].args[0].statements) == 1 + assert len(mock_executor.execute.call_args_list[1].args[0].statements) == 1 + + +class TestParameterizedRendering: + def test_render_parameterized_query_escapes_strings_and_nulls(self): + rendered = render_parameterized_query( + "SEARCH docs SIMILAR TO :query LIMIT 5 WHERE category = :category", + {"query": "O'Reilly\\notes\nnext", "category": None}, + ) + + assert rendered == ( + "SEARCH docs SIMILAR TO 'O\\'Reilly\\\\notes\\nnext' " + "LIMIT 5 WHERE category = null" + ) + + def test_render_parameterized_query_does_not_replace_inside_strings(self): + rendered = render_parameterized_query( + "SEARCH docs SIMILAR TO ':query' LIMIT 5 WHERE tag = :query", + {"query": "needle"}, + ) + + assert rendered == "SEARCH docs SIMILAR TO ':query' LIMIT 5 WHERE tag = 'needle'" + + def test_render_parameterized_query_renders_compound_literals(self): + rendered = render_parameterized_query( + "INSERT INTO COLLECTION docs VALUES :values", + { + "values": { + "text": "hello", + "tags": ["a", "b"], + "meta": {"score": 1, "active": True}, + } + }, + ) + + assert rendered == ( + "INSERT INTO COLLECTION docs VALUES " + "{'text': 'hello', 'tags': ['a', 'b'], " + "'meta': {'score': 1, 'active': true}}" + ) diff --git a/tests/test_executor.py b/tests/test_executor.py index 0a2a74f..9912a78 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1729,6 +1729,46 @@ def test_build_not_in(self, executor): assert isinstance(result, FieldCondition) assert isinstance(result.match, MatchExcept) + def test_build_in_with_only_null(self, executor): + from qdrant_client.models import IsNullCondition + from qql.ast_nodes import InExpr + + result = executor._build_qdrant_filter(InExpr(field="status", values=(None,))) + assert isinstance(result, IsNullCondition) + assert result.is_null.key == "status" + + def test_build_in_with_null_and_values(self, executor): + from qdrant_client.models import Filter, FieldCondition, IsNullCondition, MatchAny + from qql.ast_nodes import InExpr + + result = executor._build_qdrant_filter(InExpr(field="status", values=(None, "draft"))) + assert isinstance(result, Filter) + assert isinstance(result.should[0], IsNullCondition) + assert isinstance(result.should[1], FieldCondition) + assert isinstance(result.should[1].match, MatchAny) + assert result.should[1].match.any == ["draft"] + + def test_build_not_in_with_only_null(self, executor): + from qdrant_client.models import Filter, IsNullCondition + from qql.ast_nodes import NotInExpr + + result = executor._build_qdrant_filter(NotInExpr(field="status", values=(None,))) + assert isinstance(result, Filter) + assert isinstance(result.must_not[0], IsNullCondition) + + def test_build_not_in_with_null_and_values(self, executor): + from qdrant_client.models import Filter, FieldCondition, IsNullCondition, MatchAny + from qql.ast_nodes import NotInExpr + + result = executor._build_qdrant_filter( + NotInExpr(field="status", values=(None, "deleted")) + ) + assert isinstance(result, Filter) + assert isinstance(result.must_not[0], IsNullCondition) + assert isinstance(result.must_not[1], FieldCondition) + assert isinstance(result.must_not[1].match, MatchAny) + assert result.must_not[1].match.any == ["deleted"] + def test_build_is_null(self, executor): from qdrant_client.models import IsNullCondition from qql.ast_nodes import IsNullExpr @@ -3201,6 +3241,96 @@ def test_grouped_search_params_with_clause_forwarded(self, executor, mock_client assert kwargs.get("search_params") is not None +class TestBatchGroupedSearch: + def test_grouped_search_in_batch_runs_individually(self, executor, mock_client, mocker): + _mock_hybrid_collection(mock_client) + mock_group_response = mocker.MagicMock() + mock_group_response.groups = [] + mock_client.query_points_groups.return_value = mock_group_response + mock_query_response = mocker.MagicMock() + mock_query_response.points = [] + mock_client.query_batch_points.return_value = [mock_query_response] + + from qql.lexer import Lexer + from qql.parser import Parser + + node = Parser( + Lexer().tokenize( + "BEGIN BATCH " + "SEARCH articles SIMILAR TO 'q' LIMIT 5 GROUP BY category; " + "SEARCH articles SIMILAR TO 'plain' LIMIT 5 " + "END BATCH" + ) + ).parse() + + result = executor.execute(node) + + assert result.success is True + mock_client.query_points_groups.assert_called_once() + mock_client.query_batch_points.assert_called_once() + assert "group(s)" in result.data[0].message + assert "result(s)" in result.data[1].message + + +class TestNullComparisonFilters: + def test_equal_null_builds_is_null_condition(self, executor): + from qql.ast_nodes import CompareExpr + from qdrant_client.models import IsNullCondition + + result = executor._build_qdrant_filter( + CompareExpr(field="deleted_at", op="=", value=None) + ) + + assert isinstance(result, IsNullCondition) + assert result.is_null.key == "deleted_at" + + def test_not_equal_null_builds_not_null_filter(self, executor): + from qql.ast_nodes import CompareExpr + from qdrant_client.models import Filter, IsNullCondition + + result = executor._build_qdrant_filter( + CompareExpr(field="deleted_at", op="!=", value=None) + ) + + assert isinstance(result, Filter) + assert isinstance(result.must_not[0], IsNullCondition) + + def test_ordering_comparison_to_null_raises_clear_error(self, executor): + from qql.ast_nodes import CompareExpr + + with pytest.raises(QQLRuntimeError, match="Cannot use operator '>' with null"): + executor._build_qdrant_filter( + CompareExpr(field="deleted_at", op=">", value=None) + ) + + +class TestMergeSearchWith: + def test_merge_search_with_preserves_zero_values(self): + from qql.ast_nodes import SearchWith + from qql.utils import merge_search_with + + merged = merge_search_with( + SearchWith(hnsw_ef=128, mmr_candidates=10), + SearchWith(hnsw_ef=0, mmr_candidates=0), + ) + + assert merged.hnsw_ef == 0 + assert merged.mmr_candidates == 0 + + def test_merge_search_with_can_override_true_with_false(self): + from qql.ast_nodes import SearchWith + from qql.utils import merge_search_with + + merged = merge_search_with( + SearchWith(exact=True, acorn=True, indexed_only=True), + SearchWith(exact=False, acorn=False, indexed_only=False), + ) + + assert merged.exact is False + assert merged.acorn is False + assert merged.indexed_only is False + + class TestUpdateVectorVectorShape: """Gaps 12 & 13 — verify exact vector shape sent to Qdrant for named/unnamed collections.""" diff --git a/tests/test_parser.py b/tests/test_parser.py index ec9e051..c2bd64b 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -46,6 +46,21 @@ def parse(query: str): return Parser(tokens).parse() +class TestStatementBoundaries: + def test_top_level_statement_allows_trailing_semicolon(self): + node = parse("SHOW COLLECTIONS;") + assert isinstance(node, ShowCollectionsStmt) + + def test_batch_block_allows_trailing_semicolon(self): + node = parse( + "BEGIN BATCH\n" + "SHOW COLLECTIONS;\n" + "END BATCH;" + ) + assert len(node.statements) == 1 + assert isinstance(node.statements[0], ShowCollectionsStmt) + + class TestInsert: def test_basic_insert(self): node = parse("INSERT INTO COLLECTION notes VALUES {'text': 'hello'}") @@ -1086,6 +1101,17 @@ def test_with_exact_false(self): assert node.with_clause is not None assert node.with_clause.exact is False + def test_exact_keyword_survives_with_clause_merge(self): + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 EXACT WITH { hnsw_ef: 128 }") + assert node.with_clause is not None + assert node.with_clause.exact is True + assert node.with_clause.hnsw_ef == 128 + + def test_with_exact_false_can_override_exact_keyword(self): + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 EXACT WITH { exact: false }") + assert node.with_clause is not None + assert node.with_clause.exact is False + def test_with_acorn(self): node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 WITH { acorn: true }") assert node.with_clause is not None