From bfae537def7fab5fc9ee6282636fd454536aaed4 Mon Sep 17 00:00:00 2001 From: Jordan Russell Date: Thu, 5 Feb 2026 07:45:35 +1300 Subject: [PATCH 1/6] feat: add sqlglot dialect extensions for PGVector and ParadeDB Adds dialect extensions for PGVector and ParadeDB. This uses a workaround for missing Tokens by 'sharing' unused Postgres tokens. New extension tokens won't be upstreamed as per https://github.com/tobymao/sqlglot/issues/6949 --- .../extension_dialects/__init__.py | 1 + .../extension_dialects/postgres/__init__.py | 6 ++ .../extension_dialects/postgres/paradedb.py | 84 ++++++++++++++++ .../extension_dialects/postgres/pgvector.py | 98 +++++++++++++++++++ 4 files changed, 189 insertions(+) create mode 100644 sqlspec/data_dictionary/extension_dialects/__init__.py create mode 100644 sqlspec/data_dictionary/extension_dialects/postgres/__init__.py create mode 100644 sqlspec/data_dictionary/extension_dialects/postgres/paradedb.py create mode 100644 sqlspec/data_dictionary/extension_dialects/postgres/pgvector.py diff --git a/sqlspec/data_dictionary/extension_dialects/__init__.py b/sqlspec/data_dictionary/extension_dialects/__init__.py new file mode 100644 index 00000000..bcdffecc --- /dev/null +++ b/sqlspec/data_dictionary/extension_dialects/__init__.py @@ -0,0 +1 @@ +"""Extension dialects""" diff --git a/sqlspec/data_dictionary/extension_dialects/postgres/__init__.py b/sqlspec/data_dictionary/extension_dialects/postgres/__init__.py new file mode 100644 index 00000000..23f19a17 --- /dev/null +++ b/sqlspec/data_dictionary/extension_dialects/postgres/__init__.py @@ -0,0 +1,6 @@ +"""Postgres extension dialects""" + +from sqlspec.data_dictionary.extension_dialects.postgres.paradedb import ParadeDB +from sqlspec.data_dictionary.extension_dialects.postgres.pgvector import PGVector + +__all__ = ("PGVector", "ParadeDB") diff --git a/sqlspec/data_dictionary/extension_dialects/postgres/paradedb.py b/sqlspec/data_dictionary/extension_dialects/postgres/paradedb.py new file mode 100644 index 00000000..57839d7e --- /dev/null +++ b/sqlspec/data_dictionary/extension_dialects/postgres/paradedb.py @@ -0,0 +1,84 @@ +"""ParadeDB dialect extending PGVector with pg_search BM25/search operators. + +Adds support for ParadeDB search operators: +- @@@ : BM25 full-text search +- &&& : Boolean AND search +- ||| : Boolean OR search +- === : Exact term match +- ### : Score/rank retrieval +- ## : Snippet/highlight retrieval +- ##> : Snippet/highlight with options + +Also inherits pgvector distance operators from PGVector: +- <-> : L2 (Euclidean) distance +- <#> : Negative inner product +- <=> : Cosine distance +- <+> : L1 (Taxicab/Manhattan) distance +- <~> : Hamming distance (binary vectors) +- <%> : Jaccard distance (binary vectors) +""" + +from __future__ import annotations + +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.tokens import TokenType + +from sqlspec.data_dictionary.extension_dialects.postgres.pgvector import ( + PGVector, + PGVectorGenerator, + PGVectorParser, + PGVectorTokenizer, +) + +__all__ = ("ParadeDB",) + +_PARADEDB_SEARCH_TOKEN = TokenType.DAT + + +class SearchOperator(exp.Binary): + """ParadeDB search operation that preserves the original operator.""" + + arg_types = {"this": True, "expression": True, "operator": True} + + +class ParadeDBTokenizer(PGVectorTokenizer): + """Tokenizer with ParadeDB search operators and pgvector distance operators.""" + + KEYWORDS = { + **PGVectorTokenizer.KEYWORDS, + "@@@": _PARADEDB_SEARCH_TOKEN, + "&&&": _PARADEDB_SEARCH_TOKEN, + "|||": _PARADEDB_SEARCH_TOKEN, + "===": _PARADEDB_SEARCH_TOKEN, + "###": _PARADEDB_SEARCH_TOKEN, + "##": _PARADEDB_SEARCH_TOKEN, + "##>": _PARADEDB_SEARCH_TOKEN, + } + + +class ParadeDBParser(PGVectorParser): + """Parser with ParadeDB search operators and pgvector distance operators.""" + + FACTOR = {**PGVectorParser.FACTOR, _PARADEDB_SEARCH_TOKEN: SearchOperator} + + +class ParadeDBGenerator(PGVectorGenerator): + """Generator that renders ParadeDB search operators and pgvector distance operators.""" + + def searchoperator_sql(self, expression: SearchOperator) -> str: + op = expression.args.get("operator", "@@@") + left = self.sql(expression, "this") + right = self.sql(expression, "expression") + return f"{left} {op} {right}" + + +class ParadeDB(PGVector): + """ParadeDB dialect with pg_search and pgvector extension support.""" + + Tokenizer = ParadeDBTokenizer + Parser = ParadeDBParser + Generator = ParadeDBGenerator + + +Dialect.classes["paradedb"] = ParadeDB diff --git a/sqlspec/data_dictionary/extension_dialects/postgres/pgvector.py b/sqlspec/data_dictionary/extension_dialects/postgres/pgvector.py new file mode 100644 index 00000000..77e892cd --- /dev/null +++ b/sqlspec/data_dictionary/extension_dialects/postgres/pgvector.py @@ -0,0 +1,98 @@ +"""PGVector dialect extending Postgres with vector distance operators. + +Adds support for pgvector distance operators: +- <-> : L2 (Euclidean) distance (already in base Postgres) +- <#> : Negative inner product +- <=> : Cosine distance +- <+> : L1 (Taxicab/Manhattan) distance +- <~> : Hamming distance (binary vectors) +- <%> : Jaccard distance (binary vectors) +""" + +from __future__ import annotations + +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.dialects.postgres import Postgres +from sqlglot.tokens import TokenType + +__all__ = ("PGVector",) + +# Use a single unused token type for all pgvector distance operators. +# The actual operator string is captured during parsing and stored in the expression. +# SQLGlot is not going to add extension operators, even as unused tokens, so this allows us +# to work around the limitation: https://github.com/tobymao/sqlglot/issues/6949 +_PGVECTOR_DISTANCE_TOKEN = TokenType.CARET_AT + + +class VectorDistance(exp.Binary): + """Vector distance operation that preserves the original operator.""" + + arg_types = {"this": True, "expression": True, "operator": True} + + +class PGVectorTokenizer(Postgres.Tokenizer): + """Tokenizer with pgvector distance operators.""" + + KEYWORDS = { + **Postgres.Tokenizer.KEYWORDS, + "<#>": _PGVECTOR_DISTANCE_TOKEN, + "<=>": _PGVECTOR_DISTANCE_TOKEN, + "<+>": _PGVECTOR_DISTANCE_TOKEN, + "<~>": _PGVECTOR_DISTANCE_TOKEN, + "<%>": _PGVECTOR_DISTANCE_TOKEN, + } + + +class PGVectorParser(Postgres.Parser): + """Parser that captures the original operator string for pgvector operations.""" + + FACTOR = {**Postgres.Parser.FACTOR, _PGVECTOR_DISTANCE_TOKEN: VectorDistance} + + def _parse_factor(self) -> exp.Expression | None: + parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary + this = self._parse_at_time_zone(parse_method()) + + while self._match_set(self.FACTOR): + klass = self.FACTOR[self._prev.token_type] + comments = self._prev_comments + operator_text = self._prev.text + expression = parse_method() + + if not expression and klass is exp.IntDiv and self._prev.text.isalpha(): + self._retreat(self._index - 1) + return this + + if "operator" in klass.arg_types: + this = self.expression( + klass, this=this, comments=comments, expression=expression, operator=operator_text + ) + else: + this = self.expression(klass, this=this, comments=comments, expression=expression) + + if isinstance(this, exp.Div): + this.set("typed", self.dialect.TYPED_DIVISION) + this.set("safe", self.dialect.SAFE_DIVISION) + + return this + + +class PGVectorGenerator(Postgres.Generator): + """Generator that renders pgvector distance operators.""" + + def vectordistance_sql(self, expression: VectorDistance) -> str: + op = expression.args.get("operator", "<->") + left = self.sql(expression, "this") + right = self.sql(expression, "expression") + return f"{left} {op} {right}" + + +class PGVector(Postgres): + """PostgreSQL dialect with pgvector extension support.""" + + Tokenizer = PGVectorTokenizer + Parser = PGVectorParser + Generator = PGVectorGenerator + + +Dialect.classes["pgvector"] = PGVector From 9545769d3b1eda78f312bf728cad5e7526a153e0 Mon Sep 17 00:00:00 2001 From: Jordan Russell Date: Thu, 5 Feb 2026 07:59:18 +1300 Subject: [PATCH 2/6] test(dialects): add tests for the ParadeDB & PGVector SQLGlot dialects --- tests/unit/dialects/test_paradedb.py | 98 ++++++++++++++++++++++++++++ tests/unit/dialects/test_pgvector.py | 60 +++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 tests/unit/dialects/test_paradedb.py create mode 100644 tests/unit/dialects/test_pgvector.py diff --git a/tests/unit/dialects/test_paradedb.py b/tests/unit/dialects/test_paradedb.py new file mode 100644 index 00000000..ce2f928e --- /dev/null +++ b/tests/unit/dialects/test_paradedb.py @@ -0,0 +1,98 @@ +"""Dialect unit tests for the ParadeDB (PostgreSQL + pgvector + pg_search) dialect.""" + +from sqlglot import parse_one + +import sqlspec.data_dictionary.extension_dialects.postgres.paradedb # noqa: F401 + + +def _render(sql: str) -> str: + return parse_one(sql, dialect="paradedb").sql(dialect="paradedb") + + +def test_bm25_search_operator() -> None: + sql = "SELECT * FROM mock_items WHERE description @@@ 'shoes'" + rendered = _render(sql) + assert "@@@" in rendered + + +def test_match_conjunction_operator() -> None: + sql = "SELECT * FROM mock_items WHERE description &&& 'running shoes'" + rendered = _render(sql) + assert "&&&" in rendered + + +def test_match_disjunction_operator() -> None: + sql = "SELECT * FROM mock_items WHERE description ||| 'shoes'" + rendered = _render(sql) + assert "|||" in rendered + + +def test_term_query_operator() -> None: + sql = "SELECT * FROM mock_items WHERE category === 'footwear'" + rendered = _render(sql) + assert "===" in rendered + + +def test_phrase_query_operator() -> None: + sql = "SELECT * FROM mock_items WHERE description ### 'running shoes'" + rendered = _render(sql) + assert "###" in rendered + + +def test_proximity_operator() -> None: + sql = "SELECT description, rating, category FROM mock_items WHERE description @@@ ('sleek' ## 1 ## 'shoes')" + rendered = _render(sql) + assert "##" in rendered + + +def test_directional_proximity_operator() -> None: + sql = "SELECT description, rating, category FROM mock_items WHERE description @@@ ('sleek' ##> 1 ##> 'shoes')" + rendered = _render(sql) + assert "##>" in rendered + + +def test_snippet_with_named_args() -> None: + sql = ( + "SELECT id, pdb.snippet(description, start_tag => '', end_tag => '') " + "FROM mock_items WHERE description ||| 'shoes' LIMIT 5" + ) + rendered = _render(sql) + assert "|||" in rendered + assert "pdb.snippet" in rendered + + +def test_pgvector_operators_still_work() -> None: + sql = "SELECT embedding <=> '[1,2,3]' AS cosine_dist, embedding <#> '[1,2,3]' AS inner_prod FROM items" + rendered = _render(sql) + assert "<=>" in rendered + assert "<#>" in rendered + + +def test_search_in_where_clause() -> None: + sql = "SELECT * FROM mock_items WHERE description @@@ 'shoes' AND active = TRUE" + rendered = _render(sql) + assert "@@@" in rendered + assert "WHERE" in rendered + + +def test_multiple_search_operators() -> None: + sql = ( + "SELECT description ## 'query' AS snippet, description ### 'query' AS score " + "FROM mock_items WHERE description @@@ 'query'" + ) + rendered = _render(sql) + assert "##" in rendered + assert "###" in rendered + assert "@@@" in rendered + + +def test_fuzzy_cast() -> None: + sql = "SELECT * FROM mock_items WHERE description @@@ 'runing shose'::pdb.fuzzy(2)" + rendered = _render(sql) + assert "@@@" in rendered + + +def test_prox_regex() -> None: + sql = "SELECT * FROM mock_items WHERE description @@@ pdb.prox_regex('sho.*', 2, 'run.*')" + rendered = _render(sql) + assert "@@@" in rendered diff --git a/tests/unit/dialects/test_pgvector.py b/tests/unit/dialects/test_pgvector.py new file mode 100644 index 00000000..ad32ad39 --- /dev/null +++ b/tests/unit/dialects/test_pgvector.py @@ -0,0 +1,60 @@ +"""Dialect unit tests for the PGVector (PostgreSQL + pgvector) dialect.""" + +from sqlglot import parse_one + +import sqlspec.data_dictionary.extension_dialects.postgres.pgvector # noqa: F401 + + +def _render(sql: str) -> str: + return parse_one(sql, dialect="pgvector").sql(dialect="pgvector") + + +def test_cosine_distance_operator() -> None: + sql = "SELECT embedding <=> '[1,2,3]' FROM items" + rendered = _render(sql) + assert "<=>" in rendered + + +def test_negative_inner_product_operator() -> None: + sql = "SELECT embedding <#> '[1,2,3]' FROM items" + rendered = _render(sql) + assert "<#>" in rendered + + +def test_l1_distance_operator() -> None: + sql = "SELECT embedding <+> '[1,2,3]' FROM items" + rendered = _render(sql) + assert "<+>" in rendered + + +def test_hamming_distance_operator() -> None: + sql = "SELECT embedding <~> '[1,0,1]' FROM items" + rendered = _render(sql) + assert "<~>" in rendered + + +def test_jaccard_distance_operator() -> None: + sql = "SELECT embedding <%> '[1,0,1]' FROM items" + rendered = _render(sql) + assert "<%>" in rendered + + +def test_order_by_cosine_distance() -> None: + sql = "SELECT * FROM items ORDER BY embedding <=> '[1,2,3]'" + rendered = _render(sql) + assert "ORDER BY" in rendered + assert "<=>" in rendered + + +def test_distance_in_where_clause() -> None: + sql = "SELECT * FROM items WHERE embedding <=> '[1,2,3]' < 0.5" + rendered = _render(sql) + assert "<=>" in rendered + assert "WHERE" in rendered + + +def test_multiple_distance_operators() -> None: + sql = "SELECT embedding <=> '[1,2,3]' AS cosine_dist, embedding <#> '[1,2,3]' AS inner_prod FROM items" + rendered = _render(sql) + assert "<=>" in rendered + assert "<#>" in rendered From 1be6b7248a7e020b765440ef9360b707daf19150 Mon Sep 17 00:00:00 2001 From: Jordan Russell Date: Thu, 5 Feb 2026 08:13:38 +1300 Subject: [PATCH 3/6] feat(asyncpg): enable pgvector and paradedb dialects based on detected featureset Adds support to the asyncpg dialect to register paradedb extension availability and dialects based on that and pgvector availability --- sqlspec/adapters/asyncpg/__init__.py | 1 + sqlspec/adapters/asyncpg/config.py | 71 +++++++++++++++++++++++----- sqlspec/adapters/asyncpg/core.py | 1 + 3 files changed, 61 insertions(+), 12 deletions(-) diff --git a/sqlspec/adapters/asyncpg/__init__.py b/sqlspec/adapters/asyncpg/__init__.py index bcfad5bb..40104c13 100644 --- a/sqlspec/adapters/asyncpg/__init__.py +++ b/sqlspec/adapters/asyncpg/__init__.py @@ -4,6 +4,7 @@ from sqlspec.adapters.asyncpg.config import AsyncpgConfig, AsyncpgConnectionConfig, AsyncpgPoolConfig from sqlspec.adapters.asyncpg.core import default_statement_config from sqlspec.adapters.asyncpg.driver import AsyncpgCursor, AsyncpgDriver, AsyncpgExceptionHandler +from sqlspec.data_dictionary.extension_dialects import postgres # noqa: F401 __all__ = ( "AsyncpgConfig", diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index 80416835..64126871 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -102,6 +102,11 @@ class AsyncpgDriverFeatures(TypedDict): Defaults to True when pgvector-python is installed. Provides automatic conversion between Python objects and PostgreSQL vector types. Enables vector similarity operations and index support. + enable_paradedb: Enable ParadeDB (pg_search) extension detection. + When enabled and the pg_search extension is detected, the SQL dialect + switches to "paradedb" which supports search operators (@@@, &&&, etc.) + and inherits all pgvector distance operators. + Defaults to True. Independent of enable_pgvector. enable_cloud_sql: Enable Google Cloud SQL connector integration. Requires cloud-sql-python-connector package. Defaults to False (explicit opt-in required). @@ -146,6 +151,7 @@ class AsyncpgDriverFeatures(TypedDict): json_deserializer: NotRequired["Callable[[str], Any]"] enable_json_codecs: NotRequired[bool] enable_pgvector: NotRequired[bool] + enable_paradedb: NotRequired[bool] enable_cloud_sql: NotRequired[bool] cloud_sql_instance: NotRequired[str] cloud_sql_enable_iam_auth: NotRequired[bool] @@ -328,6 +334,7 @@ def __init__( self._cloud_sql_connector: Any | None = None self._alloydb_connector: Any | None = None self._pgvector_available: bool | None = None + self._paradedb_available: bool | None = None self._validate_connector_config() @@ -435,7 +442,42 @@ async def _create_pool(self) -> "Pool[Record]": config.setdefault("init", self._init_connection) - return await asyncpg_create_pool(**config) + pool = await asyncpg_create_pool(**config) + await self._detect_extensions(pool) + return pool + + async def _detect_extensions(self, pool: "Pool[Record]") -> None: + """Detect database extensions and update dialect accordingly. + + Args: + pool: Connection pool to acquire a connection from. + """ + extensions = [ + name + for name, enabled in [ + ("vector", self.driver_features.get("enable_pgvector", False)), + ("pg_search", self.driver_features.get("enable_paradedb", False)), + ] + if enabled + ] + if not extensions: + return + + connection = await pool.acquire() + try: + results = await connection.fetch( + "SELECT extname FROM pg_extension WHERE extname = ANY($1::text[])", extensions + ) + detected = {r["extname"] for r in results} + self._pgvector_available = "vector" in detected + self._paradedb_available = "pg_search" in detected + except Exception: + self._pgvector_available = False + self._paradedb_available = False + finally: + await pool.release(connection) + + self._update_dialect_for_extensions() async def _init_connection(self, connection: "AsyncpgConnection") -> None: """Initialize connection with JSON codecs, pgvector support, and user callback. @@ -450,22 +492,27 @@ async def _init_connection(self, connection: "AsyncpgConnection") -> None: decoder=self.driver_features.get("json_deserializer", from_json), ) - if self.driver_features.get("enable_pgvector", False): - if self._pgvector_available is None: - try: - result = await connection.fetchval("SELECT 1 FROM pg_extension WHERE extname = 'vector'") - self._pgvector_available = bool(result) - except Exception: - # If we can't query extensions, assume false to be safe and avoid errors - self._pgvector_available = False - - if self._pgvector_available: - await register_pgvector_support(connection) + if self._pgvector_available: + await register_pgvector_support(connection) # Call user-provided callback after internal setup if self._user_connection_hook is not None: await self._user_connection_hook(connection) + def _update_dialect_for_extensions(self) -> None: + """Update statement_config dialect based on detected extensions. + + Priority: paradedb > pgvector > postgres (default). + """ + current_dialect = getattr(self.statement_config, "dialect", "postgres") + if current_dialect != "postgres": + return + + if self._paradedb_available: + self.statement_config = self.statement_config.replace(dialect="paradedb") + elif self._pgvector_available: + self.statement_config = self.statement_config.replace(dialect="pgvector") + async def _close_pool(self) -> None: """Close the actual async connection pool and cleanup connectors.""" if self.connection_instance: diff --git a/sqlspec/adapters/asyncpg/core.py b/sqlspec/adapters/asyncpg/core.py index c4cc80e8..5792b85c 100644 --- a/sqlspec/adapters/asyncpg/core.py +++ b/sqlspec/adapters/asyncpg/core.py @@ -244,6 +244,7 @@ def apply_driver_features( deserializer = processed_features.setdefault("json_deserializer", from_json) processed_features.setdefault("enable_json_codecs", True) processed_features.setdefault("enable_pgvector", PGVECTOR_INSTALLED) + processed_features.setdefault("enable_paradedb", True) processed_features.setdefault("enable_cloud_sql", False) processed_features.setdefault("enable_alloydb", False) From bdc5fe9e72f9fde40f720318178683517cecd35b Mon Sep 17 00:00:00 2001 From: Jordan Russell Date: Thu, 5 Feb 2026 08:24:57 +1300 Subject: [PATCH 4/6] chore(tests): fix cloud connector tests with no-op _detect_extensions --- tests/unit/adapters/test_asyncpg/test_cloud_connectors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py b/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py index 89e0ac2e..4d14a8d2 100644 --- a/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py +++ b/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py @@ -16,7 +16,8 @@ def disable_connectors_by_default(): """Disable both connectors by default for clean test state.""" with patch("sqlspec.adapters.asyncpg.config.CLOUD_SQL_CONNECTOR_INSTALLED", False): with patch("sqlspec.adapters.asyncpg.config.ALLOYDB_CONNECTOR_INSTALLED", False): - yield + with patch.object(AsyncpgConfig, "_detect_extensions", new_callable=AsyncMock): + yield @pytest.fixture From 388814a4c1cbbc105007e30ddef8fc61844e614d Mon Sep 17 00:00:00 2001 From: Jordan Russell Date: Wed, 25 Feb 2026 07:18:40 +1300 Subject: [PATCH 5/6] feat(psycopg): enable pgvector and paradedb dialects based on detected features --- sqlspec/adapters/psycopg/__init__.py | 1 + sqlspec/adapters/psycopg/config.py | 118 +++++++++++++++++++++++++-- sqlspec/adapters/psycopg/core.py | 1 + 3 files changed, 114 insertions(+), 6 deletions(-) diff --git a/sqlspec/adapters/psycopg/__init__.py b/sqlspec/adapters/psycopg/__init__.py index 890644e3..b63de6c7 100644 --- a/sqlspec/adapters/psycopg/__init__.py +++ b/sqlspec/adapters/psycopg/__init__.py @@ -14,6 +14,7 @@ PsycopgSyncDriver, PsycopgSyncExceptionHandler, ) +from sqlspec.data_dictionary.extension_dialects import postgres # noqa: F401 __all__ = ( "PsycopgAsyncConfig", diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 40cfeb6b..f4e4b774 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -75,6 +75,11 @@ class PsycopgDriverFeatures(TypedDict): Provides automatic conversion between Python objects and PostgreSQL vector types. Enables vector similarity operations and index support. Set to False to disable pgvector support even when package is available. + enable_paradedb: Enable ParadeDB (pg_search) extension detection. + When enabled and the pg_search extension is detected, the SQL dialect + switches to "paradedb" which supports search operators (@@@, &&&, etc.) + and inherits all pgvector distance operators. + Defaults to True. Independent of enable_pgvector. json_serializer: Custom JSON serializer for StatementConfig parameter handling. json_deserializer: Custom JSON deserializer reference stored alongside the serializer for parity with asyncpg. on_connection_create: Callback executed when a connection is created/acquired from the pool. @@ -95,6 +100,7 @@ class PsycopgDriverFeatures(TypedDict): """ enable_pgvector: NotRequired[bool] + enable_paradedb: NotRequired[bool] json_serializer: NotRequired["Callable[[Any], str]"] json_deserializer: NotRequired["Callable[[str], Any]"] on_connection_create: NotRequired["Callable[..., Any]"] @@ -210,6 +216,8 @@ def __init__( self._user_connection_hook: Callable[[PsycopgSyncConnection], None] | None = features_dict.pop( "on_connection_create", None ) + self._pgvector_available: bool | None = None + self._paradedb_available: bool | None = None super().__init__( connection_config=connection_config, @@ -244,18 +252,67 @@ def _create_pool(self) -> "ConnectionPool": conninfo = all_config.pop("conninfo", None) if conninfo: - return ConnectionPool(conninfo, open=True, **pool_parameters) + pool = ConnectionPool(conninfo, open=True, **pool_parameters) + else: + kwargs = all_config.pop("kwargs", {}) + all_config.update(kwargs) + pool = ConnectionPool("", kwargs=all_config, open=True, **pool_parameters) + + self._detect_extensions(pool) + return pool + + def _detect_extensions(self, pool: "ConnectionPool") -> None: + """Detect database extensions and update dialect accordingly. + + Args: + pool: Connection pool to acquire a connection from. + """ + extensions = [ + name + for name, enabled in [ + ("vector", self.driver_features.get("enable_pgvector", False)), + ("pg_search", self.driver_features.get("enable_paradedb", False)), + ] + if enabled + ] + if not extensions: + return + + with pool.connection() as connection: + try: + cursor = connection.execute( + "SELECT extname FROM pg_extension WHERE extname = ANY(%s::text[])", (extensions,) + ) + results = cursor.fetchall() + detected = {r[0] for r in results} + self._pgvector_available = "vector" in detected + self._paradedb_available = "pg_search" in detected + except Exception: + self._pgvector_available = False + self._paradedb_available = False + + self._update_dialect_for_extensions() + + def _update_dialect_for_extensions(self) -> None: + """Update statement_config dialect based on detected extensions. + + Priority: paradedb > pgvector > postgres (default). + """ + current_dialect = getattr(self.statement_config, "dialect", "postgres") + if current_dialect != "postgres": + return - kwargs = all_config.pop("kwargs", {}) - all_config.update(kwargs) - return ConnectionPool("", kwargs=all_config, open=True, **pool_parameters) + if self._paradedb_available: + self.statement_config = self.statement_config.replace(dialect="paradedb") + elif self._pgvector_available: + self.statement_config = self.statement_config.replace(dialect="pgvector") def _configure_connection(self, conn: "PsycopgSyncConnection") -> None: autocommit_setting = self.connection_config.get("autocommit") if autocommit_setting is not None: conn.autocommit = autocommit_setting - if self.driver_features.get("enable_pgvector", False): + if self._pgvector_available: register_pgvector_sync(conn) # Call user-provided callback after internal setup @@ -448,6 +505,8 @@ def __init__( self._user_connection_hook: Callable[[PsycopgAsyncConnection], Awaitable[None]] | None = features_dict.pop( "on_connection_create", None ) + self._pgvector_available: bool | None = None + self._paradedb_available: bool | None = None super().__init__( connection_config=connection_config, @@ -490,15 +549,62 @@ async def _create_pool(self) -> "AsyncConnectionPool": pool = AsyncConnectionPool("", kwargs=all_config, open=False, **pool_parameters) await pool.open() + await self._detect_extensions(pool) return pool + async def _detect_extensions(self, pool: "AsyncConnectionPool") -> None: + """Detect database extensions and update dialect accordingly. + + Args: + pool: Connection pool to acquire a connection from. + """ + extensions = [ + name + for name, enabled in [ + ("vector", self.driver_features.get("enable_pgvector", False)), + ("pg_search", self.driver_features.get("enable_paradedb", False)), + ] + if enabled + ] + if not extensions: + return + + async with pool.connection() as connection: + try: + cursor = await connection.execute( + "SELECT extname FROM pg_extension WHERE extname = ANY(%s::text[])", (extensions,) + ) + results = await cursor.fetchall() + detected = {r[0] for r in results} + self._pgvector_available = "vector" in detected + self._paradedb_available = "pg_search" in detected + except Exception: + self._pgvector_available = False + self._paradedb_available = False + + self._update_dialect_for_extensions() + + def _update_dialect_for_extensions(self) -> None: + """Update statement_config dialect based on detected extensions. + + Priority: paradedb > pgvector > postgres (default). + """ + current_dialect = getattr(self.statement_config, "dialect", "postgres") + if current_dialect != "postgres": + return + + if self._paradedb_available: + self.statement_config = self.statement_config.replace(dialect="paradedb") + elif self._pgvector_available: + self.statement_config = self.statement_config.replace(dialect="pgvector") + async def _configure_async_connection(self, conn: "PsycopgAsyncConnection") -> None: autocommit_setting = self.connection_config.get("autocommit") if autocommit_setting is not None: await conn.set_autocommit(autocommit_setting) - if self.driver_features.get("enable_pgvector", False): + if self._pgvector_available: await register_pgvector_async(conn) # Call user-provided callback after internal setup diff --git a/sqlspec/adapters/psycopg/core.py b/sqlspec/adapters/psycopg/core.py index 1125a861..bf771619 100644 --- a/sqlspec/adapters/psycopg/core.py +++ b/sqlspec/adapters/psycopg/core.py @@ -198,6 +198,7 @@ def apply_driver_features( serializer = features.get("json_serializer", to_json) features.setdefault("json_serializer", serializer) features.setdefault("enable_pgvector", PGVECTOR_INSTALLED) + features.setdefault("enable_paradedb", True) parameter_config = _build_psycopg_parameter_config(driver_profile, serializer) statement_config = statement_config.replace(parameter_config=parameter_config) From fe4007feb00a95bc0b8de6c340033c7a81743051 Mon Sep 17 00:00:00 2001 From: Jordan Russell Date: Wed, 25 Feb 2026 07:21:10 +1300 Subject: [PATCH 6/6] feat(psqlpy): enable pgvector and paradedb dialects based on detected featuresets --- sqlspec/adapters/psqlpy/__init__.py | 1 + sqlspec/adapters/psqlpy/config.py | 57 ++++++++++++++++++++++++++++- sqlspec/adapters/psqlpy/core.py | 1 + 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/sqlspec/adapters/psqlpy/__init__.py b/sqlspec/adapters/psqlpy/__init__.py index ecd92ecd..f8b30bfe 100644 --- a/sqlspec/adapters/psqlpy/__init__.py +++ b/sqlspec/adapters/psqlpy/__init__.py @@ -4,6 +4,7 @@ from sqlspec.adapters.psqlpy.config import PsqlpyConfig, PsqlpyConnectionParams, PsqlpyPoolParams from sqlspec.adapters.psqlpy.core import default_statement_config from sqlspec.adapters.psqlpy.driver import PsqlpyCursor, PsqlpyDriver, PsqlpyExceptionHandler +from sqlspec.data_dictionary.extension_dialects import postgres # noqa: F401 __all__ = ( "PsqlpyConfig", diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index f8f255c8..81ebbf9e 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -79,6 +79,11 @@ class PsqlpyDriverFeatures(TypedDict): Requires pgvector-python package installed. Defaults to True when pgvector is installed. Provides automatic conversion between NumPy arrays and PostgreSQL vector types. + enable_paradedb: Enable ParadeDB (pg_search) extension detection. + When enabled and the pg_search extension is detected, the SQL dialect + switches to "paradedb" which supports search operators (@@@, &&&, etc.) + and inherits all pgvector distance operators. + Defaults to True. Independent of enable_pgvector. json_serializer: Custom JSON serializer applied to the statement configuration. json_deserializer: Custom JSON deserializer retained alongside the serializer for parity with asyncpg. on_connection_create: Async callback executed when a connection is acquired from pool. @@ -97,6 +102,7 @@ class PsqlpyDriverFeatures(TypedDict): """ enable_pgvector: NotRequired[bool] + enable_paradedb: NotRequired[bool] json_serializer: NotRequired["Callable[[Any], str]"] json_deserializer: NotRequired["Callable[[str], Any]"] on_connection_create: "NotRequired[Callable[[PsqlpyConnection], Awaitable[None]]]" @@ -207,6 +213,8 @@ def __init__( ) # Track initialized connections by ID (psqlpy connections don't support weak refs) self._initialized_connection_ids: set[int] = set() + self._pgvector_available: bool | None = None + self._paradedb_available: bool | None = None super().__init__( connection_config=connection_config, @@ -230,7 +238,54 @@ async def _ensure_connection_initialized(self, connection: "PsqlpyConnection") - async def _create_pool(self) -> "ConnectionPool": """Create the actual async connection pool.""" - return ConnectionPool(**build_connection_config(self.connection_config)) + pool = ConnectionPool(**build_connection_config(self.connection_config)) + await self._detect_extensions(pool) + return pool + + async def _detect_extensions(self, pool: "ConnectionPool") -> None: + """Detect database extensions and update dialect accordingly. + + Args: + pool: Connection pool to acquire a connection from. + """ + extensions = [ + name + for name, enabled in [ + ("vector", self.driver_features.get("enable_pgvector", False)), + ("pg_search", self.driver_features.get("enable_paradedb", False)), + ] + if enabled + ] + if not extensions: + return + + async with pool.acquire() as connection: + try: + result = await connection.fetch( + "SELECT extname FROM pg_extension WHERE extname = ANY($1::text[])", [extensions] + ) + detected = {r["extname"] for r in result} + self._pgvector_available = "vector" in detected + self._paradedb_available = "pg_search" in detected + except Exception: + self._pgvector_available = False + self._paradedb_available = False + + self._update_dialect_for_extensions() + + def _update_dialect_for_extensions(self) -> None: + """Update statement_config dialect based on detected extensions. + + Priority: paradedb > pgvector > postgres (default). + """ + current_dialect = getattr(self.statement_config, "dialect", "postgres") + if current_dialect != "postgres": + return + + if self._paradedb_available: + self.statement_config = self.statement_config.replace(dialect="paradedb") + elif self._pgvector_available: + self.statement_config = self.statement_config.replace(dialect="pgvector") async def _close_pool(self) -> None: """Close the actual async connection pool.""" diff --git a/sqlspec/adapters/psqlpy/core.py b/sqlspec/adapters/psqlpy/core.py index b94f02f9..85b71947 100644 --- a/sqlspec/adapters/psqlpy/core.py +++ b/sqlspec/adapters/psqlpy/core.py @@ -249,6 +249,7 @@ def apply_driver_features( serializer = features.get("json_serializer", to_json) features.setdefault("json_serializer", serializer) features.setdefault("enable_pgvector", PGVECTOR_INSTALLED) + features.setdefault("enable_paradedb", True) parameter_config = _build_psqlpy_parameter_config(driver_profile, serializer) statement_config = statement_config.replace(parameter_config=parameter_config)