From 9e5b48d3f9e1cdc20970932cd48e31bfe79f33ad Mon Sep 17 00:00:00 2001 From: fresioAS Date: Wed, 29 Apr 2026 08:38:13 +0200 Subject: [PATCH 01/10] initial engine adapter --- sqlmesh/core/config/connection.py | 47 ++++ sqlmesh/core/engine_adapter/__init__.py | 3 + sqlmesh/core/engine_adapter/feldera.py | 154 +++++++++++++ sqlmesh/engines/feldera/__init__.py | 1 + sqlmesh/engines/feldera/db_api.py | 255 ++++++++++++++++++++++ tests/core/engine_adapter/test_feldera.py | 112 ++++++++++ tests/core/test_connection_config.py | 20 ++ 7 files changed, 592 insertions(+) create mode 100644 sqlmesh/core/engine_adapter/feldera.py create mode 100644 sqlmesh/engines/feldera/__init__.py create mode 100644 sqlmesh/engines/feldera/db_api.py create mode 100644 tests/core/engine_adapter/test_feldera.py diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index d930537711..557e111b6e 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -2341,6 +2341,53 @@ def init(cursor: t.Any) -> None: return init + +class FelderaConnectionConfig(ConnectionConfig): + """Feldera connection configuration.""" + + host: str = "http://localhost:8080" + api_key: t.Optional[str] = None + pipeline_name: str + compilation_profile: str = "dev" + workers: int = 4 + timeout: int = 300 + + type_: t.Literal["feldera"] = Field(alias="type", default="feldera") + DIALECT: t.ClassVar[t.Literal["felderadialect"]] = "felderadialect" + DISPLAY_NAME: t.ClassVar[t.Literal["Feldera"]] = "Feldera" + DISPLAY_ORDER: t.ClassVar[t.Literal[18]] = 18 + + concurrent_tasks: int = 1 + register_comments: t.Literal[False] = False + pre_ping: t.Literal[False] = False + + _engine_import_validator = _get_engine_import_validator("feldera", "feldera") + + @property + def _connection_kwargs_keys(self) -> t.Set[str]: + return { + "host", + "api_key", + "pipeline_name", + "workers", + "compilation_profile", + "timeout", + } + + @property + def _engine_adapter(self) -> t.Type[EngineAdapter]: + from sqlmesh.core.engine_adapter.feldera import FelderaEngineAdapter + + return FelderaEngineAdapter + + @property + def _connection_factory(self) -> t.Callable: + from sqlmesh.engines.feldera.db_api import connect + + return connect + + def get_catalog(self) -> t.Optional[str]: + return None _CONNECTION_CONFIG_EXCLUDE: t.Set[t.Type[ConnectionConfig]] = { ConnectionConfig, # type: ignore[type-abstract] diff --git a/sqlmesh/core/engine_adapter/__init__.py b/sqlmesh/core/engine_adapter/__init__.py index ab29885c7b..a9d30a0732 100644 --- a/sqlmesh/core/engine_adapter/__init__.py +++ b/sqlmesh/core/engine_adapter/__init__.py @@ -20,6 +20,7 @@ from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter from sqlmesh.core.engine_adapter.fabric import FabricEngineAdapter +from sqlmesh.core.engine_adapter.feldera import FelderaEngineAdapter DIALECT_TO_ENGINE_ADAPTER = { "hive": SparkEngineAdapter, @@ -37,6 +38,8 @@ "athena": AthenaEngineAdapter, "risingwave": RisingwaveEngineAdapter, "fabric": FabricEngineAdapter, + "felderadialect": FelderaEngineAdapter, + "feldera": FelderaEngineAdapter, } DIALECT_ALIASES = { diff --git a/sqlmesh/core/engine_adapter/feldera.py b/sqlmesh/core/engine_adapter/feldera.py new file mode 100644 index 0000000000..6ea24e2102 --- /dev/null +++ b/sqlmesh/core/engine_adapter/feldera.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp +from sqlglot.dialects import generator +from sqlglot.dialects.dialect import Dialect + +from sqlmesh.core.dialect import to_schema +from sqlmesh.core.engine_adapter.base import EngineAdapter +from sqlmesh.core.engine_adapter.shared import ( + CommentCreationTable, + CommentCreationView, + DataObject, + DataObjectType, +) +from sqlmesh.utils.errors import SQLMeshError + +if t.TYPE_CHECKING: + import pandas as pd + + from sqlmesh.core._typing import SchemaName, TableName + + +class FelderaDialect(Dialect): + class Generator(generator.Generator): + TYPE_MAPPING = { + **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.INT: "INTEGER", + } + + +_FELDERA_TO_EXP_TYPE: t.Dict[str, exp.DataType.Type] = { + "BOOLEAN": exp.DataType.Type.BOOLEAN, + "TINYINT": exp.DataType.Type.TINYINT, + "SMALLINT": exp.DataType.Type.SMALLINT, + "INTEGER": exp.DataType.Type.INT, + "INT": exp.DataType.Type.INT, + "BIGINT": exp.DataType.Type.BIGINT, + "REAL": exp.DataType.Type.FLOAT, + "DOUBLE": exp.DataType.Type.DOUBLE, + "DECIMAL": exp.DataType.Type.DECIMAL, + "NUMERIC": exp.DataType.Type.DECIMAL, + "VARCHAR": exp.DataType.Type.VARCHAR, + "CHAR": exp.DataType.Type.CHAR, + "DATE": exp.DataType.Type.DATE, + "TIME": exp.DataType.Type.TIME, + "TIMESTAMP": exp.DataType.Type.TIMESTAMP, + "ARRAY": exp.DataType.Type.ARRAY, +} + + +def _feldera_type_to_exp(dtype_str: str) -> exp.DataType: + base = dtype_str.split("(")[0].strip().upper() + kind = _FELDERA_TO_EXP_TYPE.get(base, exp.DataType.Type.TEXT) + return exp.DataType(this=kind) + + +class FelderaEngineAdapter(EngineAdapter): + DIALECT = "felderadialect" + SUPPORTS_TRANSACTIONS = False + SUPPORTS_INDEXES = False + SUPPORTS_MATERIALIZED_VIEWS = True + SUPPORTS_REPLACE_TABLE = False + COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED + COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED + + def _fetch_native_df( + self, + query: t.Union[exp.Expression, str], + quote_identifiers: bool = False, + ) -> pd.DataFrame: + with self.transaction(): + self.execute(query, quote_identifiers=quote_identifiers) + return self.cursor.fetchdf() + + def _get_data_objects( + self, + schema_name: SchemaName, + object_names: t.Optional[t.Set[str]] = None, + ) -> t.List[DataObject]: + from feldera.pipeline import Pipeline + + connection = self.connection + pipeline_name = to_schema(schema_name).db + lower_object_names = ( + {name.lower() for name in object_names} if object_names else None + ) + + try: + pipeline = Pipeline.get(pipeline_name, connection._client) + except Exception: + return [] + + objects: t.List[DataObject] = [] + for table in pipeline.tables(): + name = table.name.lower() + if lower_object_names and name not in lower_object_names: + continue + objects.append( + DataObject( + catalog=None, + schema=pipeline_name, + name=name, + type=DataObjectType.TABLE, + ) + ) + + for view in pipeline.views(): + name = view.name.lower() + if lower_object_names and name not in lower_object_names: + continue + objects.append( + DataObject( + catalog=None, + schema=pipeline_name, + name=name, + type=DataObjectType.MATERIALIZED_VIEW, + ) + ) + + return objects + + def columns( + self, table_name: TableName, include_pseudo_columns: bool = False + ) -> t.Dict[str, exp.DataType]: + from feldera.pipeline import Pipeline + + connection = self.connection + pipeline = Pipeline.get(connection._pipeline_name, connection._client) + target = exp.to_table(table_name).name.lower() + + for obj in (*pipeline.tables(), *pipeline.views()): + if obj.name.lower() == target: + return { + field["name"]: _feldera_type_to_exp( + field.get("columntype", {}).get("type", "VARCHAR") + if isinstance(field.get("columntype"), dict) + else str(field.get("columntype", "VARCHAR")) + ) + for field in (obj.fields or []) + } + + raise SQLMeshError( + "Table/view " + f"'{target}' not found in pipeline '{connection._pipeline_name}'" + ) + + def get_current_catalog(self) -> t.Optional[str]: + return None + + def ping(self) -> None: + self.connection._client.get_config() diff --git a/sqlmesh/engines/feldera/__init__.py b/sqlmesh/engines/feldera/__init__.py new file mode 100644 index 0000000000..6c43ea250f --- /dev/null +++ b/sqlmesh/engines/feldera/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations \ No newline at end of file diff --git a/sqlmesh/engines/feldera/db_api.py b/sqlmesh/engines/feldera/db_api.py new file mode 100644 index 0000000000..3d70a47747 --- /dev/null +++ b/sqlmesh/engines/feldera/db_api.py @@ -0,0 +1,255 @@ +from __future__ import annotations + +import logging +import threading +import typing as t +from enum import Enum + +logger = logging.getLogger(__name__) + +if t.TYPE_CHECKING: + import pandas as pd + + +class SqlIntent(Enum): + ADHOC_QUERY = "adhoc_query" + PIPELINE_DDL = "pipeline_ddl" + DATA_INGRESS = "data_ingress" + NO_OP = "no_op" + + +def _classify(sql: str) -> SqlIntent: + """Route SQL to the correct Feldera endpoint.""" + stripped = sql.strip().lstrip("/*").strip() + if not stripped: + return SqlIntent.NO_OP + upper = stripped.upper() + if upper.startswith(("CREATE", "DROP", "ALTER")): + return SqlIntent.PIPELINE_DDL + if upper.startswith("INSERT"): + return SqlIntent.DATA_INGRESS + return SqlIntent.ADHOC_QUERY + + +class PipelineStateManager: + """Accumulates DDL and deploys it as a single Feldera pipeline program.""" + + def __init__(self) -> None: + self._lock = threading.RLock() + self._tables: t.Dict[str, str] = {} + self._views: t.Dict[str, str] = {} + self._pipeline = None + + def register_ddl(self, sql: str) -> None: + with self._lock: + upper = sql.strip().upper() + if "CREATE TABLE" in upper or "CREATE MATERIALIZED TABLE" in upper: + self._tables[_extract_name(sql)] = sql + elif "CREATE VIEW" in upper or "CREATE MATERIALIZED VIEW" in upper: + self._views[_extract_name(sql)] = sql + elif "DROP TABLE" in upper: + self._tables.pop(_extract_name(sql), None) + elif "DROP VIEW" in upper: + self._views.pop(_extract_name(sql), None) + + def assemble_program(self) -> str: + with self._lock: + statements = [ + *(ddl.rstrip(";") + ";" for ddl in self._tables.values()), + *(ddl.rstrip(";") + ";" for ddl in self._views.values()), + ] + return "\n\n".join(statements) + + def deploy( + self, + client: t.Any, + pipeline_name: str, + workers: int = 4, + compilation_profile: str = "dev", + timeout: int = 300, + ) -> t.Any: + from feldera.pipeline_builder import PipelineBuilder + from feldera.rest.pipeline import PipelineStatus + from feldera.runtime_config import RuntimeConfig + + with self._lock: + sql = self.assemble_program() + if not sql.strip(): + return self._pipeline + + runtime_config = RuntimeConfig.default() + runtime_config.workers = workers + + pipeline = PipelineBuilder( + client, + name=pipeline_name, + sql=sql, + compilation_profile=compilation_profile, + runtime_config=runtime_config, + ).create_or_replace(wait=True) + pipeline.start() + pipeline.wait_for_status(PipelineStatus.RUNNING, timeout=timeout) + self._pipeline = pipeline + return pipeline + + +def _extract_name(sql: str) -> str: + """Very basic name extraction until sqlglot-based parsing is added.""" + import re + + match = re.search( + r"(?:CREATE|DROP)\s+(?:MATERIALIZED\s+)?(?:TABLE|VIEW)\s+" + r"(?:IF\s+(?:NOT\s+)?EXISTS\s+)?(\w+)", + sql, + re.IGNORECASE, + ) + return match.group(1).lower() if match else sql[:40] + + +class FelderaCursor: + """DB-API 2.0 cursor backed by Feldera's REST API.""" + + def __init__( + self, + client: t.Any, + pipeline_name: str, + state_manager: PipelineStateManager, + workers: int = 4, + compilation_profile: str = "dev", + timeout: int = 300, + ) -> None: + self._client = client + self._pipeline_name = pipeline_name + self._state = state_manager + self._workers = workers + self._compilation_profile = compilation_profile + self._timeout = timeout + self._rows: t.List[t.Dict[str, t.Any]] = [] + self._columns: t.List[str] = [] + self.rowcount = -1 + self.description: t.Optional[t.List] = None + + def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: + if parameters is not None: + raise NotImplementedError( + "Feldera DB-API does not support query parameters" + ) + + intent = _classify(sql) + logger.debug("Feldera execute (intent=%s): %.200s", intent.value, sql) + + if intent == SqlIntent.NO_OP: + self._rows = [] + self._columns = [] + return + + if intent == SqlIntent.PIPELINE_DDL: + self._state.register_ddl(sql) + self._state.deploy( + self._client, + self._pipeline_name, + self._workers, + self._compilation_profile, + self._timeout, + ) + self._rows = [] + self._columns = [] + return + + if intent == SqlIntent.DATA_INGRESS: + self._get_pipeline().execute(sql) + self._rows = [] + self._columns = [] + return + + rows = list(self._get_pipeline().query(sql)) + self._rows = rows + self._columns = list(rows[0].keys()) if rows else [] + self.rowcount = len(rows) + self.description = [ + (column, None, None, None, None, None, None) + for column in self._columns + ] + + def _get_pipeline(self) -> t.Any: + from feldera.pipeline import Pipeline + + return Pipeline.get(self._pipeline_name, self._client) + + def fetchone(self) -> t.Optional[t.Tuple]: + return tuple(self._rows.pop(0).values()) if self._rows else None + + def fetchmany(self, size: int = 1) -> t.List[t.Tuple]: + batch, self._rows = self._rows[:size], self._rows[size:] + return [tuple(row.values()) for row in batch] + + def fetchall(self) -> t.List[t.Tuple]: + rows, self._rows = self._rows, [] + return [tuple(row.values()) for row in rows] + + def fetchdf(self) -> pd.DataFrame: + import pandas as pd + + rows, self._rows = self._rows, [] + return pd.DataFrame(rows) + + def close(self) -> None: + self._rows = [] + + +class FelderaConnection: + """DB-API 2.0 connection wrapper around FelderaClient.""" + + def __init__( + self, + client: t.Any, + pipeline_name: str, + workers: int = 4, + compilation_profile: str = "dev", + timeout: int = 300, + ) -> None: + self._client = client + self._pipeline_name = pipeline_name + self._workers = workers + self._compilation_profile = compilation_profile + self._timeout = timeout + self._state = PipelineStateManager() + + def cursor(self) -> FelderaCursor: + return FelderaCursor( + self._client, + self._pipeline_name, + self._state, + self._workers, + self._compilation_profile, + self._timeout, + ) + + def commit(self) -> None: + return None + + def rollback(self) -> None: + return None + + def close(self) -> None: + return None + + +def connect( + host: str, + pipeline_name: str, + api_key: t.Optional[str] = None, + timeout: int = 300, + workers: int = 4, + compilation_profile: str = "dev", +) -> FelderaConnection: + from feldera.rest.feldera_client import FelderaClient + + client = FelderaClient(url=host, api_key=api_key, timeout=float(timeout)) + return FelderaConnection( + client, + pipeline_name, + workers, + compilation_profile, + timeout, + ) diff --git a/tests/core/engine_adapter/test_feldera.py b/tests/core/engine_adapter/test_feldera.py new file mode 100644 index 0000000000..a16994b76c --- /dev/null +++ b/tests/core/engine_adapter/test_feldera.py @@ -0,0 +1,112 @@ +# type: ignore + +import sys +import types +import typing as t + +import pytest +from sqlglot import exp + +from sqlmesh.core.engine_adapter import FelderaEngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObjectType +from sqlmesh.utils.errors import SQLMeshError + +pytestmark = [pytest.mark.engine] + + +@pytest.fixture +def adapter(make_mocked_engine_adapter: t.Callable) -> FelderaEngineAdapter: + return make_mocked_engine_adapter(FelderaEngineAdapter, patch_get_data_objects=False) + + +def _install_feldera_pipeline(monkeypatch: pytest.MonkeyPatch, pipeline: t.Any) -> None: + pipeline_cls = type( + "Pipeline", + (), + {"get": staticmethod(lambda pipeline_name, client: pipeline)}, + ) + pipeline_module = types.ModuleType("feldera.pipeline") + pipeline_module.Pipeline = pipeline_cls + + feldera_module = types.ModuleType("feldera") + feldera_module.pipeline = pipeline_module + + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + + +def test_columns_uses_pipeline_metadata(adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch): + connection = adapter._connection_pool.get() + connection._client = object() + connection._pipeline_name = "configured_pipeline" + + pipeline = types.SimpleNamespace( + tables=lambda: [ + types.SimpleNamespace( + name="orders", + fields=[ + {"name": "id", "columntype": {"type": "INTEGER"}}, + {"name": "ratio", "columntype": {"type": "REAL"}}, + {"name": "name", "columntype": "VARCHAR(12)"}, + ], + ) + ], + views=lambda: [], + ) + _install_feldera_pipeline(monkeypatch, pipeline) + + assert adapter.columns("orders") == { + "id": exp.DataType.build("INT"), + "ratio": exp.DataType.build("FLOAT"), + "name": exp.DataType.build("VARCHAR"), + } + + +def test_columns_raises_sqlmesh_error_for_missing_object( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + connection = adapter._connection_pool.get() + connection._client = object() + connection._pipeline_name = "configured_pipeline" + + pipeline = types.SimpleNamespace(tables=lambda: [], views=lambda: []) + _install_feldera_pipeline(monkeypatch, pipeline) + + with pytest.raises(SQLMeshError, match="missing"): + adapter.columns("missing") + + +def test_get_data_objects_uses_requested_pipeline_name( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + connection = adapter._connection_pool.get() + connection._client = object() + connection._pipeline_name = "configured_pipeline" + + requested_pipeline_names: t.List[str] = [] + + def get_pipeline(pipeline_name: str, client: t.Any) -> t.Any: + requested_pipeline_names.append(pipeline_name) + return types.SimpleNamespace( + tables=lambda: [types.SimpleNamespace(name="source")], + views=lambda: [types.SimpleNamespace(name="sink")], + ) + + pipeline_cls = type("Pipeline", (), {"get": staticmethod(get_pipeline)}) + pipeline_module = types.ModuleType("feldera.pipeline") + pipeline_module.Pipeline = pipeline_cls + + feldera_module = types.ModuleType("feldera") + feldera_module.pipeline = pipeline_module + + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + + data_objects = adapter._get_data_objects("catalog.requested_pipeline") + + assert requested_pipeline_names == ["requested_pipeline"] + assert [(obj.schema_name, obj.name, obj.type) for obj in data_objects] == [ + ("requested_pipeline", "source", DataObjectType.TABLE), + ("requested_pipeline", "sink", DataObjectType.MATERIALIZED_VIEW), + ] + assert adapter.dialect == "felderadialect" \ No newline at end of file diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index 2ff95525f7..0d1183587a 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -14,6 +14,7 @@ DatabricksConnectionConfig, DuckDBAttachOptions, FabricConnectionConfig, + FelderaConnectionConfig, DuckDBConnectionConfig, GCPPostgresConnectionConfig, MotherDuckConnectionConfig, @@ -1955,6 +1956,25 @@ def test_fabric_pyodbc_connection_string_generation(): assert call_args[1]["autocommit"] is True +def test_feldera_connection_config(make_config): + config = make_config(type="feldera", pipeline_name="pipeline", check_import=False) + + assert isinstance(config, FelderaConnectionConfig) + assert config.DIALECT == "felderadialect" + + with patch("sqlmesh.engines.feldera.db_api.connect") as mock_connect: + config._connection_factory_with_kwargs() + + mock_connect.assert_called_once_with( + host="http://localhost:8080", + api_key=None, + pipeline_name="pipeline", + workers=4, + compilation_profile="dev", + timeout=300, + ) + + def test_schema_differ_overrides(make_config) -> None: default_config = make_config(type="duckdb") assert default_config.schema_differ_overrides is None From 4204673db8cb72e3ce3f6a76cb2f2dcbdf741af7 Mon Sep 17 00:00:00 2001 From: fresioAS Date: Mon, 4 May 2026 23:00:02 +0200 Subject: [PATCH 02/10] working_init_state --- pyproject.toml | 32 +-- sqlmesh/core/engine_adapter/feldera.py | 149 +++++++++++-- sqlmesh/core/snapshot/evaluator.py | 2 +- sqlmesh/engines/feldera/db_api.py | 297 +++++++++++++++++++++++-- tests/engines/feldera/test_db_api.py | 78 +++++++ 5 files changed, 492 insertions(+), 66 deletions(-) create mode 100644 tests/engines/feldera/test_db_api.py diff --git a/pyproject.toml b/pyproject.toml index bcc69c667e..7f8d1b50b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "sqlglot~=30.4.2", "tenacity", "time-machine", - "json-stream" + "json-stream", ] classifiers = [ "Intended Audience :: Developers", @@ -42,10 +42,7 @@ classifiers = [ athena = ["PyAthena[Pandas]"] azuresql = ["pymssql"] azuresql-odbc = ["pyodbc>=5.0.0"] -bigquery = [ - "google-cloud-bigquery[pandas]", - "google-cloud-bigquery-storage" -] +bigquery = ["google-cloud-bigquery[pandas]", "google-cloud-bigquery-storage"] # bigframes has to be separate to support environments with an older google-cloud-bigquery pin # this is because that pin pulls in an older bigframes and the bigframes team # pinned an older SQLGlot which is incompatible with SQLMesh @@ -109,6 +106,7 @@ dbt = ["dbt-core<2"] dlt = ["dlt"] duckdb = [] fabric = ["pyodbc>=5.0.0"] +feldera = ["feldera"] gcppostgres = ["cloud-sql-python-connector[pg8000]>=1.8.0"] github = ["PyGithub>=2.6.0"] motherduck = ["duckdb>=1.3.2"] @@ -183,11 +181,7 @@ no_implicit_optional = true disallow_untyped_defs = true [[tool.mypy.overrides]] -module = [ - "examples.*.macros.*", - "tests.*", - "sqlmesh.migrations.*" -] +module = ["examples.*.macros.*", "tests.*", "sqlmesh.migrations.*"] disallow_untyped_defs = false # Sometimes it's helpful to use types within an "untyped" function because it allows IDE assistance # Unfortunately this causes MyPy to print an annoying 'By default the bodies of untyped functions are not checked' @@ -227,7 +221,7 @@ module = [ "dlt.*", "bigframes.*", "json_stream.*", - "duckdb.*" + "duckdb.*", ] ignore_missing_imports = true @@ -274,7 +268,7 @@ markers = [ # Other "set_default_connection", - "registry_isolation" + "registry_isolation", ] addopts = "-n 0 --dist=loadgroup" asyncio_default_fixture_loop_scope = "session" @@ -282,7 +276,7 @@ log_cli = false # Set this to true to enable logging during tests log_cli_format = "%(asctime)s.%(msecs)03d %(filename)s:%(lineno)d %(levelname)s %(message)s" log_cli_level = "INFO" filterwarnings = [ - "ignore:The localize method is no longer necessary, as this time zone supports the fold attribute" + "ignore:The localize method is no longer necessary, as this time zone supports the fold attribute", ] reruns_delay = 10 @@ -290,20 +284,12 @@ reruns_delay = 10 line-length = 100 [tool.ruff.lint] -select = [ - "F401", - "RET505", - "T100", -] +select = ["F401", "RET505", "T100"] extend-select = ["TID"] [tool.ruff.lint.flake8-tidy-imports] -banned-module-level-imports = [ - "duckdb", - "numpy", - "pandas", -] +banned-module-level-imports = ["duckdb", "numpy", "pandas"] # Bans imports from sqlmesh.lsp in files outside of sqlmesh/lsp [tool.ruff.lint.flake8-tidy-imports.banned-api] diff --git a/sqlmesh/core/engine_adapter/feldera.py b/sqlmesh/core/engine_adapter/feldera.py index 6ea24e2102..79065e7413 100644 --- a/sqlmesh/core/engine_adapter/feldera.py +++ b/sqlmesh/core/engine_adapter/feldera.py @@ -3,7 +3,7 @@ import typing as t from sqlglot import exp -from sqlglot.dialects import generator +from sqlglot.generator import Generator from sqlglot.dialects.dialect import Dialect from sqlmesh.core.dialect import to_schema @@ -23,12 +23,18 @@ class FelderaDialect(Dialect): - class Generator(generator.Generator): + class Generator(Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **Generator.TYPE_MAPPING, exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.INT: "INTEGER", } + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.DateStrToDate: lambda self, expression: ( + f"CAST({self.sql(expression, 'this')} AS DATE)" + ), + } _FELDERA_TO_EXP_TYPE: t.Dict[str, exp.DataType.Type] = { @@ -93,34 +99,54 @@ def _get_data_objects( except Exception: return [] - objects: t.List[DataObject] = [] + objects_by_name: t.Dict[str, DataObject] = {} for table in pipeline.tables(): name = table.name.lower() if lower_object_names and name not in lower_object_names: continue - objects.append( - DataObject( - catalog=None, - schema=pipeline_name, - name=name, - type=DataObjectType.TABLE, - ) + objects_by_name[name] = DataObject( + catalog=None, + schema=pipeline_name, + name=name, + type=DataObjectType.TABLE, ) for view in pipeline.views(): name = view.name.lower() if lower_object_names and name not in lower_object_names: continue - objects.append( - DataObject( - catalog=None, - schema=pipeline_name, - name=name, - type=DataObjectType.MATERIALIZED_VIEW, - ) + objects_by_name[name] = DataObject( + catalog=None, + schema=pipeline_name, + name=name, + type=DataObjectType.MATERIALIZED_VIEW, + ) + + pending_drops = connection._state.pending_drops() + for object_name in pending_drops: + objects_by_name.pop(object_name, None) + + for object_name in connection._state.pending_tables(): + if lower_object_names and object_name not in lower_object_names: + continue + objects_by_name[object_name] = DataObject( + catalog=None, + schema=pipeline_name, + name=object_name, + type=DataObjectType.TABLE, + ) + + for object_name in connection._state.pending_views(): + if lower_object_names and object_name not in lower_object_names: + continue + objects_by_name[object_name] = DataObject( + catalog=None, + schema=pipeline_name, + name=object_name, + type=DataObjectType.MATERIALIZED_VIEW, ) - return objects + return list(objects_by_name.values()) def columns( self, table_name: TableName, include_pseudo_columns: bool = False @@ -150,5 +176,90 @@ def columns( def get_current_catalog(self) -> t.Optional[str]: return None + def _replace_materialized_view( + self, + table_name: TableName, + query_or_df: t.Any, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, + **kwargs: t.Any, + ) -> None: + target_table = exp.to_table(table_name) + target_data_object = self.get_data_object(target_table) + + if target_data_object is not None: + self.drop_data_object(target_data_object, ignore_if_not_exists=True) + + self.create_view( + target_table, + query_or_df, + target_columns_to_types=target_columns_to_types, + replace=False, + materialized=True, + table_description=table_description, + column_descriptions=column_descriptions, + source_columns=source_columns, + **kwargs, + ) + + def replace_query( + self, + table_name: TableName, + query_or_df: t.Any, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, + supports_replace_table_override: t.Optional[bool] = None, + **kwargs: t.Any, + ) -> None: + self._replace_materialized_view( + table_name, + query_or_df, + target_columns_to_types=target_columns_to_types, + table_description=table_description, + column_descriptions=column_descriptions, + source_columns=source_columns, + **kwargs, + ) + + def insert_overwrite_by_time_partition( + self, + table_name: TableName, + query_or_df: t.Any, + start: t.Any, + end: t.Any, + time_formatter: t.Any, + time_column: t.Any, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + **kwargs: t.Any, + ) -> None: + self._replace_materialized_view( + table_name, + query_or_df, + target_columns_to_types=target_columns_to_types, + source_columns=source_columns, + **kwargs, + ) + + def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None: + target_data_object = self.get_data_object(exp.to_table(table_name)) + if target_data_object and target_data_object.type.is_view: + self.drop_view(table_name, exists=exists, **kwargs) + return + super().drop_table(table_name, exists=exists, **kwargs) + + def create_schema( + self, + schema_name: SchemaName, + ignore_if_exists: bool = True, + warn_on_error: bool = True, + properties: t.Optional[t.List[exp.Expr]] = None, + ) -> None: + return None + def ping(self) -> None: self.connection._client.get_config() diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index b1ffd4dc26..1d6a9f0c1f 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -2060,7 +2060,7 @@ def create( # Only sql models have queries that can be tested. # We also need to make sure that we don't dry run on Redshift because its planner / optimizer sometimes # breaks on our CTAS queries due to us relying on the WHERE FALSE LIMIT 0 combo. - if model.is_sql and dry_run and self.adapter.dialect != "redshift": + if model.is_sql and dry_run and self.adapter.dialect not in {"redshift", "felderadialect"}: logger.info("Dry running model '%s'", model.name) self.adapter.fetchall(ctas_query) else: diff --git a/sqlmesh/engines/feldera/db_api.py b/sqlmesh/engines/feldera/db_api.py index 3d70a47747..0099dff669 100644 --- a/sqlmesh/engines/feldera/db_api.py +++ b/sqlmesh/engines/feldera/db_api.py @@ -4,6 +4,10 @@ import threading import typing as t from enum import Enum +import re + +from sqlglot import exp, parse, parse_one +from sqlglot.errors import ParseError logger = logging.getLogger(__name__) @@ -20,9 +24,20 @@ class SqlIntent(Enum): def _classify(sql: str) -> SqlIntent: """Route SQL to the correct Feldera endpoint.""" - stripped = sql.strip().lstrip("/*").strip() + stripped = _strip_leading_comments(sql) if not stripped: return SqlIntent.NO_OP + + try: + expression = parse_one(stripped) + except (ParseError, ValueError): + expression = None + + if isinstance(expression, (exp.Create, exp.Drop, exp.Alter)): + return SqlIntent.PIPELINE_DDL + if isinstance(expression, exp.Insert): + return SqlIntent.DATA_INGRESS + upper = stripped.upper() if upper.startswith(("CREATE", "DROP", "ALTER")): return SqlIntent.PIPELINE_DDL @@ -38,19 +53,55 @@ def __init__(self) -> None: self._lock = threading.RLock() self._tables: t.Dict[str, str] = {} self._views: t.Dict[str, str] = {} + self._dropped_objects: t.Set[str] = set() self._pipeline = None + self._dirty = False + self._hydrated = False + self._hydrated_object_keys: t.Set[str] = set() def register_ddl(self, sql: str) -> None: with self._lock: upper = sql.strip().upper() + object_key = _extract_name(sql) + self._hydrated_object_keys.discard(object_key) if "CREATE TABLE" in upper or "CREATE MATERIALIZED TABLE" in upper: - self._tables[_extract_name(sql)] = sql + self._dropped_objects.discard(object_key) + self._views.pop(object_key, None) + self._tables[object_key] = sql + self._dirty = True elif "CREATE VIEW" in upper or "CREATE MATERIALIZED VIEW" in upper: - self._views[_extract_name(sql)] = sql + self._dropped_objects.discard(object_key) + self._tables.pop(object_key, None) + self._views[object_key] = sql + self._dirty = True elif "DROP TABLE" in upper: - self._tables.pop(_extract_name(sql), None) + self._tables.pop(object_key, None) + self._dropped_objects.add(object_key) + self._dirty = True elif "DROP VIEW" in upper: - self._views.pop(_extract_name(sql), None) + self._views.pop(object_key, None) + self._dropped_objects.add(object_key) + self._dirty = True + + def has_pending_changes(self) -> bool: + with self._lock: + return self._dirty + + def pending_tables(self) -> t.Set[str]: + with self._lock: + return set(self._tables) + + def pending_views(self) -> t.Set[str]: + with self._lock: + return set(self._views) + + def pending_drops(self) -> t.Set[str]: + with self._lock: + return set(self._dropped_objects) + + def current_pipeline(self) -> t.Any: + with self._lock: + return self._pipeline def assemble_program(self) -> str: with self._lock: @@ -68,42 +119,217 @@ def deploy( compilation_profile: str = "dev", timeout: int = 300, ) -> t.Any: + from feldera.enums import CompilationProfile + from feldera.pipeline import Pipeline from feldera.pipeline_builder import PipelineBuilder - from feldera.rest.pipeline import PipelineStatus + from feldera.rest.errors import FelderaAPIError + from feldera.rest.pipeline import Pipeline as InnerPipeline from feldera.runtime_config import RuntimeConfig + try: + from feldera.pipeline import PipelineStatus + except ImportError: + from feldera.rest.pipeline import PipelineStatus + with self._lock: + self._hydrate_existing_program(client, pipeline_name) + if not self._dirty: + return self._pipeline + sql = self.assemble_program() if not sql.strip(): + self._dirty = False + self._dropped_objects.clear() return self._pipeline runtime_config = RuntimeConfig.default() runtime_config.workers = workers - pipeline = PipelineBuilder( + profile = compilation_profile + if isinstance(profile, str): + profile = CompilationProfile(profile) + + try: + pipeline = self._compile_program( + client, + pipeline_name, + sql, + profile, + runtime_config, + timeout, + Pipeline, + PipelineBuilder, + PipelineStatus, + InnerPipeline, + FelderaAPIError, + ) + except RuntimeError as ex: + if "not found" not in str(ex).lower() or not self._hydrated_object_keys: + raise + + for object_key in list(self._hydrated_object_keys): + self._tables.pop(object_key, None) + self._views.pop(object_key, None) + self._hydrated_object_keys.clear() + + sql = self.assemble_program() + pipeline = self._compile_program( + client, + pipeline_name, + sql, + profile, + runtime_config, + timeout, + Pipeline, + PipelineBuilder, + PipelineStatus, + InnerPipeline, + FelderaAPIError, + ) + + pipeline.start() + pipeline.wait_for_status(PipelineStatus.RUNNING, timeout=timeout) + self._pipeline = pipeline + self._dirty = False + self._dropped_objects.clear() + return pipeline + + def _hydrate_existing_program(self, client: t.Any, pipeline_name: str) -> None: + if self._hydrated: + return + + self._hydrated = True + + try: + from feldera.pipeline import Pipeline + + pipeline = Pipeline.get(pipeline_name, client) + program_code = getattr(getattr(pipeline, "_inner", None), "program_code", "") or "" + except Exception: + return + + for expression in parse(program_code): + if expression is None: + continue + + sql = expression.sql() + object_key = _extract_name(sql) + + if isinstance(expression, exp.Create): + kind = str(expression.args.get("kind") or "").upper() + if "VIEW" in kind: + self._tables.pop(object_key, None) + self._views[object_key] = sql + elif "TABLE" in kind: + self._views.pop(object_key, None) + self._tables[object_key] = sql + self._hydrated_object_keys.add(object_key) + + def _compile_program( + self, + client: t.Any, + pipeline_name: str, + sql: str, + profile: t.Any, + runtime_config: t.Any, + timeout: int, + Pipeline: t.Any, + PipelineBuilder: t.Any, + PipelineStatus: t.Any, + InnerPipeline: t.Any, + FelderaAPIError: t.Any, + ) -> t.Any: + try: + existing_pipeline = Pipeline.get(pipeline_name, client) + except FelderaAPIError: + existing_pipeline = None + + if existing_pipeline is None: + return PipelineBuilder( client, name=pipeline_name, sql=sql, - compilation_profile=compilation_profile, + compilation_profile=profile, runtime_config=runtime_config, ).create_or_replace(wait=True) - pipeline.start() - pipeline.wait_for_status(PipelineStatus.RUNNING, timeout=timeout) - self._pipeline = pipeline - return pipeline + + existing_pipeline.stop(force=True) + existing_pipeline.wait_for_status(PipelineStatus.STOPPED, timeout=timeout) + existing_pipeline.dismiss_error() + + inner_pipeline = InnerPipeline( + name=pipeline_name, + sql=sql, + udf_rust="", + udf_toml="", + description="", + program_config={ + "profile": profile.value, + "runtime_version": None, + }, + runtime_config=runtime_config.to_dict(), + ) + inner_pipeline = client.create_or_update_pipeline(inner_pipeline, wait=True) + pipeline = Pipeline(client) + pipeline._inner = inner_pipeline + return pipeline def _extract_name(sql: str) -> str: - """Very basic name extraction until sqlglot-based parsing is added.""" - import re + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped) + except (ParseError, ValueError): + return stripped[:80].lower() + + target = expression.this if isinstance(expression, (exp.Create, exp.Drop)) else None + if isinstance(target, exp.Schema): + target = target.this + + if isinstance(target, exp.Table): + return target.name.lower() - match = re.search( - r"(?:CREATE|DROP)\s+(?:MATERIALIZED\s+)?(?:TABLE|VIEW)\s+" - r"(?:IF\s+(?:NOT\s+)?EXISTS\s+)?(\w+)", + return stripped[:80].lower() + + +def _normalize_ddl(sql: str) -> str: + return re.sub( + r"\bCREATE\s+TABLE\s+IF\s+NOT\s+EXISTS\b", + "CREATE TABLE", sql, - re.IGNORECASE, + flags=re.IGNORECASE, ) - return match.group(1).lower() if match else sql[:40] + + +def _strip_leading_comments(sql: str) -> str: + stripped = sql.lstrip() + while stripped.startswith("/*"): + comment_end = stripped.find("*/") + if comment_end == -1: + return stripped + stripped = stripped[comment_end + 2 :].lstrip() + return stripped + + +def _strip_table_qualifiers(sql: str) -> str: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped) + except (ParseError, ValueError): + return stripped + + expression = expression.transform(_unqualify_table) + return expression.sql() + + +def _unqualify_table(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Table): + node = node.copy() + node.set("db", None) + node.set("catalog", None) + return node class FelderaCursor: @@ -135,6 +361,9 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: "Feldera DB-API does not support query parameters" ) + sql = _strip_table_qualifiers(sql) + sql = _normalize_ddl(sql) + intent = _classify(sql) logger.debug("Feldera execute (intent=%s): %.200s", intent.value, sql) @@ -145,6 +374,11 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: if intent == SqlIntent.PIPELINE_DDL: self._state.register_ddl(sql) + self._rows = [] + self._columns = [] + return + + if self._state.has_pending_changes(): self._state.deploy( self._client, self._pipeline_name, @@ -152,9 +386,6 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: self._compilation_profile, self._timeout, ) - self._rows = [] - self._columns = [] - return if intent == SqlIntent.DATA_INGRESS: self._get_pipeline().execute(sql) @@ -174,6 +405,10 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: def _get_pipeline(self) -> t.Any: from feldera.pipeline import Pipeline + pipeline = self._state.current_pipeline() + if pipeline is not None: + return pipeline + return Pipeline.get(self._pipeline_name, self._client) def fetchone(self) -> t.Optional[t.Tuple]: @@ -200,20 +435,27 @@ def close(self) -> None: class FelderaConnection: """DB-API 2.0 connection wrapper around FelderaClient.""" + _state_lock = threading.Lock() + _shared_states: t.Dict[t.Tuple[str, str], PipelineStateManager] = {} + def __init__( self, client: t.Any, + host: str, pipeline_name: str, workers: int = 4, compilation_profile: str = "dev", timeout: int = 300, ) -> None: self._client = client + self._host = host self._pipeline_name = pipeline_name self._workers = workers self._compilation_profile = compilation_profile self._timeout = timeout - self._state = PipelineStateManager() + state_key = (host, pipeline_name) + with self._state_lock: + self._state = self._shared_states.setdefault(state_key, PipelineStateManager()) def cursor(self) -> FelderaCursor: return FelderaCursor( @@ -232,6 +474,14 @@ def rollback(self) -> None: return None def close(self) -> None: + if self._state.has_pending_changes(): + self._state.deploy( + self._client, + self._pipeline_name, + self._workers, + self._compilation_profile, + self._timeout, + ) return None @@ -248,6 +498,7 @@ def connect( client = FelderaClient(url=host, api_key=api_key, timeout=float(timeout)) return FelderaConnection( client, + host, pipeline_name, workers, compilation_profile, diff --git a/tests/engines/feldera/test_db_api.py b/tests/engines/feldera/test_db_api.py new file mode 100644 index 0000000000..20555d666d --- /dev/null +++ b/tests/engines/feldera/test_db_api.py @@ -0,0 +1,78 @@ +import types + +from sqlglot import parse_one + +from sqlmesh.engines.feldera import db_api + + +def test_classify_treats_comment_prefixed_create_schema_as_pipeline_ddl() -> None: + assert ( + db_api._classify("/* sqlmesh */ CREATE SCHEMA foo") + == db_api.SqlIntent.PIPELINE_DDL + ) + + +def test_cursor_defers_pipeline_deploy_until_non_ddl_statement() -> None: + class FakeStateManager: + def __init__(self) -> None: + self.pending = False + self.deploy_calls = 0 + self.pipeline = types.SimpleNamespace(query=lambda sql: [{"a": 1}]) + + def register_ddl(self, sql: str) -> None: + self.pending = True + + def has_pending_changes(self) -> bool: + return self.pending + + def deploy(self, *args: object, **kwargs: object) -> object: + self.deploy_calls += 1 + self.pending = False + return self.pipeline + + def current_pipeline(self) -> object: + return self.pipeline + + state_manager = FakeStateManager() + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=state_manager, + ) + + cursor.execute("/* sqlmesh */ CREATE TABLE foo (id INT)") + + assert state_manager.deploy_calls == 0 + + cursor.execute("SELECT 1 AS a") + + assert state_manager.deploy_calls == 1 + assert cursor.fetchall() == [(1,)] + + +def test_hydrate_existing_program_skips_empty_parse_results(monkeypatch) -> None: + manager = db_api.PipelineStateManager() + + monkeypatch.setattr(db_api, "parse", lambda sql: [None, parse_one("CREATE TABLE foo (id INT)")]) + + pipeline_module = types.ModuleType("feldera.pipeline") + pipeline_module.Pipeline = type( + "Pipeline", + (), + { + "get": staticmethod( + lambda pipeline_name, client: types.SimpleNamespace( + _inner=types.SimpleNamespace(program_code="ignored") + ) + ) + }, + ) + feldera_module = types.ModuleType("feldera") + feldera_module.pipeline = pipeline_module + + monkeypatch.setitem(__import__("sys").modules, "feldera", feldera_module) + monkeypatch.setitem(__import__("sys").modules, "feldera.pipeline", pipeline_module) + + manager._hydrate_existing_program(object(), "test_pipeline") + + assert manager.pending_tables() == {"foo"} \ No newline at end of file From d15fc1e0af185a76a4f6c5b0b5c236517a04cfe2 Mon Sep 17 00:00:00 2001 From: fresioAS Date: Wed, 6 May 2026 01:54:29 +0200 Subject: [PATCH 03/10] materialized --- sqlmesh/core/engine_adapter/feldera.py | 144 +++++++---- sqlmesh/engines/feldera/db_api.py | 285 +++++++++++++++++++++- tests/core/engine_adapter/test_feldera.py | 199 ++++++++++++++- tests/engines/feldera/test_db_api.py | 135 +++++++++- 4 files changed, 701 insertions(+), 62 deletions(-) diff --git a/sqlmesh/core/engine_adapter/feldera.py b/sqlmesh/core/engine_adapter/feldera.py index 79065e7413..a0f1db4a87 100644 --- a/sqlmesh/core/engine_adapter/feldera.py +++ b/sqlmesh/core/engine_adapter/feldera.py @@ -1,10 +1,11 @@ from __future__ import annotations +import logging import typing as t from sqlglot import exp -from sqlglot.generator import Generator from sqlglot.dialects.dialect import Dialect +from sqlglot.generator import Generator from sqlmesh.core.dialect import to_schema from sqlmesh.core.engine_adapter.base import EngineAdapter @@ -14,12 +15,30 @@ DataObject, DataObjectType, ) +from sqlmesh.engines.feldera.db_api import QUERY_MIRROR_PREFIX from sqlmesh.utils.errors import SQLMeshError if t.TYPE_CHECKING: import pandas as pd from sqlmesh.core._typing import SchemaName, TableName + from sqlmesh.core.engine_adapter.shared import SourceQuery + + +logger = logging.getLogger(__name__) + + +def _is_query_mirror_name(name: str) -> bool: + return name.lower().startswith(QUERY_MIRROR_PREFIX) + + +def _view_type(state: t.Any, object_name: str) -> DataObjectType: + is_materialized_view = getattr(state, "is_materialized_view", None) + + if callable(is_materialized_view) and is_materialized_view(object_name): + return DataObjectType.MATERIALIZED_VIEW + + return DataObjectType.VIEW class FelderaDialect(Dialect): @@ -113,13 +132,15 @@ def _get_data_objects( for view in pipeline.views(): name = view.name.lower() + if _is_query_mirror_name(name): + continue if lower_object_names and name not in lower_object_names: continue objects_by_name[name] = DataObject( catalog=None, schema=pipeline_name, name=name, - type=DataObjectType.MATERIALIZED_VIEW, + type=_view_type(connection._state, name), ) pending_drops = connection._state.pending_drops() @@ -143,7 +164,7 @@ def _get_data_objects( catalog=None, schema=pipeline_name, name=object_name, - type=DataObjectType.MATERIALIZED_VIEW, + type=_view_type(connection._state, object_name), ) return list(objects_by_name.values()) @@ -176,80 +197,103 @@ def columns( def get_current_catalog(self) -> t.Optional[str]: return None - def _replace_materialized_view( + def create_view( self, - table_name: TableName, + view_name: TableName, query_or_df: t.Any, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + replace: bool = True, + materialized: bool = False, + materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expr]] = None, source_columns: t.Optional[t.List[str]] = None, - **kwargs: t.Any, + **create_kwargs: t.Any, ) -> None: - target_table = exp.to_table(table_name) - target_data_object = self.get_data_object(target_table) + if replace: + target_data_object = self.get_data_object(exp.to_table(view_name)) + if target_data_object is not None: + self.drop_data_object(target_data_object, ignore_if_not_exists=True) - if target_data_object is not None: - self.drop_data_object(target_data_object, ignore_if_not_exists=True) - - self.create_view( - target_table, + super().create_view( + view_name, query_or_df, target_columns_to_types=target_columns_to_types, replace=False, - materialized=True, + materialized=materialized, + materialized_properties=materialized_properties, table_description=table_description, column_descriptions=column_descriptions, + view_properties=view_properties, source_columns=source_columns, - **kwargs, + **create_kwargs, ) - def replace_query( + def _create_table_from_source_queries( self, table_name: TableName, - query_or_df: t.Any, + source_queries: t.List[SourceQuery], target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + exists: bool = True, + replace: bool = False, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, - source_columns: t.Optional[t.List[str]] = None, - supports_replace_table_override: t.Optional[bool] = None, + table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: - self._replace_materialized_view( - table_name, - query_or_df, - target_columns_to_types=target_columns_to_types, - table_description=table_description, - column_descriptions=column_descriptions, - source_columns=source_columns, - **kwargs, - ) + if replace: + return super()._create_table_from_source_queries( + table_name, + source_queries, + target_columns_to_types=target_columns_to_types, + exists=exists, + replace=replace, + table_description=table_description, + column_descriptions=column_descriptions, + table_kind=table_kind, + track_rows_processed=track_rows_processed, + **kwargs, + ) - def insert_overwrite_by_time_partition( - self, - table_name: TableName, - query_or_df: t.Any, - start: t.Any, - end: t.Any, - time_formatter: t.Any, - time_column: t.Any, - target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - source_columns: t.Optional[t.List[str]] = None, - **kwargs: t.Any, - ) -> None: - self._replace_materialized_view( - table_name, - query_or_df, - target_columns_to_types=target_columns_to_types, - source_columns=source_columns, - **kwargs, - ) + if not target_columns_to_types: + raise SQLMeshError( + "Feldera requires known column types when creating a table from a query." + ) + + with self.transaction(condition=len(source_queries) > 1): + self._create_table_from_columns( + table_name, + target_columns_to_types, + exists=exists, + table_description=table_description, + column_descriptions=column_descriptions, + **kwargs, + ) + for source_query in source_queries: + with source_query as query: + self._insert_append_query( + table_name, + query, + target_columns_to_types, + track_rows_processed=track_rows_processed, + ) def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None: target_data_object = self.get_data_object(exp.to_table(table_name)) - if target_data_object and target_data_object.type.is_view: - self.drop_view(table_name, exists=exists, **kwargs) - return + if target_data_object: + if target_data_object.type.is_materialized_view: + self.drop_view( + table_name, + ignore_if_not_exists=exists, + materialized=True, + **kwargs, + ) + return + if target_data_object.type.is_view: + self.drop_view(table_name, ignore_if_not_exists=exists, **kwargs) + return super().drop_table(table_name, exists=exists, **kwargs) def create_schema( diff --git a/sqlmesh/engines/feldera/db_api.py b/sqlmesh/engines/feldera/db_api.py index 0099dff669..086d3bdc69 100644 --- a/sqlmesh/engines/feldera/db_api.py +++ b/sqlmesh/engines/feldera/db_api.py @@ -11,6 +11,8 @@ logger = logging.getLogger(__name__) +QUERY_MIRROR_PREFIX = "__sqlmesh_query__" + if t.TYPE_CHECKING: import pandas as pd @@ -61,10 +63,46 @@ def __init__(self) -> None: def register_ddl(self, sql: str) -> None: with self._lock: - upper = sql.strip().upper() object_key = _extract_name(sql) + expression = None + + try: + expression = parse_one(_strip_leading_comments(sql)) + except (ParseError, ValueError): + pass + self._hydrated_object_keys.discard(object_key) + + if isinstance(expression, exp.Create): + kind = str(expression.args.get("kind") or "").upper() + if "TABLE" in kind: + sql = _rewrite_table_ctas_sql(sql) + self._dropped_objects.discard(object_key) + self._views.pop(object_key, None) + self._tables[object_key] = sql + self._dirty = True + elif "VIEW" in kind: + self._dropped_objects.discard(object_key) + self._tables.pop(object_key, None) + self._views[object_key] = sql + self._dirty = True + return + + if isinstance(expression, exp.Drop): + kind = str(expression.args.get("kind") or "").upper() + if "TABLE" in kind: + self._tables.pop(object_key, None) + self._dropped_objects.add(object_key) + self._dirty = True + elif "VIEW" in kind: + self._views.pop(object_key, None) + self._dropped_objects.add(object_key) + self._dirty = True + return + + upper = sql.strip().upper() if "CREATE TABLE" in upper or "CREATE MATERIALIZED TABLE" in upper: + sql = _rewrite_table_ctas_sql(sql) self._dropped_objects.discard(object_key) self._views.pop(object_key, None) self._tables[object_key] = sql @@ -95,6 +133,10 @@ def pending_views(self) -> t.Set[str]: with self._lock: return set(self._views) + def is_materialized_view(self, object_name: str) -> bool: + with self._lock: + return _is_materialized_view_sql(self._views.get(object_name.lower(), "")) + def pending_drops(self) -> t.Set[str]: with self._lock: return set(self._dropped_objects) @@ -103,11 +145,30 @@ def current_pipeline(self) -> t.Any: with self._lock: return self._pipeline + def queryable_relation_names(self) -> t.Set[str]: + with self._lock: + return { + *self._tables, + *( + name + for name, sql in self._views.items() + if not _is_materialized_view_sql(sql) + ), + } + def assemble_program(self) -> str: with self._lock: statements = [ - *(ddl.rstrip(";") + ";" for ddl in self._tables.values()), - *(ddl.rstrip(";") + ";" for ddl in self._views.values()), + *( + statement + for ddl in self._tables.values() + for statement in _ddl_statements_with_query_mirror(ddl) + ), + *( + statement + for ddl in self._views.values() + for statement in _ddl_statements_with_query_mirror(ddl) + ), ] return "\n\n".join(statements) @@ -213,14 +274,22 @@ def _hydrate_existing_program(self, client: t.Any, pipeline_name: str) -> None: continue sql = expression.sql() - object_key = _extract_name(sql) if isinstance(expression, exp.Create): + target = expression.this + if isinstance(target, exp.Schema): + target = target.this + + if isinstance(target, exp.Table) and _is_query_mirror_name(target.name): + continue + + object_key = _extract_name(sql) kind = str(expression.args.get("kind") or "").upper() if "VIEW" in kind: self._tables.pop(object_key, None) self._views[object_key] = sql elif "TABLE" in kind: + sql = _rewrite_table_ctas_sql(sql) self._views.pop(object_key, None) self._tables[object_key] = sql self._hydrated_object_keys.add(object_key) @@ -302,6 +371,203 @@ def _normalize_ddl(sql: str) -> str: ) +def _is_query_mirror_name(name: str) -> bool: + return name.lower().startswith(QUERY_MIRROR_PREFIX) + + +def _query_mirror_name(name: str) -> str: + return f"{QUERY_MIRROR_PREFIX}{name}" + + +def _query_mirror_table(table: exp.Table) -> exp.Table: + mirror = table.copy() + mirror.set("this", exp.to_identifier(_query_mirror_name(table.name), quoted=True)) + return mirror + + +def _ddl_statements_with_query_mirror(sql: str) -> t.List[str]: + try: + expressions = [expression for expression in parse(sql) if expression is not None] + except (ParseError, ValueError): + expressions = [] + + if not expressions: + statement = sql.rstrip(";") + ";" + mirror_sql = _query_mirror_sql(sql) + return [statement, *([mirror_sql.rstrip(";") + ";"] if mirror_sql else [])] + + statements = [] + for expression in expressions: + statement = expression.sql().rstrip(";") + ";" + statements.append(statement) + mirror_sql = _query_mirror_sql(statement) + if mirror_sql: + statements.append(mirror_sql.rstrip(";") + ";") + + return statements + + +def _query_mirror_sql(sql: str) -> t.Optional[str]: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped) + except (ParseError, ValueError): + return None + + if not isinstance(expression, exp.Create): + return None + + target = expression.this + if isinstance(target, exp.Schema): + target = target.this + + if not isinstance(target, exp.Table) or _is_query_mirror_name(target.name): + return None + + kind = str(expression.args.get("kind") or "").upper() + if "TABLE" not in kind and "VIEW" not in kind: + return None + if "VIEW" in kind and _is_materialized_view_sql(sql): + return None + + return exp.Create( + this=_query_mirror_table(target), + kind="MATERIALIZED VIEW", + expression=exp.select("*").from_(target.copy()), + ).sql() + + +def _rewrite_query_for_query_mirrors(sql: str, relation_names: t.Set[str]) -> str: + if not relation_names: + return sql + + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped) + except (ParseError, ValueError): + return sql + + cte_names = { + cte.alias_or_name.lower() + for cte in expression.find_all(exp.CTE) + if cte.alias_or_name + } + + def transform(node: exp.Expression) -> exp.Expression: + if ( + isinstance(node, exp.Table) + and not _is_query_mirror_name(node.name) + and node.name.lower() in relation_names + and node.name.lower() not in cte_names + ): + return _query_mirror_table(node) + return node + + return expression.transform(transform, copy=True).sql() + + +def _execution_error(rows: t.List[t.Mapping[str, t.Any]]) -> t.Optional[str]: + for row in rows: + for value in row.values(): + if isinstance(value, str) and value.startswith("Execution error:"): + return value + + return None + + +def _is_materialized_view_sql(sql: str) -> bool: + stripped = _strip_leading_comments(sql) + upper = stripped.upper() + + if "CREATE MATERIALIZED VIEW" in upper: + return True + if "CREATE VIEW" in upper: + return False + + try: + expression = parse_one(stripped) + except (ParseError, ValueError): + expression = None + + if isinstance(expression, exp.Create): + kind = str(expression.args.get("kind") or "").upper() + return "MATERIALIZED" in kind and "VIEW" in kind + + return False + + +def _rewrite_table_ctas_sql(sql: str) -> str: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped) + except (ParseError, ValueError): + return sql + + if not isinstance(expression, exp.Create): + return sql + + kind = str(expression.args.get("kind") or "").upper() + query = expression.expression + target = expression.this + + if "TABLE" not in kind or query is None or not isinstance(target, exp.Table): + return sql + + columns_to_types = _select_columns_to_types(query) + if not columns_to_types: + return sql + + schema = exp.Schema( + this=target.copy(), + expressions=[ + exp.ColumnDef(this=exp.to_identifier(column, quoted=True), kind=data_type.copy()) + for column, data_type in columns_to_types.items() + ], + ) + create_exp = exp.Create( + this=schema, + kind=expression.args.get("kind") or "TABLE", + replace=bool(expression.args.get("replace")), + exists=bool(expression.args.get("exists")), + properties=expression.args.get("properties"), + ) + insert_exp = exp.insert(query.copy(), target.copy(), columns=list(columns_to_types)) + return f"{create_exp.sql()};\n{insert_exp.sql()}" + + +def _select_columns_to_types(query: exp.Expression) -> t.Optional[t.Dict[str, exp.DataType]]: + if not isinstance(query, exp.Query): + return None + + columns_to_types: t.Dict[str, exp.DataType] = {} + unknown = exp.DataType.build("unknown") + + for select in query.selects: + output_name = select.output_name + data_type = _projection_type(select) or (select.type or unknown).copy() + + if not output_name or output_name in columns_to_types or data_type == unknown: + return None + + columns_to_types[output_name] = data_type + + return columns_to_types or None + + +def _projection_type(select: exp.Expression) -> t.Optional[exp.DataType]: + expression = select + if isinstance(select, exp.Alias): + expression = select.this + + if isinstance(expression, exp.Cast) and isinstance(expression.args.get("to"), exp.DataType): + return expression.args["to"].copy() + + return None + + def _strip_leading_comments(sql: str) -> str: stripped = sql.lstrip() while stripped.startswith("/*"): @@ -393,7 +659,12 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: self._columns = [] return - rows = list(self._get_pipeline().query(sql)) + query_sql = _rewrite_query_for_query_mirrors( + sql, self._state.queryable_relation_names() + ) + rows = list(self._get_pipeline().query(query_sql)) + if error := _execution_error(rows): + raise RuntimeError(error) self._rows = rows self._columns = list(rows[0].keys()) if rows else [] self.rowcount = len(rows) @@ -403,12 +674,12 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: ] def _get_pipeline(self) -> t.Any: - from feldera.pipeline import Pipeline - pipeline = self._state.current_pipeline() if pipeline is not None: return pipeline + from feldera.pipeline import Pipeline + return Pipeline.get(self._pipeline_name, self._client) def fetchone(self) -> t.Optional[t.Tuple]: diff --git a/tests/core/engine_adapter/test_feldera.py b/tests/core/engine_adapter/test_feldera.py index a16994b76c..8f209ab88e 100644 --- a/tests/core/engine_adapter/test_feldera.py +++ b/tests/core/engine_adapter/test_feldera.py @@ -5,18 +5,26 @@ import typing as t import pytest +from sqlglot import parse_one from sqlglot import exp from sqlmesh.core.engine_adapter import FelderaEngineAdapter -from sqlmesh.core.engine_adapter.shared import DataObjectType +from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType from sqlmesh.utils.errors import SQLMeshError +from tests.core.engine_adapter import to_sql_calls pytestmark = [pytest.mark.engine] @pytest.fixture def adapter(make_mocked_engine_adapter: t.Callable) -> FelderaEngineAdapter: - return make_mocked_engine_adapter(FelderaEngineAdapter, patch_get_data_objects=False) + adapter = make_mocked_engine_adapter(FelderaEngineAdapter, patch_get_data_objects=False) + connection = adapter._connection_pool.get() + connection._state.pending_drops.return_value = set() + connection._state.pending_tables.return_value = set() + connection._state.pending_views.return_value = set() + connection._state.is_materialized_view.side_effect = lambda _: False + return adapter def _install_feldera_pipeline(monkeypatch: pytest.MonkeyPatch, pipeline: t.Any) -> None: @@ -89,7 +97,10 @@ def get_pipeline(pipeline_name: str, client: t.Any) -> t.Any: requested_pipeline_names.append(pipeline_name) return types.SimpleNamespace( tables=lambda: [types.SimpleNamespace(name="source")], - views=lambda: [types.SimpleNamespace(name="sink")], + views=lambda: [ + types.SimpleNamespace(name="sink"), + types.SimpleNamespace(name="__sqlmesh_query__source"), + ], ) pipeline_cls = type("Pipeline", (), {"get": staticmethod(get_pipeline)}) @@ -107,6 +118,186 @@ def get_pipeline(pipeline_name: str, client: t.Any) -> t.Any: assert requested_pipeline_names == ["requested_pipeline"] assert [(obj.schema_name, obj.name, obj.type) for obj in data_objects] == [ ("requested_pipeline", "source", DataObjectType.TABLE), + ("requested_pipeline", "sink", DataObjectType.VIEW), + ] + assert adapter.dialect == "felderadialect" + + +def test_get_data_objects_marks_materialized_views_from_state( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + connection = adapter._connection_pool.get() + connection._client = object() + connection._state.is_materialized_view.side_effect = lambda name: name == "sink" + + pipeline = types.SimpleNamespace( + tables=lambda: [], + views=lambda: [types.SimpleNamespace(name="sink")], + ) + _install_feldera_pipeline(monkeypatch, pipeline) + + data_objects = adapter._get_data_objects("catalog.requested_pipeline") + + assert [(obj.schema_name, obj.name, obj.type) for obj in data_objects] == [ ("requested_pipeline", "sink", DataObjectType.MATERIALIZED_VIEW), ] - assert adapter.dialect == "felderadialect" \ No newline at end of file + + +def test_replace_query_creates_table( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr(adapter, "_get_data_objects", lambda schema_name, object_names=None: []) + adapter.replace_query( + "db.full_model", + parse_one("SELECT a FROM tbl"), + {"a": exp.DataType.build("INT")}, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "db"."full_model" ("a" INTEGER)', + 'INSERT INTO "db"."full_model" ("a") SELECT "a" FROM "tbl"', + ] + + +def test_insert_overwrite_by_time_partition_uses_table_operations( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr(adapter, "_get_data_objects", lambda schema_name, object_names=None: []) + + adapter.insert_overwrite_by_time_partition( + "db.incremental_model", + parse_one("SELECT id, ds FROM source"), + start="2024-01-01", + end="2024-01-02", + time_column="ds", + time_formatter=lambda x, _: exp.Literal.string(str(x)[:10]), + target_columns_to_types={ + "id": exp.DataType.build("INT"), + "ds": exp.DataType.build("DATE"), + }, + ) + + assert to_sql_calls(adapter) == [ + 'DELETE FROM "db"."incremental_model" WHERE "ds" BETWEEN \'2024-01-01\' AND \'2024-01-02\'', + 'INSERT INTO "db"."incremental_model" ("id", "ds") SELECT "id", "ds" FROM (SELECT "id", "ds" FROM "source") AS "_subquery" WHERE "ds" BETWEEN \'2024-01-01\' AND \'2024-01-02\'', + ] + + +def test_create_view_creates_view( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr(adapter, "_get_data_objects", lambda schema_name, object_names=None: []) + adapter.create_view( + "db.view_model", + parse_one("SELECT a FROM tbl"), + replace=False, + materialized=False, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE VIEW "db"."view_model" AS SELECT "a" FROM "tbl"', + ] + + +def test_replace_view_drops_then_creates_view( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject(schema="db", name="view_model", type=DataObjectType.VIEW), + ) + + adapter.create_view( + "db.view_model", + parse_one("SELECT a FROM tbl"), + replace=True, + materialized=False, + ) + + assert to_sql_calls(adapter) == [ + 'DROP VIEW IF EXISTS "db"."view_model"', + 'CREATE VIEW "db"."view_model" AS SELECT "a" FROM "tbl"', + ] + + +def test_create_materialized_view_creates_materialized_view( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr(adapter, "_get_data_objects", lambda schema_name, object_names=None: []) + adapter.create_view( + "db.materialized_view_model", + parse_one("SELECT a FROM tbl"), + replace=False, + materialized=True, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE MATERIALIZED VIEW "db"."materialized_view_model" AS SELECT "a" FROM "tbl"', + ] + + +def test_replace_materialized_view_drops_then_creates_materialized_view( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject( + schema="db", name="materialized_view_model", type=DataObjectType.MATERIALIZED_VIEW + ), + ) + + adapter.create_view( + "db.materialized_view_model", + parse_one("SELECT a FROM tbl"), + replace=True, + materialized=True, + ) + + assert to_sql_calls(adapter) == [ + 'DROP MATERIALIZED VIEW IF EXISTS "db"."materialized_view_model"', + 'CREATE MATERIALIZED VIEW "db"."materialized_view_model" AS SELECT "a" FROM "tbl"', + ] + + +def test_drop_table_drops_view(adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject(schema="db", name="view_model", type=DataObjectType.VIEW), + ) + + adapter.drop_table("db.view_model") + + assert to_sql_calls(adapter) == ['DROP VIEW IF EXISTS "db"."view_model"'] + + +def test_drop_table_drops_materialized_view( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject( + schema="db", name="materialized_view_model", type=DataObjectType.MATERIALIZED_VIEW + ), + ) + + adapter.drop_table("db.materialized_view_model") + + assert to_sql_calls(adapter) == [ + 'DROP MATERIALIZED VIEW IF EXISTS "db"."materialized_view_model"' + ] + + +def test_drop_table_drops_table(adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject(schema="db", name="full_model", type=DataObjectType.TABLE), + ) + + adapter.drop_table("db.full_model") + + assert to_sql_calls(adapter) == ['DROP TABLE IF EXISTS "db"."full_model"'] \ No newline at end of file diff --git a/tests/engines/feldera/test_db_api.py b/tests/engines/feldera/test_db_api.py index 20555d666d..0f6e6176b1 100644 --- a/tests/engines/feldera/test_db_api.py +++ b/tests/engines/feldera/test_db_api.py @@ -1,5 +1,6 @@ import types +import pytest from sqlglot import parse_one from sqlmesh.engines.feldera import db_api @@ -33,6 +34,9 @@ def deploy(self, *args: object, **kwargs: object) -> object: def current_pipeline(self) -> object: return self.pipeline + def queryable_relation_names(self) -> set[str]: + return set() + state_manager = FakeStateManager() cursor = db_api.FelderaCursor( client=object(), @@ -75,4 +79,133 @@ def test_hydrate_existing_program_skips_empty_parse_results(monkeypatch) -> None manager._hydrate_existing_program(object(), "test_pipeline") - assert manager.pending_tables() == {"foo"} \ No newline at end of file + assert manager.pending_tables() == {"foo"} + + +def test_state_manager_tracks_view_materialization() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl("CREATE VIEW regular_view AS SELECT 1") + manager.register_ddl("CREATE MATERIALIZED VIEW materialized_view AS SELECT 1") + + assert manager.pending_views() == {"regular_view", "materialized_view"} + assert manager.is_materialized_view("regular_view") is False + assert manager.is_materialized_view("materialized_view") is True + + +def test_state_manager_rewrites_table_ctas() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl( + 'CREATE TABLE "seed_model" AS SELECT CAST("id" AS INTEGER) AS "id" FROM (VALUES (1)) AS "t"("id")' + ) + + assert manager.assemble_program() == ( + 'CREATE TABLE "seed_model" ("id" INT);\n' + '\n' + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__seed_model" AS SELECT * FROM "seed_model";\n' + '\n' + 'INSERT INTO "seed_model" (id) SELECT CAST("id" AS INT) AS "id" FROM (VALUES (1)) AS "t"("id");' + ) + + +def test_state_manager_adds_query_mirrors_for_non_materialized_relations() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl('CREATE TABLE "full_model" ("id" INT)') + manager.register_ddl('CREATE VIEW "view_model" AS SELECT "id" FROM "full_model"') + manager.register_ddl( + 'CREATE MATERIALIZED VIEW "materialized_view_model" AS SELECT "id" FROM "full_model"' + ) + + program = manager.assemble_program() + + assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__full_model" AS SELECT * FROM "full_model";' in program + assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__view_model" AS SELECT * FROM "view_model";' in program + assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__materialized_view_model"' not in program + + +def test_hydrate_existing_program_skips_query_mirrors(monkeypatch) -> None: + manager = db_api.PipelineStateManager() + + pipeline_module = types.ModuleType("feldera.pipeline") + pipeline_module.Pipeline = type( + "Pipeline", + (), + { + "get": staticmethod( + lambda pipeline_name, client: types.SimpleNamespace( + _inner=types.SimpleNamespace( + program_code=( + 'CREATE TABLE "full_model" ("id" INT);\n' + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__full_model" AS SELECT * FROM "full_model";\n' + 'CREATE VIEW "view_model" AS SELECT "id" FROM "full_model";\n' + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__view_model" AS SELECT * FROM "view_model";' + ) + ) + ) + ) + }, + ) + feldera_module = types.ModuleType("feldera") + feldera_module.pipeline = pipeline_module + + monkeypatch.setitem(__import__("sys").modules, "feldera", feldera_module) + monkeypatch.setitem(__import__("sys").modules, "feldera.pipeline", pipeline_module) + + manager._hydrate_existing_program(object(), "test_pipeline") + + assert manager.pending_tables() == {"full_model"} + assert manager.pending_views() == {"view_model"} + + +def test_cursor_rewrites_queries_to_query_mirrors() -> None: + captured_queries = [] + + class FakeStateManager: + def has_pending_changes(self) -> bool: + return False + + def queryable_relation_names(self) -> set[str]: + return {"full_model", "view_model"} + + def current_pipeline(self) -> object: + return types.SimpleNamespace( + query=lambda sql: captured_queries.append(sql) or [{"count": 1}] + ) + + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=FakeStateManager(), + ) + + cursor.execute('SELECT COUNT(*) FROM "full_model"') + + assert parse_one(captured_queries[0]).sql() == parse_one( + 'SELECT COUNT(*) FROM "__sqlmesh_query__full_model"' + ).sql() + assert cursor.fetchone() == (1,) + + +def test_cursor_raises_execution_error_from_query_rows() -> None: + class FakeStateManager: + def has_pending_changes(self) -> bool: + return False + + def queryable_relation_names(self) -> set[str]: + return set() + + def current_pipeline(self) -> object: + return types.SimpleNamespace( + query=lambda sql: [{"COUNT(*)": "Execution error: test failure"}] + ) + + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=FakeStateManager(), + ) + + with pytest.raises(RuntimeError, match="Execution error: test failure"): + cursor.execute("SELECT COUNT(*) FROM full_model") \ No newline at end of file From 1cce68d1494368c8c8c6bef12b3f3279ddd3fa29 Mon Sep 17 00:00:00 2001 From: fresioAS Date: Thu, 7 May 2026 21:28:16 +0200 Subject: [PATCH 04/10] rename to feldera --- sqlmesh/core/config/connection.py | 2 +- sqlmesh/core/engine_adapter/__init__.py | 1 - sqlmesh/core/engine_adapter/feldera.py | 8 ++++++-- sqlmesh/core/snapshot/evaluator.py | 2 +- tests/core/engine_adapter/test_feldera.py | 6 +++++- tests/core/test_connection_config.py | 3 +-- 6 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 557e111b6e..057a321d46 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -2353,7 +2353,7 @@ class FelderaConnectionConfig(ConnectionConfig): timeout: int = 300 type_: t.Literal["feldera"] = Field(alias="type", default="feldera") - DIALECT: t.ClassVar[t.Literal["felderadialect"]] = "felderadialect" + DIALECT: t.ClassVar[t.Literal["feldera"]] = "feldera" DISPLAY_NAME: t.ClassVar[t.Literal["Feldera"]] = "Feldera" DISPLAY_ORDER: t.ClassVar[t.Literal[18]] = 18 diff --git a/sqlmesh/core/engine_adapter/__init__.py b/sqlmesh/core/engine_adapter/__init__.py index a9d30a0732..cb0e2a8413 100644 --- a/sqlmesh/core/engine_adapter/__init__.py +++ b/sqlmesh/core/engine_adapter/__init__.py @@ -38,7 +38,6 @@ "athena": AthenaEngineAdapter, "risingwave": RisingwaveEngineAdapter, "fabric": FabricEngineAdapter, - "felderadialect": FelderaEngineAdapter, "feldera": FelderaEngineAdapter, } diff --git a/sqlmesh/core/engine_adapter/feldera.py b/sqlmesh/core/engine_adapter/feldera.py index a0f1db4a87..3d8a4a73b5 100644 --- a/sqlmesh/core/engine_adapter/feldera.py +++ b/sqlmesh/core/engine_adapter/feldera.py @@ -41,7 +41,7 @@ def _view_type(state: t.Any, object_name: str) -> DataObjectType: return DataObjectType.VIEW -class FelderaDialect(Dialect): +class _SQLMeshFeldera(Dialect): class Generator(Generator): TYPE_MAPPING = { **Generator.TYPE_MAPPING, @@ -56,6 +56,10 @@ class Generator(Generator): } +if Dialect.get("feldera") is None: + Dialect.classes["feldera"] = _SQLMeshFeldera + + _FELDERA_TO_EXP_TYPE: t.Dict[str, exp.DataType.Type] = { "BOOLEAN": exp.DataType.Type.BOOLEAN, "TINYINT": exp.DataType.Type.TINYINT, @@ -83,7 +87,7 @@ def _feldera_type_to_exp(dtype_str: str) -> exp.DataType: class FelderaEngineAdapter(EngineAdapter): - DIALECT = "felderadialect" + DIALECT = "feldera" SUPPORTS_TRANSACTIONS = False SUPPORTS_INDEXES = False SUPPORTS_MATERIALIZED_VIEWS = True diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 1d6a9f0c1f..89c179d19e 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -2060,7 +2060,7 @@ def create( # Only sql models have queries that can be tested. # We also need to make sure that we don't dry run on Redshift because its planner / optimizer sometimes # breaks on our CTAS queries due to us relying on the WHERE FALSE LIMIT 0 combo. - if model.is_sql and dry_run and self.adapter.dialect not in {"redshift", "felderadialect"}: + if model.is_sql and dry_run and self.adapter.dialect not in {"redshift", "feldera"}: logger.info("Dry running model '%s'", model.name) self.adapter.fetchall(ctas_query) else: diff --git a/tests/core/engine_adapter/test_feldera.py b/tests/core/engine_adapter/test_feldera.py index 8f209ab88e..41339b892c 100644 --- a/tests/core/engine_adapter/test_feldera.py +++ b/tests/core/engine_adapter/test_feldera.py @@ -120,7 +120,11 @@ def get_pipeline(pipeline_name: str, client: t.Any) -> t.Any: ("requested_pipeline", "source", DataObjectType.TABLE), ("requested_pipeline", "sink", DataObjectType.VIEW), ] - assert adapter.dialect == "felderadialect" + assert adapter.dialect == "feldera" + + +def test_builtin_dialect_registers_feldera_name() -> None: + assert parse_one("SELECT 1", dialect="feldera").sql(dialect="feldera") == "SELECT 1" def test_get_data_objects_marks_materialized_views_from_state( diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index 0d1183587a..1e6deb2959 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1960,14 +1960,13 @@ def test_feldera_connection_config(make_config): config = make_config(type="feldera", pipeline_name="pipeline", check_import=False) assert isinstance(config, FelderaConnectionConfig) - assert config.DIALECT == "felderadialect" + assert config.DIALECT == "feldera" with patch("sqlmesh.engines.feldera.db_api.connect") as mock_connect: config._connection_factory_with_kwargs() mock_connect.assert_called_once_with( host="http://localhost:8080", - api_key=None, pipeline_name="pipeline", workers=4, compilation_profile="dev", From 9424ca5c20af25a4f48a30a9ab4edfe254c69354 Mon Sep 17 00:00:00 2001 From: fresioAS Date: Sat, 9 May 2026 02:23:03 +0200 Subject: [PATCH 05/10] Enhance Feldera engine adapter with new SQL functions and improve DDL handling - Added support for CURRENT_TIMESTAMP in Feldera dialect. - Implemented `_insert_overwrite_by_condition` for efficient table replacement. - Enhanced `_is_virtual_layer_ddl` to identify virtual layer views. - Updated tests to cover new functionality and ensure correctness. Co-authored-by: Copilot --- sqlmesh/core/engine_adapter/feldera.py | 36 ++ sqlmesh/engines/feldera/db_api.py | 461 ++++++++++++++++++---- tests/core/engine_adapter/test_feldera.py | 29 ++ tests/engines/feldera/test_db_api.py | 294 +++++++++++++- 4 files changed, 743 insertions(+), 77 deletions(-) diff --git a/sqlmesh/core/engine_adapter/feldera.py b/sqlmesh/core/engine_adapter/feldera.py index 3d8a4a73b5..692ec61b02 100644 --- a/sqlmesh/core/engine_adapter/feldera.py +++ b/sqlmesh/core/engine_adapter/feldera.py @@ -50,6 +50,11 @@ class Generator(Generator): } TRANSFORMS = { **Generator.TRANSFORMS, + exp.CurrentTimestamp: lambda self, expression: ( + self.func("CURRENT_TIMESTAMP", expression.this) + if expression.this + else "CURRENT_TIMESTAMP" + ), exp.DateStrToDate: lambda self, expression: ( f"CAST({self.sql(expression, 'this')} AS DATE)" ), @@ -284,6 +289,37 @@ def _create_table_from_source_queries( track_rows_processed=track_rows_processed, ) + def _insert_overwrite_by_condition( + self, + table_name: TableName, + source_queries: t.List[SourceQuery], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + where: t.Optional[exp.Condition] = None, + insert_overwrite_strategy_override: t.Optional[t.Any] = None, + **kwargs: t.Any, + ) -> None: + # Whole-table replacement is cheaper in Feldera as DROP+CREATE than DELETE+INSERT. + if where is None and source_queries: + self.drop_table(table_name) + self._create_table_from_source_queries( + table_name, + source_queries, + target_columns_to_types=target_columns_to_types, + exists=True, + replace=False, + **kwargs, + ) + return + + super()._insert_overwrite_by_condition( + table_name, + source_queries, + target_columns_to_types=target_columns_to_types, + where=where, + insert_overwrite_strategy_override=insert_overwrite_strategy_override, + **kwargs, + ) + def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None: target_data_object = self.get_data_object(exp.to_table(table_name)) if target_data_object: diff --git a/sqlmesh/engines/feldera/db_api.py b/sqlmesh/engines/feldera/db_api.py index 086d3bdc69..7d7bb63fb5 100644 --- a/sqlmesh/engines/feldera/db_api.py +++ b/sqlmesh/engines/feldera/db_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import logging import threading import typing as t @@ -12,6 +13,7 @@ logger = logging.getLogger(__name__) QUERY_MIRROR_PREFIX = "__sqlmesh_query__" +FELDERA_DIALECT = "feldera" if t.TYPE_CHECKING: import pandas as pd @@ -31,7 +33,7 @@ def _classify(sql: str) -> SqlIntent: return SqlIntent.NO_OP try: - expression = parse_one(stripped) + expression = parse_one(stripped, dialect=FELDERA_DIALECT) except (ParseError, ValueError): expression = None @@ -48,6 +50,36 @@ def _classify(sql: str) -> SqlIntent: return SqlIntent.ADHOC_QUERY +def _is_virtual_layer_ddl(sql: str) -> bool: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + return False + + if not isinstance(expression, (exp.Create, exp.Drop)): + return False + + target = expression.this + if isinstance(target, exp.Schema): + target = target.this + + if not isinstance(target, exp.Table) or not target.db: + return False + + target_db = target.db.lower() + if target_db.startswith("sqlmesh__"): + return False + + referenced_snapshot_tables = [ + table + for table in expression.find_all(exp.Table) + if table is not target and table.db and table.db.lower().startswith("sqlmesh__") + ] + return bool(referenced_snapshot_tables) + + class PipelineStateManager: """Accumulates DDL and deploys it as a single Feldera pipeline program.""" @@ -67,7 +99,7 @@ def register_ddl(self, sql: str) -> None: expression = None try: - expression = parse_one(_strip_leading_comments(sql)) + expression = parse_one(_strip_leading_comments(sql), dialect=FELDERA_DIALECT) except (ParseError, ValueError): pass @@ -210,43 +242,31 @@ def deploy( if isinstance(profile, str): profile = CompilationProfile(profile) - try: - pipeline = self._compile_program( - client, - pipeline_name, - sql, - profile, - runtime_config, - timeout, - Pipeline, - PipelineBuilder, - PipelineStatus, - InnerPipeline, - FelderaAPIError, - ) - except RuntimeError as ex: - if "not found" not in str(ex).lower() or not self._hydrated_object_keys: - raise - - for object_key in list(self._hydrated_object_keys): - self._tables.pop(object_key, None) - self._views.pop(object_key, None) - self._hydrated_object_keys.clear() - - sql = self.assemble_program() - pipeline = self._compile_program( - client, - pipeline_name, - sql, - profile, - runtime_config, - timeout, - Pipeline, - PipelineBuilder, - PipelineStatus, - InnerPipeline, - FelderaAPIError, - ) + while True: + try: + pipeline = self._compile_program( + client, + pipeline_name, + sql, + profile, + runtime_config, + timeout, + Pipeline, + PipelineBuilder, + PipelineStatus, + InnerPipeline, + FelderaAPIError, + ) + break + except RuntimeError as ex: + if not self._evict_hydrated_objects(str(ex)): + raise + + sql = self.assemble_program() + if not sql.strip(): + self._dirty = False + self._dropped_objects.clear() + return self._pipeline pipeline.start() pipeline.wait_for_status(PipelineStatus.RUNNING, timeout=timeout) @@ -269,11 +289,11 @@ def _hydrate_existing_program(self, client: t.Any, pipeline_name: str) -> None: except Exception: return - for expression in parse(program_code): + for expression in parse(program_code, dialect=FELDERA_DIALECT): if expression is None: continue - sql = expression.sql() + sql = expression.sql(dialect=FELDERA_DIALECT) if isinstance(expression, exp.Create): target = expression.this @@ -294,6 +314,29 @@ def _hydrate_existing_program(self, client: t.Any, pipeline_name: str) -> None: self._tables[object_key] = sql self._hydrated_object_keys.add(object_key) + def _evict_hydrated_objects(self, error_message: str) -> bool: + if not self._hydrated_object_keys: + return False + + normalized_error = error_message.lower() + object_keys = [ + object_key + for object_key in self._hydrated_object_keys + if object_key in normalized_error + ] + + if not object_keys: + if "not found" not in normalized_error: + return False + object_keys = list(self._hydrated_object_keys) + + for object_key in object_keys: + self._tables.pop(object_key, None) + self._views.pop(object_key, None) + self._hydrated_object_keys.discard(object_key) + + return True + def _compile_program( self, client: t.Any, @@ -314,13 +357,16 @@ def _compile_program( existing_pipeline = None if existing_pipeline is None: - return PipelineBuilder( - client, - name=pipeline_name, - sql=sql, - compilation_profile=profile, - runtime_config=runtime_config, - ).create_or_replace(wait=True) + try: + return PipelineBuilder( + client, + name=pipeline_name, + sql=sql, + compilation_profile=profile, + runtime_config=runtime_config, + ).create_or_replace(wait=True) + except RuntimeError as ex: + raise self._format_compile_error(client, pipeline_name, ex) from ex existing_pipeline.stop(force=True) existing_pipeline.wait_for_status(PipelineStatus.STOPPED, timeout=timeout) @@ -338,17 +384,65 @@ def _compile_program( }, runtime_config=runtime_config.to_dict(), ) - inner_pipeline = client.create_or_update_pipeline(inner_pipeline, wait=True) + try: + inner_pipeline = client.create_or_update_pipeline(inner_pipeline, wait=True) + except RuntimeError as ex: + raise self._format_compile_error(client, pipeline_name, ex) from ex pipeline = Pipeline(client) pipeline._inner = inner_pipeline return pipeline + def _format_compile_error(self, client: t.Any, pipeline_name: str, error: Exception) -> RuntimeError: + error_message = str(error) + + try: + from feldera.enums import PipelineFieldSelector + except Exception: + PipelineFieldSelector = None + + try: + field_selector = PipelineFieldSelector.ALL if PipelineFieldSelector else None + pipeline = client.get_pipeline(pipeline_name, field_selector) + except Exception: + return RuntimeError(error_message) + + program_error = getattr(pipeline, "program_error", None) or {} + sql_compilation = program_error.get("sql_compilation") or {} + sql_messages = sql_compilation.get("messages") or [] + + if sql_messages: + details = self._sql_compilation_error_message(pipeline_name, sql_messages) + if details != error_message: + return RuntimeError(details) + + rust_error = program_error.get("rust_compilation") + system_error = program_error.get("system_error") + if rust_error or system_error: + message = f"The program failed to compile: {getattr(pipeline, 'program_status', 'unknown')}\n" + if rust_error is not None: + message += f"Rust Error: {rust_error}\n" + if system_error is not None: + message += f"System Error: {system_error}" + return RuntimeError(message.rstrip()) + + return RuntimeError(error_message) + + @staticmethod + def _sql_compilation_error_message( + pipeline_name: str, sql_errors: t.Sequence[t.Mapping[str, t.Any]] + ) -> str: + err_msg = f"Pipeline {pipeline_name} failed to compile:\n" + for sql_error in sql_errors: + err_msg += f"{sql_error['error_type']}\n{sql_error['message']}\n" + err_msg += f"Code snippet:\n{sql_error['snippet']}" + return err_msg + def _extract_name(sql: str) -> str: stripped = _strip_leading_comments(sql) try: - expression = parse_one(stripped) + expression = parse_one(stripped, dialect=FELDERA_DIALECT) except (ParseError, ValueError): return stripped[:80].lower() @@ -387,7 +481,9 @@ def _query_mirror_table(table: exp.Table) -> exp.Table: def _ddl_statements_with_query_mirror(sql: str) -> t.List[str]: try: - expressions = [expression for expression in parse(sql) if expression is not None] + expressions = [ + expression for expression in parse(sql, dialect=FELDERA_DIALECT) if expression is not None + ] except (ParseError, ValueError): expressions = [] @@ -398,7 +494,7 @@ def _ddl_statements_with_query_mirror(sql: str) -> t.List[str]: statements = [] for expression in expressions: - statement = expression.sql().rstrip(";") + ";" + statement = expression.sql(dialect=FELDERA_DIALECT).rstrip(";") + ";" statements.append(statement) mirror_sql = _query_mirror_sql(statement) if mirror_sql: @@ -411,7 +507,7 @@ def _query_mirror_sql(sql: str) -> t.Optional[str]: stripped = _strip_leading_comments(sql) try: - expression = parse_one(stripped) + expression = parse_one(stripped, dialect=FELDERA_DIALECT) except (ParseError, ValueError): return None @@ -424,6 +520,8 @@ def _query_mirror_sql(sql: str) -> t.Optional[str]: if not isinstance(target, exp.Table) or _is_query_mirror_name(target.name): return None + if target.db and target.db.lower().startswith("sqlmesh__"): + return None kind = str(expression.args.get("kind") or "").upper() if "TABLE" not in kind and "VIEW" not in kind: @@ -431,11 +529,13 @@ def _query_mirror_sql(sql: str) -> t.Optional[str]: if "VIEW" in kind and _is_materialized_view_sql(sql): return None - return exp.Create( + mirror_sql = exp.Create( this=_query_mirror_table(target), kind="MATERIALIZED VIEW", expression=exp.select("*").from_(target.copy()), - ).sql() + ).sql(dialect=FELDERA_DIALECT) + + return _strip_table_qualifiers(mirror_sql) def _rewrite_query_for_query_mirrors(sql: str, relation_names: t.Set[str]) -> str: @@ -445,7 +545,7 @@ def _rewrite_query_for_query_mirrors(sql: str, relation_names: t.Set[str]) -> st stripped = _strip_leading_comments(sql) try: - expression = parse_one(stripped) + expression = parse_one(stripped, dialect=FELDERA_DIALECT) except (ParseError, ValueError): return sql @@ -465,7 +565,7 @@ def transform(node: exp.Expression) -> exp.Expression: return _query_mirror_table(node) return node - return expression.transform(transform, copy=True).sql() + return expression.transform(transform, copy=True).sql(dialect=FELDERA_DIALECT) def _execution_error(rows: t.List[t.Mapping[str, t.Any]]) -> t.Optional[str]: @@ -487,7 +587,7 @@ def _is_materialized_view_sql(sql: str) -> bool: return False try: - expression = parse_one(stripped) + expression = parse_one(stripped, dialect=FELDERA_DIALECT) except (ParseError, ValueError): expression = None @@ -502,7 +602,7 @@ def _rewrite_table_ctas_sql(sql: str) -> str: stripped = _strip_leading_comments(sql) try: - expression = parse_one(stripped) + expression = parse_one(stripped, dialect=FELDERA_DIALECT) except (ParseError, ValueError): return sql @@ -535,7 +635,10 @@ def _rewrite_table_ctas_sql(sql: str) -> str: properties=expression.args.get("properties"), ) insert_exp = exp.insert(query.copy(), target.copy(), columns=list(columns_to_types)) - return f"{create_exp.sql()};\n{insert_exp.sql()}" + return ( + f"{create_exp.sql(dialect=FELDERA_DIALECT)};\n" + f"{insert_exp.sql(dialect=FELDERA_DIALECT)}" + ) def _select_columns_to_types(query: exp.Expression) -> t.Optional[t.Dict[str, exp.DataType]]: @@ -582,12 +685,200 @@ def _strip_table_qualifiers(sql: str) -> str: stripped = _strip_leading_comments(sql) try: - expression = parse_one(stripped) + expression = parse_one(stripped, dialect=FELDERA_DIALECT) except (ParseError, ValueError): return stripped expression = expression.transform(_unqualify_table) - return expression.sql() + return expression.sql(dialect=FELDERA_DIALECT) + + +def _normalize_pipeline_ddl(sql: str) -> str: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + return stripped + + def transform(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Table) and node.db and node.db.lower().startswith("sqlmesh__"): + node = node.copy() + node.set("db", None) + return node + + return expression.transform(transform, copy=True).sql(dialect=FELDERA_DIALECT) + + +def _insert_to_input_json_payload( + sql: str, +) -> t.Optional[t.Tuple[str, t.List[t.Dict[str, t.Any]]]]: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + return None + + if not isinstance(expression, exp.Insert): + return None + + target = expression.this + if isinstance(target, exp.Schema): + table = target.this + target_columns = [column.name for column in target.expressions] + elif isinstance(target, exp.Table): + table = target + target_columns = [] + else: + return None + + if not isinstance(table, exp.Table): + return None + + query = expression.expression + if not isinstance(query, exp.Query): + return None + + values = _query_values_source(query) + if values is None: + return None + + alias = values.args.get("alias") + if alias is None: + return None + + source_columns = [column.name for column in alias.columns] + if not source_columns: + return None + + if not target_columns: + target_columns = [select.output_name for select in query.selects] + + if len(target_columns) != len(query.selects): + return None + + rows = [] + for row in values.expressions: + if not isinstance(row, exp.Tuple): + return None + + source_row = { + column: _literal_value(value) + for column, value in zip(source_columns, row.expressions) + } + payload_row: t.Dict[str, t.Any] = {} + for target_column, select in zip(target_columns, query.selects): + if not target_column: + return None + + evaluated = _evaluate_insert_value_expression(select, source_row) + if evaluated is _UNSUPPORTED_INGEST_EXPRESSION: + return None + payload_row[target_column] = evaluated + rows.append(payload_row) + + return table.name, rows + + +def _query_values_source(query: exp.Query) -> t.Optional[exp.Values]: + from_expression = query.args.get("from_") + if from_expression is None: + return None + + source = from_expression.this + if isinstance(source, exp.Values): + return source + if isinstance(source, exp.Subquery) and isinstance(source.this, exp.Query): + return _query_values_source(source.this) + return None + + +_UNSUPPORTED_INGEST_EXPRESSION = object() + + +def _evaluate_insert_value_expression( + expression: exp.Expression, source_row: t.Mapping[str, t.Any] +) -> t.Any: + if isinstance(expression, exp.Alias): + return _evaluate_insert_value_expression(expression.this, source_row) + + if isinstance(expression, exp.Cast): + value = _evaluate_insert_value_expression(expression.this, source_row) + if value is _UNSUPPORTED_INGEST_EXPRESSION: + return value + to_type = expression.args.get("to") + if isinstance(to_type, exp.DataType): + return _coerce_input_json_value(value, to_type) + return value + + if isinstance(expression, exp.Column): + return source_row.get(expression.name) + + if isinstance(expression, (exp.Literal, exp.Boolean, exp.Null)): + return _literal_value(expression) + + if isinstance(expression, exp.Paren): + return _evaluate_insert_value_expression(expression.this, source_row) + + if isinstance(expression, exp.Neg): + value = _evaluate_insert_value_expression(expression.this, source_row) + if isinstance(value, (int, float)): + return -value + return _UNSUPPORTED_INGEST_EXPRESSION + + return _UNSUPPORTED_INGEST_EXPRESSION + + +def _literal_value(expression: exp.Expression) -> t.Any: + if isinstance(expression, (exp.Literal, exp.Boolean, exp.Null)): + return expression.to_py() + return _UNSUPPORTED_INGEST_EXPRESSION + + +def _coerce_input_json_value(value: t.Any, data_type: exp.DataType) -> t.Any: + if value is None: + return None + + dtype = data_type.this + if dtype in { + exp.DataType.Type.TINYINT, + exp.DataType.Type.SMALLINT, + exp.DataType.Type.INT, + exp.DataType.Type.BIGINT, + }: + return int(value) + + if dtype in { + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + exp.DataType.Type.DECIMAL, + }: + return float(value) + + if dtype == exp.DataType.Type.BOOLEAN: + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"true", "t", "1"}: + return True + if normalized in {"false", "f", "0"}: + return False + return bool(value) + + if dtype in { + exp.DataType.Type.CHAR, + exp.DataType.Type.NCHAR, + exp.DataType.Type.TEXT, + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NVARCHAR, + exp.DataType.Type.DATE, + exp.DataType.Type.TIME, + exp.DataType.Type.TIMESTAMP, + exp.DataType.Type.TIMESTAMPTZ, + }: + return str(value) + + return value def _unqualify_table(node: exp.Expression) -> exp.Expression: @@ -627,10 +918,14 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: "Feldera DB-API does not support query parameters" ) - sql = _strip_table_qualifiers(sql) - sql = _normalize_ddl(sql) - - intent = _classify(sql) + original_sql = sql + normalized_sql = _normalize_ddl(sql) + intent = _classify(normalized_sql) + sql = ( + _normalize_pipeline_ddl(normalized_sql) + if intent == SqlIntent.PIPELINE_DDL + else _strip_table_qualifiers(normalized_sql) + ) logger.debug("Feldera execute (intent=%s): %.200s", intent.value, sql) if intent == SqlIntent.NO_OP: @@ -639,6 +934,10 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: return if intent == SqlIntent.PIPELINE_DDL: + if _is_virtual_layer_ddl(original_sql): + self._rows = [] + self._columns = [] + return self._state.register_ddl(sql) self._rows = [] self._columns = [] @@ -654,7 +953,13 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: ) if intent == SqlIntent.DATA_INGRESS: - self._get_pipeline().execute(sql) + pipeline = self._get_pipeline() + payload = _insert_to_input_json_payload(sql) + if payload is not None: + table_name, rows = payload + pipeline.input_json(table_name, rows) + else: + pipeline.execute(sql) self._rows = [] self._columns = [] return @@ -746,13 +1051,21 @@ def rollback(self) -> None: def close(self) -> None: if self._state.has_pending_changes(): - self._state.deploy( - self._client, - self._pipeline_name, - self._workers, - self._compilation_profile, - self._timeout, - ) + try: + self._state.deploy( + self._client, + self._pipeline_name, + self._workers, + self._compilation_profile, + self._timeout, + ) + except Exception as ex: + logger.error( + "Feldera pending DDL failed during connection close for pipeline %s:\n%s", + self._pipeline_name, + ex, + ) + raise return None diff --git a/tests/core/engine_adapter/test_feldera.py b/tests/core/engine_adapter/test_feldera.py index 41339b892c..fe9894d505 100644 --- a/tests/core/engine_adapter/test_feldera.py +++ b/tests/core/engine_adapter/test_feldera.py @@ -127,6 +127,13 @@ def test_builtin_dialect_registers_feldera_name() -> None: assert parse_one("SELECT 1", dialect="feldera").sql(dialect="feldera") == "SELECT 1" +def test_builtin_dialect_preserves_current_timestamp_keyword() -> None: + assert ( + parse_one("SELECT CURRENT_TIMESTAMP AS ts", dialect="feldera").sql(dialect="feldera") + == "SELECT CURRENT_TIMESTAMP AS ts" + ) + + def test_get_data_objects_marks_materialized_views_from_state( adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch ): @@ -163,6 +170,28 @@ def test_replace_query_creates_table( ] +def test_replace_query_recreates_existing_table( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject(schema="db", name="full_model", type=DataObjectType.TABLE), + ) + + adapter.replace_query( + "db.full_model", + parse_one("SELECT a FROM tbl"), + {"a": exp.DataType.build("INT")}, + ) + + assert to_sql_calls(adapter) == [ + 'DROP TABLE IF EXISTS "db"."full_model"', + 'CREATE TABLE IF NOT EXISTS "db"."full_model" ("a" INTEGER)', + 'INSERT INTO "db"."full_model" ("a") SELECT "a" FROM "tbl"', + ] + + def test_insert_overwrite_by_time_partition_uses_table_operations( adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch ): diff --git a/tests/engines/feldera/test_db_api.py b/tests/engines/feldera/test_db_api.py index 0f6e6176b1..b50b6ba37a 100644 --- a/tests/engines/feldera/test_db_api.py +++ b/tests/engines/feldera/test_db_api.py @@ -13,6 +13,44 @@ def test_classify_treats_comment_prefixed_create_schema_as_pipeline_ddl() -> Non ) +def test_is_virtual_layer_ddl_identifies_environment_alias_view() -> None: + assert db_api._is_virtual_layer_ddl( + 'CREATE VIEW "polymarket__dev"."live_trades" AS ' + 'SELECT * FROM "sqlmesh__polymarket"."polymarket__live_trades__1782741465__dev"' + ) + + assert not db_api._is_virtual_layer_ddl( + 'CREATE MATERIALIZED VIEW "sqlmesh__polymarket"."polymarket__rolling_vwap__1225616675__dev" AS ' + 'SELECT * FROM "sqlmesh__polymarket"."polymarket__live_trades__1782741465__dev"' + ) + + +def test_strip_table_qualifiers_preserves_current_timestamp_keyword() -> None: + sql = ( + 'CREATE MATERIALIZED VIEW "db"."view_model" AS ' + 'SELECT CURRENT_TIMESTAMP AS ts FROM "db"."source"' + ) + + assert db_api._strip_table_qualifiers(sql) == ( + 'CREATE MATERIALIZED VIEW "view_model" AS ' + 'SELECT CURRENT_TIMESTAMP AS ts FROM "source"' + ) + + +def test_normalize_pipeline_ddl_strips_only_sqlmesh_internal_qualifiers() -> None: + sql = ( + 'CREATE MATERIALIZED VIEW "sqlmesh__polymarket"."polymarket__rolling_vwap__1225616675__dev" AS ' + 'SELECT "live_trades"."market_id" AS "market_id" ' + 'FROM "polymarket"."live_trades" AS "live_trades"' + ) + + assert db_api._normalize_pipeline_ddl(sql) == ( + 'CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__1225616675__dev" AS ' + 'SELECT "live_trades"."market_id" AS "market_id" ' + 'FROM "polymarket"."live_trades" AS "live_trades"' + ) + + def test_cursor_defers_pipeline_deploy_until_non_ddl_statement() -> None: class FakeStateManager: def __init__(self) -> None: @@ -54,10 +92,71 @@ def queryable_relation_names(self) -> set[str]: assert cursor.fetchall() == [(1,)] +def test_cursor_ignores_virtual_layer_view_ddl() -> None: + class FakeStateManager: + def __init__(self) -> None: + self.registered_sql = [] + + def register_ddl(self, sql: str) -> None: + self.registered_sql.append(sql) + + def has_pending_changes(self) -> bool: + return False + + state_manager = FakeStateManager() + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=state_manager, + ) + + cursor.execute( + 'CREATE VIEW "polymarket__dev"."live_trades" AS ' + 'SELECT * FROM "sqlmesh__polymarket"."polymarket__live_trades__1782741465__dev"' + ) + + assert state_manager.registered_sql == [] + + +def test_cursor_preserves_schema_qualified_references_in_pipeline_ddl() -> None: + class FakeStateManager: + def __init__(self) -> None: + self.registered_sql = [] + + def register_ddl(self, sql: str) -> None: + self.registered_sql.append(sql) + + def has_pending_changes(self) -> bool: + return False + + state_manager = FakeStateManager() + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=state_manager, + ) + + cursor.execute( + 'CREATE MATERIALIZED VIEW "sqlmesh__polymarket"."polymarket__trade_price_observations__1956259901__dev" AS ' + 'SELECT "live_trades"."market_id" AS "market_id" ' + 'FROM "polymarket"."live_trades" AS "live_trades"' + ) + + assert state_manager.registered_sql == [ + 'CREATE MATERIALIZED VIEW "polymarket__trade_price_observations__1956259901__dev" AS ' + 'SELECT "live_trades"."market_id" AS "market_id" ' + 'FROM "polymarket"."live_trades" AS "live_trades"' + ] + + def test_hydrate_existing_program_skips_empty_parse_results(monkeypatch) -> None: manager = db_api.PipelineStateManager() - monkeypatch.setattr(db_api, "parse", lambda sql: [None, parse_one("CREATE TABLE foo (id INT)")]) + monkeypatch.setattr( + db_api, + "parse", + lambda sql, **kwargs: [None, parse_one("CREATE TABLE foo (id INT)", **kwargs)], + ) pipeline_module = types.ModuleType("feldera.pipeline") pipeline_module.Pipeline = type( @@ -101,12 +200,99 @@ def test_state_manager_rewrites_table_ctas() -> None: ) assert manager.assemble_program() == ( - 'CREATE TABLE "seed_model" ("id" INT);\n' + 'CREATE TABLE "seed_model" ("id" INTEGER);\n' '\n' 'CREATE MATERIALIZED VIEW "__sqlmesh_query__seed_model" AS SELECT * FROM "seed_model";\n' '\n' - 'INSERT INTO "seed_model" (id) SELECT CAST("id" AS INT) AS "id" FROM (VALUES (1)) AS "t"("id");' + 'INSERT INTO "seed_model" (id) SELECT CAST("id" AS INTEGER) AS "id" FROM (VALUES (1)) AS "t"("id");' + ) + + +def test_evict_hydrated_objects_removes_stale_object_from_compile_error() -> None: + manager = db_api.PipelineStateManager() + manager._views = { + 'polymarket__rolling_vwap__781619724__dev': ( + 'CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__781619724__dev" AS ' + 'SELECT CURRENT_TIMESTAMP AS ts' + ), + 'polymarket__rolling_vwap__1225616675__dev': ( + 'CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__1225616675__dev" AS ' + 'SELECT NOW() AS ts' + ), + } + manager._hydrated_object_keys = { + 'polymarket__rolling_vwap__781619724__dev', + 'polymarket__rolling_vwap__1225616675__dev', + } + + removed = manager._evict_hydrated_objects( + 'Compilation error in CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__781619724__dev"' + ) + + assert removed is True + assert 'polymarket__rolling_vwap__781619724__dev' not in manager.pending_views() + assert 'polymarket__rolling_vwap__1225616675__dev' in manager.pending_views() + + +def test_format_compile_error_preserves_sql_compilation_details() -> None: + manager = db_api.PipelineStateManager() + + class FakeClient: + def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: + return types.SimpleNamespace( + program_status="SqlError", + program_error={ + "sql_compilation": { + "messages": [ + { + "error_type": "Compilation error", + "message": "Object 'polymarket__live_trades__1782741465__dev' not found", + "snippet": '1|CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__1225616675__dev" AS SELECT ...', + } + ] + } + }, + ) + + error = manager._format_compile_error( + FakeClient(), + "polymarket", + RuntimeError("The program failed to compile: SqlError"), + ) + + assert str(error) == ( + "Pipeline polymarket failed to compile:\n" + "Compilation error\n" + "Object 'polymarket__live_trades__1782741465__dev' not found\n" + 'Code snippet:\n1|CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__1225616675__dev" AS SELECT ...' + ) + + +def test_connection_close_logs_pending_compile_error(caplog: pytest.LogCaptureFixture) -> None: + class FakeStateManager: + def has_pending_changes(self) -> bool: + return True + + def deploy(self, *args: object, **kwargs: object) -> object: + raise RuntimeError( + "Pipeline polymarket failed to compile:\n" + "Compilation error\n" + "TIMESTAMP_TRUNC problem" + ) + + connection = db_api.FelderaConnection( + client=object(), + host="http://localhost:8080", + pipeline_name="polymarket", ) + connection._state = FakeStateManager() + + with caplog.at_level("ERROR", logger="sqlmesh.engines.feldera.db_api"): + with pytest.raises(RuntimeError, match="TIMESTAMP_TRUNC problem"): + connection.close() + + assert "Feldera pending DDL failed during connection close for pipeline polymarket" in caplog.text + assert "TIMESTAMP_TRUNC problem" in caplog.text def test_state_manager_adds_query_mirrors_for_non_materialized_relations() -> None: @@ -125,6 +311,33 @@ def test_state_manager_adds_query_mirrors_for_non_materialized_relations() -> No assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__materialized_view_model"' not in program +def test_state_manager_query_mirror_strips_schema_qualifiers() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl( + 'CREATE TABLE "polymarket"."markets_sample_seed" AS ' + 'SELECT CAST("id" AS INTEGER) AS "id" FROM (VALUES (1)) AS "t"("id")' + ) + + program = manager.assemble_program() + + assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__markets_sample_seed" AS SELECT * FROM "markets_sample_seed";' in program + assert 'SELECT * FROM "polymarket"."markets_sample_seed"' not in program + + +def test_state_manager_skips_query_mirrors_for_sqlmesh_internal_objects() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl( + 'CREATE TABLE "sqlmesh__polymarket"."polymarket__l1_orderbook__441175831__dev" ' + '("market_id" VARCHAR, "best_bid" DOUBLE)' + ) + + program = manager.assemble_program() + + assert '__sqlmesh_query__polymarket__l1_orderbook__441175831__dev' not in program + + def test_hydrate_existing_program_skips_query_mirrors(monkeypatch) -> None: manager = db_api.PipelineStateManager() @@ -188,6 +401,81 @@ def current_pipeline(self) -> object: assert cursor.fetchone() == (1,) +def test_insert_to_input_json_payload_rewrites_seed_values() -> None: + payload = db_api._insert_to_input_json_payload( + 'INSERT INTO "seed_model" ("market_id", "volume_24h", "end_date") ' + 'SELECT CAST("market_id" AS TEXT) AS "market_id", ' + 'CAST("volume_24h" AS DOUBLE) AS "volume_24h", ' + 'CAST("end_date" AS TIMESTAMP) AS "end_date" ' + 'FROM (VALUES ' + "('123', '4.5', '2026-05-09 01:00:00'), " + "('456', '7.0', NULL)" + ') AS "t"("market_id", "volume_24h", "end_date")' + ) + + assert payload == ( + "seed_model", + [ + { + "market_id": "123", + "volume_24h": 4.5, + "end_date": "2026-05-09 01:00:00", + }, + { + "market_id": "456", + "volume_24h": 7.0, + "end_date": None, + }, + ], + ) + + +def test_cursor_uses_input_json_for_seed_inserts() -> None: + captured_rows = [] + executed_sql = [] + + class FakePipeline: + def input_json(self, table_name: str, data: object, **kwargs: object) -> None: + captured_rows.append((table_name, data, kwargs)) + + def execute(self, sql: str) -> None: + executed_sql.append(sql) + + class FakeStateManager: + def has_pending_changes(self) -> bool: + return False + + def current_pipeline(self) -> object: + return FakePipeline() + + def queryable_relation_names(self) -> set[str]: + return set() + + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=FakeStateManager(), + ) + + cursor.execute( + 'INSERT INTO "seed_model" ("a", "b") ' + 'SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" ' + 'FROM (VALUES (1, 4), (2, 5)) AS "t"("a", "b")' + ) + + assert executed_sql == [] + assert captured_rows == [ + ( + "seed_model", + [ + {"a": 1, "b": 4}, + {"a": 2, "b": 5}, + ], + {}, + ) + ] + + def test_cursor_raises_execution_error_from_query_rows() -> None: class FakeStateManager: def has_pending_changes(self) -> bool: From 7e27f16c979b31cf23a211ec4954f2c78d6d3ae6 Mon Sep 17 00:00:00 2001 From: fresioAS Date: Tue, 12 May 2026 21:38:21 +0200 Subject: [PATCH 06/10] pipeline identifier --- sqlmesh/engines/feldera/db_api.py | 51 +++++++- tests/engines/feldera/test_db_api.py | 184 +++++++++++++++++++-------- 2 files changed, 175 insertions(+), 60 deletions(-) diff --git a/sqlmesh/engines/feldera/db_api.py b/sqlmesh/engines/feldera/db_api.py index 7d7bb63fb5..d5b53dc523 100644 --- a/sqlmesh/engines/feldera/db_api.py +++ b/sqlmesh/engines/feldera/db_api.py @@ -1,6 +1,5 @@ from __future__ import annotations -import copy import logging import threading import typing as t @@ -689,10 +688,50 @@ def _strip_table_qualifiers(sql: str) -> str: except (ParseError, ValueError): return stripped - expression = expression.transform(_unqualify_table) + def transform(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Table): + node = _canonicalize_snapshot_table(node) + node = _unqualify_table(node) + return node + + expression = expression.transform(transform, copy=True) return expression.sql(dialect=FELDERA_DIALECT) +def _snapshot_name_parts(name: str) -> t.Optional[t.Tuple[str, str]]: + parts = name.split("__") + if len(parts) < 3: + return None + + if parts[-1].isdigit(): + schema_name = parts[0] + model_name = "__".join(parts[1:-1]) + return (schema_name, model_name) if model_name else None + + if len(parts) >= 4 and parts[-2].isdigit(): + schema_name = parts[0] + model_name = "__".join(parts[1:-2]) + return (schema_name, model_name) if model_name else None + + return None + + +def _canonicalize_snapshot_table(node: exp.Table) -> exp.Table: + if _is_query_mirror_name(node.name): + return node + + snapshot_parts = _snapshot_name_parts(node.name) + if snapshot_parts is None: + return node + + _, model_name = snapshot_parts + canonicalized = node.copy() + canonicalized.set("this", exp.to_identifier(model_name, quoted=True)) + canonicalized.set("db", None) + canonicalized.set("catalog", None) + return canonicalized + + def _normalize_pipeline_ddl(sql: str) -> str: stripped = _strip_leading_comments(sql) @@ -702,9 +741,9 @@ def _normalize_pipeline_ddl(sql: str) -> str: return stripped def transform(node: exp.Expression) -> exp.Expression: - if isinstance(node, exp.Table) and node.db and node.db.lower().startswith("sqlmesh__"): - node = node.copy() - node.set("db", None) + if isinstance(node, exp.Table): + node = _canonicalize_snapshot_table(node) + node = _unqualify_table(node) return node return expression.transform(transform, copy=True).sql(dialect=FELDERA_DIALECT) @@ -736,6 +775,8 @@ def _insert_to_input_json_payload( if not isinstance(table, exp.Table): return None + table = _canonicalize_snapshot_table(table) + query = expression.expression if not isinstance(query, exp.Query): return None diff --git a/tests/engines/feldera/test_db_api.py b/tests/engines/feldera/test_db_api.py index b50b6ba37a..2e7e0ce332 100644 --- a/tests/engines/feldera/test_db_api.py +++ b/tests/engines/feldera/test_db_api.py @@ -15,13 +15,13 @@ def test_classify_treats_comment_prefixed_create_schema_as_pipeline_ddl() -> Non def test_is_virtual_layer_ddl_identifies_environment_alias_view() -> None: assert db_api._is_virtual_layer_ddl( - 'CREATE VIEW "polymarket__dev"."live_trades" AS ' - 'SELECT * FROM "sqlmesh__polymarket"."polymarket__live_trades__1782741465__dev"' + 'CREATE VIEW "analytics__dev"."source_events" AS ' + 'SELECT * FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev"' ) assert not db_api._is_virtual_layer_ddl( - 'CREATE MATERIALIZED VIEW "sqlmesh__polymarket"."polymarket__rolling_vwap__1225616675__dev" AS ' - 'SELECT * FROM "sqlmesh__polymarket"."polymarket__live_trades__1782741465__dev"' + 'CREATE MATERIALIZED VIEW "sqlmesh__analytics"."analytics__aggregate_view__1225616675__dev" AS ' + 'SELECT * FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev"' ) @@ -37,17 +37,28 @@ def test_strip_table_qualifiers_preserves_current_timestamp_keyword() -> None: ) -def test_normalize_pipeline_ddl_strips_only_sqlmesh_internal_qualifiers() -> None: +def test_normalize_pipeline_ddl_canonicalizes_snapshot_names_to_logical_names() -> None: sql = ( - 'CREATE MATERIALIZED VIEW "sqlmesh__polymarket"."polymarket__rolling_vwap__1225616675__dev" AS ' - 'SELECT "live_trades"."market_id" AS "market_id" ' - 'FROM "polymarket"."live_trades" AS "live_trades"' + 'CREATE MATERIALIZED VIEW "sqlmesh__analytics"."analytics__aggregate_view__1225616675__dev" AS ' + 'SELECT "source_events"."entity_id" AS "entity_id" ' + 'FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev" AS "source_events"' ) assert db_api._normalize_pipeline_ddl(sql) == ( - 'CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__1225616675__dev" AS ' - 'SELECT "live_trades"."market_id" AS "market_id" ' - 'FROM "polymarket"."live_trades" AS "live_trades"' + 'CREATE MATERIALIZED VIEW "aggregate_view" AS ' + 'SELECT "source_events"."entity_id" AS "entity_id" ' + 'FROM "source_events" AS "source_events"' + ) + + +def test_normalize_pipeline_ddl_strips_schema_from_logical_tables() -> None: + sql = ( + 'CREATE TABLE "analytics"."sample_seed" ' + '("entity_id" VARCHAR, "description" VARCHAR)' + ) + + assert db_api._normalize_pipeline_ddl(sql) == ( + 'CREATE TABLE "sample_seed" ("entity_id" VARCHAR, "description" VARCHAR)' ) @@ -111,14 +122,14 @@ def has_pending_changes(self) -> bool: ) cursor.execute( - 'CREATE VIEW "polymarket__dev"."live_trades" AS ' - 'SELECT * FROM "sqlmesh__polymarket"."polymarket__live_trades__1782741465__dev"' + 'CREATE VIEW "analytics__dev"."source_events" AS ' + 'SELECT * FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev"' ) assert state_manager.registered_sql == [] -def test_cursor_preserves_schema_qualified_references_in_pipeline_ddl() -> None: +def test_cursor_registers_logical_model_names_in_pipeline_ddl() -> None: class FakeStateManager: def __init__(self) -> None: self.registered_sql = [] @@ -137,15 +148,15 @@ def has_pending_changes(self) -> bool: ) cursor.execute( - 'CREATE MATERIALIZED VIEW "sqlmesh__polymarket"."polymarket__trade_price_observations__1956259901__dev" AS ' - 'SELECT "live_trades"."market_id" AS "market_id" ' - 'FROM "polymarket"."live_trades" AS "live_trades"' + 'CREATE MATERIALIZED VIEW "sqlmesh__analytics"."analytics__aggregated_observations__1956259901__dev" AS ' + 'SELECT "source_events"."entity_id" AS "entity_id" ' + 'FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev" AS "source_events"' ) assert state_manager.registered_sql == [ - 'CREATE MATERIALIZED VIEW "polymarket__trade_price_observations__1956259901__dev" AS ' - 'SELECT "live_trades"."market_id" AS "market_id" ' - 'FROM "polymarket"."live_trades" AS "live_trades"' + 'CREATE MATERIALIZED VIEW "aggregated_observations" AS ' + 'SELECT "source_events"."entity_id" AS "entity_id" ' + 'FROM "source_events" AS "source_events"' ] @@ -211,27 +222,27 @@ def test_state_manager_rewrites_table_ctas() -> None: def test_evict_hydrated_objects_removes_stale_object_from_compile_error() -> None: manager = db_api.PipelineStateManager() manager._views = { - 'polymarket__rolling_vwap__781619724__dev': ( - 'CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__781619724__dev" AS ' + 'analytics__aggregate_view__781619724__dev': ( + 'CREATE MATERIALIZED VIEW "analytics__aggregate_view__781619724__dev" AS ' 'SELECT CURRENT_TIMESTAMP AS ts' ), - 'polymarket__rolling_vwap__1225616675__dev': ( - 'CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__1225616675__dev" AS ' + 'analytics__aggregate_view__1225616675__dev': ( + 'CREATE MATERIALIZED VIEW "analytics__aggregate_view__1225616675__dev" AS ' 'SELECT NOW() AS ts' ), } manager._hydrated_object_keys = { - 'polymarket__rolling_vwap__781619724__dev', - 'polymarket__rolling_vwap__1225616675__dev', + 'analytics__aggregate_view__781619724__dev', + 'analytics__aggregate_view__1225616675__dev', } removed = manager._evict_hydrated_objects( - 'Compilation error in CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__781619724__dev"' + 'Compilation error in CREATE MATERIALIZED VIEW "analytics__aggregate_view__781619724__dev"' ) assert removed is True - assert 'polymarket__rolling_vwap__781619724__dev' not in manager.pending_views() - assert 'polymarket__rolling_vwap__1225616675__dev' in manager.pending_views() + assert 'analytics__aggregate_view__781619724__dev' not in manager.pending_views() + assert 'analytics__aggregate_view__1225616675__dev' in manager.pending_views() def test_format_compile_error_preserves_sql_compilation_details() -> None: @@ -246,8 +257,8 @@ def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: "messages": [ { "error_type": "Compilation error", - "message": "Object 'polymarket__live_trades__1782741465__dev' not found", - "snippet": '1|CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__1225616675__dev" AS SELECT ...', + "message": "Object 'analytics__source_events__1782741465__dev' not found", + "snippet": '1|CREATE MATERIALIZED VIEW "analytics__aggregate_view__1225616675__dev" AS SELECT ...', } ] } @@ -256,15 +267,15 @@ def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: error = manager._format_compile_error( FakeClient(), - "polymarket", + "test_pipeline", RuntimeError("The program failed to compile: SqlError"), ) assert str(error) == ( - "Pipeline polymarket failed to compile:\n" + "Pipeline test_pipeline failed to compile:\n" "Compilation error\n" - "Object 'polymarket__live_trades__1782741465__dev' not found\n" - 'Code snippet:\n1|CREATE MATERIALIZED VIEW "polymarket__rolling_vwap__1225616675__dev" AS SELECT ...' + "Object 'analytics__source_events__1782741465__dev' not found\n" + 'Code snippet:\n1|CREATE MATERIALIZED VIEW "analytics__aggregate_view__1225616675__dev" AS SELECT ...' ) @@ -275,7 +286,7 @@ def has_pending_changes(self) -> bool: def deploy(self, *args: object, **kwargs: object) -> object: raise RuntimeError( - "Pipeline polymarket failed to compile:\n" + "Pipeline test_pipeline failed to compile:\n" "Compilation error\n" "TIMESTAMP_TRUNC problem" ) @@ -283,7 +294,7 @@ def deploy(self, *args: object, **kwargs: object) -> object: connection = db_api.FelderaConnection( client=object(), host="http://localhost:8080", - pipeline_name="polymarket", + pipeline_name="test_pipeline", ) connection._state = FakeStateManager() @@ -291,7 +302,7 @@ def deploy(self, *args: object, **kwargs: object) -> object: with pytest.raises(RuntimeError, match="TIMESTAMP_TRUNC problem"): connection.close() - assert "Feldera pending DDL failed during connection close for pipeline polymarket" in caplog.text + assert "Feldera pending DDL failed during connection close for pipeline test_pipeline" in caplog.text assert "TIMESTAMP_TRUNC problem" in caplog.text @@ -315,27 +326,44 @@ def test_state_manager_query_mirror_strips_schema_qualifiers() -> None: manager = db_api.PipelineStateManager() manager.register_ddl( - 'CREATE TABLE "polymarket"."markets_sample_seed" AS ' + 'CREATE TABLE "analytics"."sample_seed" AS ' 'SELECT CAST("id" AS INTEGER) AS "id" FROM (VALUES (1)) AS "t"("id")' ) program = manager.assemble_program() - assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__markets_sample_seed" AS SELECT * FROM "markets_sample_seed";' in program - assert 'SELECT * FROM "polymarket"."markets_sample_seed"' not in program + assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__sample_seed" AS SELECT * FROM "sample_seed";' in program + assert 'SELECT * FROM "analytics"."sample_seed"' not in program def test_state_manager_skips_query_mirrors_for_sqlmesh_internal_objects() -> None: manager = db_api.PipelineStateManager() manager.register_ddl( - 'CREATE TABLE "sqlmesh__polymarket"."polymarket__l1_orderbook__441175831__dev" ' - '("market_id" VARCHAR, "best_bid" DOUBLE)' + 'CREATE TABLE "sqlmesh__analytics"."analytics__source_snapshot__441175831__dev" ' + '("entity_id" VARCHAR, "metric_value" DOUBLE)' ) program = manager.assemble_program() - assert '__sqlmesh_query__polymarket__l1_orderbook__441175831__dev' not in program + assert '__sqlmesh_query__analytics__source_snapshot__441175831__dev' not in program + + +def test_state_manager_canonicalizes_snapshot_tables_to_logical_names() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl( + db_api._normalize_pipeline_ddl( + 'CREATE TABLE "sqlmesh__analytics"."analytics__records__849752499__dev" ' + '("entity_id" VARCHAR, "description" VARCHAR)' + ) + ) + + program = manager.assemble_program() + + assert 'CREATE TABLE "records" ("entity_id" VARCHAR, "description" VARCHAR);' in program + assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__records" AS SELECT * FROM "records";' in program + assert 'analytics__records__849752499__dev' not in program def test_hydrate_existing_program_skips_query_mirrors(monkeypatch) -> None: @@ -401,35 +429,81 @@ def current_pipeline(self) -> object: assert cursor.fetchone() == (1,) +def test_cursor_rewrites_snapshot_queries_to_logical_query_mirrors() -> None: + captured_queries = [] + + class FakeStateManager: + def has_pending_changes(self) -> bool: + return False + + def queryable_relation_names(self) -> set[str]: + return {"source_events"} + + def current_pipeline(self) -> object: + return types.SimpleNamespace( + query=lambda sql: captured_queries.append(sql) or [{"count": 1}] + ) + + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=FakeStateManager(), + ) + + cursor.execute('SELECT COUNT(*) FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev"') + + assert parse_one(captured_queries[0]).sql() == parse_one( + 'SELECT COUNT(*) FROM "__sqlmesh_query__source_events"' + ).sql() + assert cursor.fetchone() == (1,) + + def test_insert_to_input_json_payload_rewrites_seed_values() -> None: payload = db_api._insert_to_input_json_payload( - 'INSERT INTO "seed_model" ("market_id", "volume_24h", "end_date") ' - 'SELECT CAST("market_id" AS TEXT) AS "market_id", ' - 'CAST("volume_24h" AS DOUBLE) AS "volume_24h", ' - 'CAST("end_date" AS TIMESTAMP) AS "end_date" ' + 'INSERT INTO "seed_model" ("entity_id", "score", "event_ts") ' + 'SELECT CAST("entity_id" AS TEXT) AS "entity_id", ' + 'CAST("score" AS DOUBLE) AS "score", ' + 'CAST("event_ts" AS TIMESTAMP) AS "event_ts" ' 'FROM (VALUES ' "('123', '4.5', '2026-05-09 01:00:00'), " "('456', '7.0', NULL)" - ') AS "t"("market_id", "volume_24h", "end_date")' + ') AS "t"("entity_id", "score", "event_ts")' ) assert payload == ( "seed_model", [ { - "market_id": "123", - "volume_24h": 4.5, - "end_date": "2026-05-09 01:00:00", + "entity_id": "123", + "score": 4.5, + "event_ts": "2026-05-09 01:00:00", }, { - "market_id": "456", - "volume_24h": 7.0, - "end_date": None, + "entity_id": "456", + "score": 7.0, + "event_ts": None, }, ], ) +def test_insert_to_input_json_payload_canonicalizes_snapshot_target() -> None: + payload = db_api._insert_to_input_json_payload( + 'INSERT INTO "sqlmesh__analytics"."analytics__sample_seed__2883946936__dev" ("entity_id") ' + 'SELECT CAST("entity_id" AS TEXT) AS "entity_id" ' + 'FROM (VALUES (\'123\')) AS "t"("entity_id")' + ) + + assert payload == ( + "sample_seed", + [ + { + "entity_id": "123", + } + ], + ) + + def test_cursor_uses_input_json_for_seed_inserts() -> None: captured_rows = [] executed_sql = [] From 7a0008494869cc58f02ddbd9bf3a78efec76ae73 Mon Sep 17 00:00:00 2001 From: fresioAS Date: Tue, 12 May 2026 23:29:21 +0200 Subject: [PATCH 07/10] make style --- pyproject.toml | 2 + sqlmesh/core/config/connection.py | 3 +- sqlmesh/core/engine_adapter/feldera.py | 19 +-- sqlmesh/engines/feldera/__init__.py | 2 +- sqlmesh/engines/feldera/db_api.py | 55 +++---- tests/core/engine_adapter/test_feldera.py | 10 +- tests/engines/feldera/test_db_api.py | 178 ++++++++++++---------- 7 files changed, 144 insertions(+), 125 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7f8d1b50b5..28aff0b439 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ dev = [ "dbt-redshift", "dbt-trino", "Faker", + "feldera", "google-auth", "google-cloud-bigquery", "google-cloud-bigquery-storage", @@ -219,6 +220,7 @@ module = [ "bs4.*", "pydantic_core.*", "dlt.*", + "feldera.*", "bigframes.*", "json_stream.*", "duckdb.*", diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 057a321d46..de0e0cdb01 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -2341,7 +2341,7 @@ def init(cursor: t.Any) -> None: return init - + class FelderaConnectionConfig(ConnectionConfig): """Feldera connection configuration.""" @@ -2389,6 +2389,7 @@ def _connection_factory(self) -> t.Callable: def get_catalog(self) -> t.Optional[str]: return None + _CONNECTION_CONFIG_EXCLUDE: t.Set[t.Type[ConnectionConfig]] = { ConnectionConfig, # type: ignore[type-abstract] BaseDuckDBConnectionConfig, # type: ignore[type-abstract] diff --git a/sqlmesh/core/engine_adapter/feldera.py b/sqlmesh/core/engine_adapter/feldera.py index 692ec61b02..c4dfa85890 100644 --- a/sqlmesh/core/engine_adapter/feldera.py +++ b/sqlmesh/core/engine_adapter/feldera.py @@ -5,7 +5,7 @@ from sqlglot import exp from sqlglot.dialects.dialect import Dialect -from sqlglot.generator import Generator +from sqlglot.generator import Generator as SQLGlotGenerator from sqlmesh.core.dialect import to_schema from sqlmesh.core.engine_adapter.base import EngineAdapter @@ -42,14 +42,14 @@ def _view_type(state: t.Any, object_name: str) -> DataObjectType: class _SQLMeshFeldera(Dialect): - class Generator(Generator): + class Generator(SQLGlotGenerator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **SQLGlotGenerator.TYPE_MAPPING, exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.INT: "INTEGER", } TRANSFORMS = { - **Generator.TRANSFORMS, + **SQLGlotGenerator.TRANSFORMS, exp.CurrentTimestamp: lambda self, expression: ( self.func("CURRENT_TIMESTAMP", expression.this) if expression.this @@ -65,7 +65,7 @@ class Generator(Generator): Dialect.classes["feldera"] = _SQLMeshFeldera -_FELDERA_TO_EXP_TYPE: t.Dict[str, exp.DataType.Type] = { +_FELDERA_TO_EXP_TYPE: t.Dict[str, t.Any] = { "BOOLEAN": exp.DataType.Type.BOOLEAN, "TINYINT": exp.DataType.Type.TINYINT, "SMALLINT": exp.DataType.Type.SMALLINT, @@ -102,7 +102,7 @@ class FelderaEngineAdapter(EngineAdapter): def _fetch_native_df( self, - query: t.Union[exp.Expression, str], + query: t.Union[exp.Expr, str], quote_identifiers: bool = False, ) -> pd.DataFrame: with self.transaction(): @@ -118,9 +118,7 @@ def _get_data_objects( connection = self.connection pipeline_name = to_schema(schema_name).db - lower_object_names = ( - {name.lower() for name in object_names} if object_names else None - ) + lower_object_names = {name.lower() for name in object_names} if object_names else None try: pipeline = Pipeline.get(pipeline_name, connection._client) @@ -199,8 +197,7 @@ def columns( } raise SQLMeshError( - "Table/view " - f"'{target}' not found in pipeline '{connection._pipeline_name}'" + f"Table/view '{target}' not found in pipeline '{connection._pipeline_name}'" ) def get_current_catalog(self) -> t.Optional[str]: diff --git a/sqlmesh/engines/feldera/__init__.py b/sqlmesh/engines/feldera/__init__.py index 6c43ea250f..9d48db4f9f 100644 --- a/sqlmesh/engines/feldera/__init__.py +++ b/sqlmesh/engines/feldera/__init__.py @@ -1 +1 @@ -from __future__ import annotations \ No newline at end of file +from __future__ import annotations diff --git a/sqlmesh/engines/feldera/db_api.py b/sqlmesh/engines/feldera/db_api.py index d5b53dc523..731b1c0467 100644 --- a/sqlmesh/engines/feldera/db_api.py +++ b/sqlmesh/engines/feldera/db_api.py @@ -180,11 +180,7 @@ def queryable_relation_names(self) -> t.Set[str]: with self._lock: return { *self._tables, - *( - name - for name, sql in self._views.items() - if not _is_materialized_view_sql(sql) - ), + *(name for name, sql in self._views.items() if not _is_materialized_view_sql(sql)), } def assemble_program(self) -> str: @@ -391,7 +387,9 @@ def _compile_program( pipeline._inner = inner_pipeline return pipeline - def _format_compile_error(self, client: t.Any, pipeline_name: str, error: Exception) -> RuntimeError: + def _format_compile_error( + self, client: t.Any, pipeline_name: str, error: Exception + ) -> RuntimeError: error_message = str(error) try: @@ -417,7 +415,9 @@ def _format_compile_error(self, client: t.Any, pipeline_name: str, error: Except rust_error = program_error.get("rust_compilation") system_error = program_error.get("system_error") if rust_error or system_error: - message = f"The program failed to compile: {getattr(pipeline, 'program_status', 'unknown')}\n" + message = ( + f"The program failed to compile: {getattr(pipeline, 'program_status', 'unknown')}\n" + ) if rust_error is not None: message += f"Rust Error: {rust_error}\n" if system_error is not None: @@ -481,7 +481,9 @@ def _query_mirror_table(table: exp.Table) -> exp.Table: def _ddl_statements_with_query_mirror(sql: str) -> t.List[str]: try: expressions = [ - expression for expression in parse(sql, dialect=FELDERA_DIALECT) if expression is not None + expression + for expression in parse(sql, dialect=FELDERA_DIALECT) + if expression is not None ] except (ParseError, ValueError): expressions = [] @@ -549,9 +551,7 @@ def _rewrite_query_for_query_mirrors(sql: str, relation_names: t.Set[str]) -> st return sql cte_names = { - cte.alias_or_name.lower() - for cte in expression.find_all(exp.CTE) - if cte.alias_or_name + cte.alias_or_name.lower() for cte in expression.find_all(exp.CTE) if cte.alias_or_name } def transform(node: exp.Expression) -> exp.Expression: @@ -567,7 +567,7 @@ def transform(node: exp.Expression) -> exp.Expression: return expression.transform(transform, copy=True).sql(dialect=FELDERA_DIALECT) -def _execution_error(rows: t.List[t.Mapping[str, t.Any]]) -> t.Optional[str]: +def _execution_error(rows: t.Sequence[t.Mapping[str, t.Any]]) -> t.Optional[str]: for row in rows: for value in row.values(): if isinstance(value, str) and value.startswith("Execution error:"): @@ -634,10 +634,7 @@ def _rewrite_table_ctas_sql(sql: str) -> str: properties=expression.args.get("properties"), ) insert_exp = exp.insert(query.copy(), target.copy(), columns=list(columns_to_types)) - return ( - f"{create_exp.sql(dialect=FELDERA_DIALECT)};\n" - f"{insert_exp.sql(dialect=FELDERA_DIALECT)}" - ) + return f"{create_exp.sql(dialect=FELDERA_DIALECT)};\n{insert_exp.sql(dialect=FELDERA_DIALECT)}" def _select_columns_to_types(query: exp.Expression) -> t.Optional[t.Dict[str, exp.DataType]]: @@ -649,7 +646,9 @@ def _select_columns_to_types(query: exp.Expression) -> t.Optional[t.Dict[str, ex for select in query.selects: output_name = select.output_name - data_type = _projection_type(select) or (select.type or unknown).copy() + data_type = ( + _projection_type(t.cast(exp.Expression, select)) or (select.type or unknown).copy() + ) if not output_name or output_name in columns_to_types or data_type == unknown: return None @@ -805,15 +804,16 @@ def _insert_to_input_json_payload( return None source_row = { - column: _literal_value(value) - for column, value in zip(source_columns, row.expressions) + column: _literal_value(value) for column, value in zip(source_columns, row.expressions) } payload_row: t.Dict[str, t.Any] = {} for target_column, select in zip(target_columns, query.selects): if not target_column: return None - evaluated = _evaluate_insert_value_expression(select, source_row) + evaluated: t.Any = _evaluate_insert_value_expression( + t.cast(exp.Expression, select), source_row + ) if evaluated is _UNSUPPORTED_INGEST_EXPRESSION: return None payload_row[target_column] = evaluated @@ -937,7 +937,7 @@ def __init__( self, client: t.Any, pipeline_name: str, - state_manager: PipelineStateManager, + state_manager: t.Any, workers: int = 4, compilation_profile: str = "dev", timeout: int = 300, @@ -955,9 +955,7 @@ def __init__( def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: if parameters is not None: - raise NotImplementedError( - "Feldera DB-API does not support query parameters" - ) + raise NotImplementedError("Feldera DB-API does not support query parameters") original_sql = sql normalized_sql = _normalize_ddl(sql) @@ -1005,9 +1003,7 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: self._columns = [] return - query_sql = _rewrite_query_for_query_mirrors( - sql, self._state.queryable_relation_names() - ) + query_sql = _rewrite_query_for_query_mirrors(sql, self._state.queryable_relation_names()) rows = list(self._get_pipeline().query(query_sql)) if error := _execution_error(rows): raise RuntimeError(error) @@ -1015,8 +1011,7 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: self._columns = list(rows[0].keys()) if rows else [] self.rowcount = len(rows) self.description = [ - (column, None, None, None, None, None, None) - for column in self._columns + (column, None, None, None, None, None, None) for column in self._columns ] def _get_pipeline(self) -> t.Any: @@ -1072,7 +1067,7 @@ def __init__( self._timeout = timeout state_key = (host, pipeline_name) with self._state_lock: - self._state = self._shared_states.setdefault(state_key, PipelineStateManager()) + self._state: t.Any = self._shared_states.setdefault(state_key, PipelineStateManager()) def cursor(self) -> FelderaCursor: return FelderaCursor( diff --git a/tests/core/engine_adapter/test_feldera.py b/tests/core/engine_adapter/test_feldera.py index fe9894d505..580c2466f7 100644 --- a/tests/core/engine_adapter/test_feldera.py +++ b/tests/core/engine_adapter/test_feldera.py @@ -43,7 +43,9 @@ def _install_feldera_pipeline(monkeypatch: pytest.MonkeyPatch, pipeline: t.Any) monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) -def test_columns_uses_pipeline_metadata(adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch): +def test_columns_uses_pipeline_metadata( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): connection = adapter._connection_pool.get() connection._client = object() connection._pipeline_name = "configured_pipeline" @@ -216,9 +218,7 @@ def test_insert_overwrite_by_time_partition_uses_table_operations( ] -def test_create_view_creates_view( - adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch -): +def test_create_view_creates_view(adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(adapter, "_get_data_objects", lambda schema_name, object_names=None: []) adapter.create_view( "db.view_model", @@ -333,4 +333,4 @@ def test_drop_table_drops_table(adapter: FelderaEngineAdapter, monkeypatch: pyte adapter.drop_table("db.full_model") - assert to_sql_calls(adapter) == ['DROP TABLE IF EXISTS "db"."full_model"'] \ No newline at end of file + assert to_sql_calls(adapter) == ['DROP TABLE IF EXISTS "db"."full_model"'] diff --git a/tests/engines/feldera/test_db_api.py b/tests/engines/feldera/test_db_api.py index 2e7e0ce332..2cbd196dd2 100644 --- a/tests/engines/feldera/test_db_api.py +++ b/tests/engines/feldera/test_db_api.py @@ -1,4 +1,5 @@ import types +import typing as t import pytest from sqlglot import parse_one @@ -7,10 +8,7 @@ def test_classify_treats_comment_prefixed_create_schema_as_pipeline_ddl() -> None: - assert ( - db_api._classify("/* sqlmesh */ CREATE SCHEMA foo") - == db_api.SqlIntent.PIPELINE_DDL - ) + assert db_api._classify("/* sqlmesh */ CREATE SCHEMA foo") == db_api.SqlIntent.PIPELINE_DDL def test_is_virtual_layer_ddl_identifies_environment_alias_view() -> None: @@ -32,8 +30,7 @@ def test_strip_table_qualifiers_preserves_current_timestamp_keyword() -> None: ) assert db_api._strip_table_qualifiers(sql) == ( - 'CREATE MATERIALIZED VIEW "view_model" AS ' - 'SELECT CURRENT_TIMESTAMP AS ts FROM "source"' + 'CREATE MATERIALIZED VIEW "view_model" AS SELECT CURRENT_TIMESTAMP AS ts FROM "source"' ) @@ -52,10 +49,7 @@ def test_normalize_pipeline_ddl_canonicalizes_snapshot_names_to_logical_names() def test_normalize_pipeline_ddl_strips_schema_from_logical_tables() -> None: - sql = ( - 'CREATE TABLE "analytics"."sample_seed" ' - '("entity_id" VARCHAR, "description" VARCHAR)' - ) + sql = 'CREATE TABLE "analytics"."sample_seed" ("entity_id" VARCHAR, "description" VARCHAR)' assert db_api._normalize_pipeline_ddl(sql) == ( 'CREATE TABLE "sample_seed" ("entity_id" VARCHAR, "description" VARCHAR)' @@ -86,7 +80,7 @@ def current_pipeline(self) -> object: def queryable_relation_names(self) -> set[str]: return set() - state_manager = FakeStateManager() + state_manager = t.cast(t.Any, FakeStateManager()) cursor = db_api.FelderaCursor( client=object(), pipeline_name="test_pipeline", @@ -106,7 +100,7 @@ def queryable_relation_names(self) -> set[str]: def test_cursor_ignores_virtual_layer_view_ddl() -> None: class FakeStateManager: def __init__(self) -> None: - self.registered_sql = [] + self.registered_sql: list[str] = [] def register_ddl(self, sql: str) -> None: self.registered_sql.append(sql) @@ -114,7 +108,7 @@ def register_ddl(self, sql: str) -> None: def has_pending_changes(self) -> bool: return False - state_manager = FakeStateManager() + state_manager = t.cast(t.Any, FakeStateManager()) cursor = db_api.FelderaCursor( client=object(), pipeline_name="test_pipeline", @@ -132,7 +126,7 @@ def has_pending_changes(self) -> bool: def test_cursor_registers_logical_model_names_in_pipeline_ddl() -> None: class FakeStateManager: def __init__(self) -> None: - self.registered_sql = [] + self.registered_sql: list[str] = [] def register_ddl(self, sql: str) -> None: self.registered_sql.append(sql) @@ -140,7 +134,7 @@ def register_ddl(self, sql: str) -> None: def has_pending_changes(self) -> bool: return False - state_manager = FakeStateManager() + state_manager = t.cast(t.Any, FakeStateManager()) cursor = db_api.FelderaCursor( client=object(), pipeline_name="test_pipeline", @@ -170,19 +164,23 @@ def test_hydrate_existing_program_skips_empty_parse_results(monkeypatch) -> None ) pipeline_module = types.ModuleType("feldera.pipeline") - pipeline_module.Pipeline = type( + setattr( + pipeline_module, "Pipeline", - (), - { - "get": staticmethod( - lambda pipeline_name, client: types.SimpleNamespace( - _inner=types.SimpleNamespace(program_code="ignored") + type( + "Pipeline", + (), + { + "get": staticmethod( + lambda pipeline_name, client: types.SimpleNamespace( + _inner=types.SimpleNamespace(program_code="ignored") + ) ) - ) - }, + }, + ), ) feldera_module = types.ModuleType("feldera") - feldera_module.pipeline = pipeline_module + setattr(feldera_module, "pipeline", pipeline_module) monkeypatch.setitem(__import__("sys").modules, "feldera", feldera_module) monkeypatch.setitem(__import__("sys").modules, "feldera.pipeline", pipeline_module) @@ -212,9 +210,9 @@ def test_state_manager_rewrites_table_ctas() -> None: assert manager.assemble_program() == ( 'CREATE TABLE "seed_model" ("id" INTEGER);\n' - '\n' + "\n" 'CREATE MATERIALIZED VIEW "__sqlmesh_query__seed_model" AS SELECT * FROM "seed_model";\n' - '\n' + "\n" 'INSERT INTO "seed_model" (id) SELECT CAST("id" AS INTEGER) AS "id" FROM (VALUES (1)) AS "t"("id");' ) @@ -222,18 +220,18 @@ def test_state_manager_rewrites_table_ctas() -> None: def test_evict_hydrated_objects_removes_stale_object_from_compile_error() -> None: manager = db_api.PipelineStateManager() manager._views = { - 'analytics__aggregate_view__781619724__dev': ( + "analytics__aggregate_view__781619724__dev": ( 'CREATE MATERIALIZED VIEW "analytics__aggregate_view__781619724__dev" AS ' - 'SELECT CURRENT_TIMESTAMP AS ts' + "SELECT CURRENT_TIMESTAMP AS ts" ), - 'analytics__aggregate_view__1225616675__dev': ( + "analytics__aggregate_view__1225616675__dev": ( 'CREATE MATERIALIZED VIEW "analytics__aggregate_view__1225616675__dev" AS ' - 'SELECT NOW() AS ts' + "SELECT NOW() AS ts" ), } manager._hydrated_object_keys = { - 'analytics__aggregate_view__781619724__dev', - 'analytics__aggregate_view__1225616675__dev', + "analytics__aggregate_view__781619724__dev", + "analytics__aggregate_view__1225616675__dev", } removed = manager._evict_hydrated_objects( @@ -241,8 +239,8 @@ def test_evict_hydrated_objects_removes_stale_object_from_compile_error() -> Non ) assert removed is True - assert 'analytics__aggregate_view__781619724__dev' not in manager.pending_views() - assert 'analytics__aggregate_view__1225616675__dev' in manager.pending_views() + assert "analytics__aggregate_view__781619724__dev" not in manager.pending_views() + assert "analytics__aggregate_view__1225616675__dev" in manager.pending_views() def test_format_compile_error_preserves_sql_compilation_details() -> None: @@ -296,13 +294,16 @@ def deploy(self, *args: object, **kwargs: object) -> object: host="http://localhost:8080", pipeline_name="test_pipeline", ) - connection._state = FakeStateManager() + connection._state = t.cast(t.Any, FakeStateManager()) with caplog.at_level("ERROR", logger="sqlmesh.engines.feldera.db_api"): with pytest.raises(RuntimeError, match="TIMESTAMP_TRUNC problem"): connection.close() - assert "Feldera pending DDL failed during connection close for pipeline test_pipeline" in caplog.text + assert ( + "Feldera pending DDL failed during connection close for pipeline test_pipeline" + in caplog.text + ) assert "TIMESTAMP_TRUNC problem" in caplog.text @@ -317,8 +318,14 @@ def test_state_manager_adds_query_mirrors_for_non_materialized_relations() -> No program = manager.assemble_program() - assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__full_model" AS SELECT * FROM "full_model";' in program - assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__view_model" AS SELECT * FROM "view_model";' in program + assert ( + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__full_model" AS SELECT * FROM "full_model";' + in program + ) + assert ( + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__view_model" AS SELECT * FROM "view_model";' + in program + ) assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__materialized_view_model"' not in program @@ -332,7 +339,10 @@ def test_state_manager_query_mirror_strips_schema_qualifiers() -> None: program = manager.assemble_program() - assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__sample_seed" AS SELECT * FROM "sample_seed";' in program + assert ( + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__sample_seed" AS SELECT * FROM "sample_seed";' + in program + ) assert 'SELECT * FROM "analytics"."sample_seed"' not in program @@ -346,7 +356,7 @@ def test_state_manager_skips_query_mirrors_for_sqlmesh_internal_objects() -> Non program = manager.assemble_program() - assert '__sqlmesh_query__analytics__source_snapshot__441175831__dev' not in program + assert "__sqlmesh_query__analytics__source_snapshot__441175831__dev" not in program def test_state_manager_canonicalizes_snapshot_tables_to_logical_names() -> None: @@ -362,34 +372,40 @@ def test_state_manager_canonicalizes_snapshot_tables_to_logical_names() -> None: program = manager.assemble_program() assert 'CREATE TABLE "records" ("entity_id" VARCHAR, "description" VARCHAR);' in program - assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__records" AS SELECT * FROM "records";' in program - assert 'analytics__records__849752499__dev' not in program + assert ( + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__records" AS SELECT * FROM "records";' in program + ) + assert "analytics__records__849752499__dev" not in program def test_hydrate_existing_program_skips_query_mirrors(monkeypatch) -> None: manager = db_api.PipelineStateManager() pipeline_module = types.ModuleType("feldera.pipeline") - pipeline_module.Pipeline = type( + setattr( + pipeline_module, "Pipeline", - (), - { - "get": staticmethod( - lambda pipeline_name, client: types.SimpleNamespace( - _inner=types.SimpleNamespace( - program_code=( - 'CREATE TABLE "full_model" ("id" INT);\n' - 'CREATE MATERIALIZED VIEW "__sqlmesh_query__full_model" AS SELECT * FROM "full_model";\n' - 'CREATE VIEW "view_model" AS SELECT "id" FROM "full_model";\n' - 'CREATE MATERIALIZED VIEW "__sqlmesh_query__view_model" AS SELECT * FROM "view_model";' + type( + "Pipeline", + (), + { + "get": staticmethod( + lambda pipeline_name, client: types.SimpleNamespace( + _inner=types.SimpleNamespace( + program_code=( + 'CREATE TABLE "full_model" ("id" INT);\n' + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__full_model" AS SELECT * FROM "full_model";\n' + 'CREATE VIEW "view_model" AS SELECT "id" FROM "full_model";\n' + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__view_model" AS SELECT * FROM "view_model";' + ) ) ) ) - ) - }, + }, + ), ) feldera_module = types.ModuleType("feldera") - feldera_module.pipeline = pipeline_module + setattr(feldera_module, "pipeline", pipeline_module) monkeypatch.setitem(__import__("sys").modules, "feldera", feldera_module) monkeypatch.setitem(__import__("sys").modules, "feldera.pipeline", pipeline_module) @@ -401,7 +417,7 @@ def test_hydrate_existing_program_skips_query_mirrors(monkeypatch) -> None: def test_cursor_rewrites_queries_to_query_mirrors() -> None: - captured_queries = [] + captured_queries: list[str] = [] class FakeStateManager: def has_pending_changes(self) -> bool: @@ -411,26 +427,29 @@ def queryable_relation_names(self) -> set[str]: return {"full_model", "view_model"} def current_pipeline(self) -> object: - return types.SimpleNamespace( - query=lambda sql: captured_queries.append(sql) or [{"count": 1}] - ) + def query(sql: str) -> list[dict[str, int]]: + captured_queries.append(sql) + return [{"count": 1}] + + return types.SimpleNamespace(query=query) cursor = db_api.FelderaCursor( client=object(), pipeline_name="test_pipeline", - state_manager=FakeStateManager(), + state_manager=t.cast(t.Any, FakeStateManager()), ) cursor.execute('SELECT COUNT(*) FROM "full_model"') - assert parse_one(captured_queries[0]).sql() == parse_one( - 'SELECT COUNT(*) FROM "__sqlmesh_query__full_model"' - ).sql() + assert ( + parse_one(captured_queries[0]).sql() + == parse_one('SELECT COUNT(*) FROM "__sqlmesh_query__full_model"').sql() + ) assert cursor.fetchone() == (1,) def test_cursor_rewrites_snapshot_queries_to_logical_query_mirrors() -> None: - captured_queries = [] + captured_queries: list[str] = [] class FakeStateManager: def has_pending_changes(self) -> bool: @@ -440,21 +459,26 @@ def queryable_relation_names(self) -> set[str]: return {"source_events"} def current_pipeline(self) -> object: - return types.SimpleNamespace( - query=lambda sql: captured_queries.append(sql) or [{"count": 1}] - ) + def query(sql: str) -> list[dict[str, int]]: + captured_queries.append(sql) + return [{"count": 1}] + + return types.SimpleNamespace(query=query) cursor = db_api.FelderaCursor( client=object(), pipeline_name="test_pipeline", - state_manager=FakeStateManager(), + state_manager=t.cast(t.Any, FakeStateManager()), ) - cursor.execute('SELECT COUNT(*) FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev"') + cursor.execute( + 'SELECT COUNT(*) FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev"' + ) - assert parse_one(captured_queries[0]).sql() == parse_one( - 'SELECT COUNT(*) FROM "__sqlmesh_query__source_events"' - ).sql() + assert ( + parse_one(captured_queries[0]).sql() + == parse_one('SELECT COUNT(*) FROM "__sqlmesh_query__source_events"').sql() + ) assert cursor.fetchone() == (1,) @@ -464,7 +488,7 @@ def test_insert_to_input_json_payload_rewrites_seed_values() -> None: 'SELECT CAST("entity_id" AS TEXT) AS "entity_id", ' 'CAST("score" AS DOUBLE) AS "score", ' 'CAST("event_ts" AS TIMESTAMP) AS "event_ts" ' - 'FROM (VALUES ' + "FROM (VALUES " "('123', '4.5', '2026-05-09 01:00:00'), " "('456', '7.0', NULL)" ') AS "t"("entity_id", "score", "event_ts")' @@ -528,7 +552,7 @@ def queryable_relation_names(self) -> set[str]: cursor = db_api.FelderaCursor( client=object(), pipeline_name="test_pipeline", - state_manager=FakeStateManager(), + state_manager=t.cast(t.Any, FakeStateManager()), ) cursor.execute( @@ -566,8 +590,8 @@ def current_pipeline(self) -> object: cursor = db_api.FelderaCursor( client=object(), pipeline_name="test_pipeline", - state_manager=FakeStateManager(), + state_manager=t.cast(t.Any, FakeStateManager()), ) with pytest.raises(RuntimeError, match="Execution error: test failure"): - cursor.execute("SELECT COUNT(*) FROM full_model") \ No newline at end of file + cursor.execute("SELECT COUNT(*) FROM full_model") From 4c28367e3d7cf885fa84c0b3f5087b02dd33c6c8 Mon Sep 17 00:00:00 2001 From: fresioAS Date: Wed, 13 May 2026 00:10:06 +0200 Subject: [PATCH 08/10] timeout --- sqlmesh/engines/feldera/db_api.py | 22 +-- tests/engines/feldera/test_db_api.py | 205 +++++++++++++++++++++++++++ 2 files changed, 209 insertions(+), 18 deletions(-) diff --git a/sqlmesh/engines/feldera/db_api.py b/sqlmesh/engines/feldera/db_api.py index 731b1c0467..66ec92aba9 100644 --- a/sqlmesh/engines/feldera/db_api.py +++ b/sqlmesh/engines/feldera/db_api.py @@ -214,11 +214,6 @@ def deploy( from feldera.rest.pipeline import Pipeline as InnerPipeline from feldera.runtime_config import RuntimeConfig - try: - from feldera.pipeline import PipelineStatus - except ImportError: - from feldera.rest.pipeline import PipelineStatus - with self._lock: self._hydrate_existing_program(client, pipeline_name) if not self._dirty: @@ -248,7 +243,6 @@ def deploy( timeout, Pipeline, PipelineBuilder, - PipelineStatus, InnerPipeline, FelderaAPIError, ) @@ -263,8 +257,7 @@ def deploy( self._dropped_objects.clear() return self._pipeline - pipeline.start() - pipeline.wait_for_status(PipelineStatus.RUNNING, timeout=timeout) + pipeline.start(timeout_s=timeout) self._pipeline = pipeline self._dirty = False self._dropped_objects.clear() @@ -342,7 +335,6 @@ def _compile_program( timeout: int, Pipeline: t.Any, PipelineBuilder: t.Any, - PipelineStatus: t.Any, InnerPipeline: t.Any, FelderaAPIError: t.Any, ) -> t.Any: @@ -363,8 +355,7 @@ def _compile_program( except RuntimeError as ex: raise self._format_compile_error(client, pipeline_name, ex) from ex - existing_pipeline.stop(force=True) - existing_pipeline.wait_for_status(PipelineStatus.STOPPED, timeout=timeout) + existing_pipeline.stop(force=True, timeout_s=timeout) existing_pipeline.dismiss_error() inner_pipeline = InnerPipeline( @@ -391,15 +382,10 @@ def _format_compile_error( self, client: t.Any, pipeline_name: str, error: Exception ) -> RuntimeError: error_message = str(error) + from feldera.enums import PipelineFieldSelector try: - from feldera.enums import PipelineFieldSelector - except Exception: - PipelineFieldSelector = None - - try: - field_selector = PipelineFieldSelector.ALL if PipelineFieldSelector else None - pipeline = client.get_pipeline(pipeline_name, field_selector) + pipeline = client.get_pipeline(pipeline_name, PipelineFieldSelector.ALL) except Exception: return RuntimeError(error_message) diff --git a/tests/engines/feldera/test_db_api.py b/tests/engines/feldera/test_db_api.py index 2cbd196dd2..5bd4591718 100644 --- a/tests/engines/feldera/test_db_api.py +++ b/tests/engines/feldera/test_db_api.py @@ -277,6 +277,211 @@ def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: ) +def test_format_compile_error_requests_all_pipeline_fields(monkeypatch) -> None: + manager = db_api.PipelineStateManager() + requested_field_selector: list[object] = [] + selector_all = object() + + enums_module = types.ModuleType("feldera.enums") + setattr(enums_module, "PipelineFieldSelector", types.SimpleNamespace(ALL=selector_all)) + monkeypatch.setitem(__import__("sys").modules, "feldera.enums", enums_module) + + class FakeClient: + def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: + requested_field_selector.append(field_selector) + return types.SimpleNamespace(program_error={}, program_status="Unknown") + + manager._format_compile_error(FakeClient(), "test_pipeline", RuntimeError("boom")) + + assert requested_field_selector == [selector_all] + + +def test_deploy_imports_compilation_profile_from_feldera_enums(monkeypatch) -> None: + manager = db_api.PipelineStateManager() + + class CompilationProfile(str): + pass + + class Pipeline: + @staticmethod + def get(name: str, client: object) -> object: + return types.SimpleNamespace(_inner=types.SimpleNamespace(program_code="")) + + class PipelineBuilder: + def __init__(self, *args: object, **kwargs: object) -> None: + pass + + def create_or_replace(self, wait: bool = True) -> object: + return types.SimpleNamespace(start=lambda: None) + + class RuntimeConfig: + def __init__(self) -> None: + self.workers = 0 + + @classmethod + def default(cls) -> "RuntimeConfig": + return cls() + + def to_dict(self) -> dict[str, object]: + return {} + + feldera_module = types.ModuleType("feldera") + enums_module = types.ModuleType("feldera.enums") + pipeline_module = types.ModuleType("feldera.pipeline") + pipeline_builder_module = types.ModuleType("feldera.pipeline_builder") + runtime_config_module = types.ModuleType("feldera.runtime_config") + rest_module = types.ModuleType("feldera.rest") + rest_errors_module = types.ModuleType("feldera.rest.errors") + rest_pipeline_module = types.ModuleType("feldera.rest.pipeline") + + setattr(enums_module, "CompilationProfile", CompilationProfile) + setattr(pipeline_module, "Pipeline", Pipeline) + setattr(pipeline_builder_module, "PipelineBuilder", PipelineBuilder) + setattr(runtime_config_module, "RuntimeConfig", RuntimeConfig) + setattr(rest_errors_module, "FelderaAPIError", RuntimeError) + setattr(rest_pipeline_module, "Pipeline", type("InnerPipeline", (), {})) + setattr(rest_module, "errors", rest_errors_module) + setattr(rest_module, "pipeline", rest_pipeline_module) + setattr(feldera_module, "enums", enums_module) + setattr(feldera_module, "pipeline", pipeline_module) + setattr(feldera_module, "pipeline_builder", pipeline_builder_module) + setattr(feldera_module, "runtime_config", runtime_config_module) + setattr(feldera_module, "rest", rest_module) + + monkeypatch.setitem(__import__("sys").modules, "feldera", feldera_module) + monkeypatch.setitem(__import__("sys").modules, "feldera.enums", enums_module) + monkeypatch.setitem(__import__("sys").modules, "feldera.pipeline", pipeline_module) + monkeypatch.setitem( + __import__("sys").modules, "feldera.pipeline_builder", pipeline_builder_module + ) + monkeypatch.setitem(__import__("sys").modules, "feldera.runtime_config", runtime_config_module) + monkeypatch.setitem(__import__("sys").modules, "feldera.rest", rest_module) + monkeypatch.setitem(__import__("sys").modules, "feldera.rest.errors", rest_errors_module) + monkeypatch.setitem(__import__("sys").modules, "feldera.rest.pipeline", rest_pipeline_module) + + manager.deploy(object(), "test_pipeline") + + +def test_compile_program_does_not_wait_after_stop() -> None: + manager = db_api.PipelineStateManager() + wait_calls: list[tuple[object, int]] = [] + stop_calls: list[tuple[bool, t.Optional[float]]] = [] + dismiss_error_calls = 0 + + class ExistingPipeline: + def stop(self, force: bool = True, timeout_s: t.Optional[float] = None) -> None: + stop_calls.append((force, timeout_s)) + + def wait_for_status(self, status: object, timeout: int = 0) -> None: + wait_calls.append((status, timeout)) + + def dismiss_error(self) -> None: + nonlocal dismiss_error_calls + dismiss_error_calls += 1 + + class Pipeline: + @staticmethod + def get(name: str, client: object) -> object: + return ExistingPipeline() + + def __init__(self, client: object) -> None: + self._inner = None + + class InnerPipeline: + def __init__(self, **kwargs: object) -> None: + pass + + class Profile: + value = "dev" + + class RuntimeConfig: + def to_dict(self) -> dict[str, object]: + return {} + + client = types.SimpleNamespace(create_or_update_pipeline=lambda pipeline, wait=True: pipeline) + + manager._compile_program( + client, + "test_pipeline", + "SELECT 1", + Profile(), + RuntimeConfig(), + 300, + Pipeline, + object, + InnerPipeline, + RuntimeError, + ) + + assert stop_calls == [(True, 300)] + assert dismiss_error_calls == 1 + assert wait_calls == [] + + +def test_deploy_does_not_wait_after_start(monkeypatch) -> None: + manager = db_api.PipelineStateManager() + manager._dirty = True + start_calls: list[t.Optional[float]] = [] + wait_calls: list[tuple[object, int]] = [] + + class Pipeline: + def start(self, timeout_s: t.Optional[float] = None) -> None: + start_calls.append(timeout_s) + + def wait_for_status(self, status: object, timeout: int = 0) -> None: + wait_calls.append((status, timeout)) + + class CompilationProfile(str): + pass + + class RuntimeConfig: + def __init__(self) -> None: + self.workers = 0 + + @classmethod + def default(cls) -> "RuntimeConfig": + return cls() + + def to_dict(self) -> dict[str, object]: + return {} + + enums_module = types.ModuleType("feldera.enums") + pipeline_module = types.ModuleType("feldera.pipeline") + pipeline_builder_module = types.ModuleType("feldera.pipeline_builder") + runtime_config_module = types.ModuleType("feldera.runtime_config") + rest_errors_module = types.ModuleType("feldera.rest.errors") + rest_pipeline_module = types.ModuleType("feldera.rest.pipeline") + + setattr(enums_module, "CompilationProfile", CompilationProfile) + setattr( + pipeline_module, + "Pipeline", + type("PipelineClass", (), {"get": staticmethod(lambda n, c: None)}), + ) + setattr(pipeline_builder_module, "PipelineBuilder", object) + setattr(runtime_config_module, "RuntimeConfig", RuntimeConfig) + setattr(rest_errors_module, "FelderaAPIError", RuntimeError) + setattr(rest_pipeline_module, "Pipeline", type("InnerPipeline", (), {})) + + monkeypatch.setitem(__import__("sys").modules, "feldera.enums", enums_module) + monkeypatch.setitem(__import__("sys").modules, "feldera.pipeline", pipeline_module) + monkeypatch.setitem( + __import__("sys").modules, "feldera.pipeline_builder", pipeline_builder_module + ) + monkeypatch.setitem(__import__("sys").modules, "feldera.runtime_config", runtime_config_module) + monkeypatch.setitem(__import__("sys").modules, "feldera.rest.errors", rest_errors_module) + monkeypatch.setitem(__import__("sys").modules, "feldera.rest.pipeline", rest_pipeline_module) + + monkeypatch.setattr(manager, "_hydrate_existing_program", lambda client, pipeline_name: None) + monkeypatch.setattr(manager, "assemble_program", lambda: "CREATE TABLE x (id INT)") + monkeypatch.setattr(manager, "_compile_program", lambda *args, **kwargs: Pipeline()) + + manager.deploy(object(), "test_pipeline") + + assert start_calls == [300] + assert wait_calls == [] + + def test_connection_close_logs_pending_compile_error(caplog: pytest.LogCaptureFixture) -> None: class FakeStateManager: def has_pending_changes(self) -> bool: From 3b489637120497262b537d1ef478e5f18cf9f8f7 Mon Sep 17 00:00:00 2001 From: fresioAS Date: Wed, 13 May 2026 00:54:42 +0200 Subject: [PATCH 09/10] dialect.py + tests --- pyproject.toml | 1 + sqlmesh/core/config/connection.py | 15 +- sqlmesh/core/engine_adapter/feldera.py | 28 +--- sqlmesh/engines/feldera/db_api.py | 12 +- sqlmesh/engines/feldera/dialect.py | 33 ++++ tests/core/engine_adapter/test_feldera.py | 89 ++++++++++- tests/core/test_connection_config.py | 1 + tests/engines/feldera/test_db_api.py | 186 +++++++++++++++++++--- 8 files changed, 302 insertions(+), 63 deletions(-) create mode 100644 sqlmesh/engines/feldera/dialect.py diff --git a/pyproject.toml b/pyproject.toml index 28aff0b439..137df13a7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -255,6 +255,7 @@ markers = [ "clickhouse_cloud: test for Clickhouse (cloud mode)", "databricks: test for Databricks", "duckdb: test for DuckDB", + "feldera: test for Feldera", "fabric: test for Fabric", "motherduck: test for MotherDuck", "mssql: test for MSSQL", diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index de0e0cdb01..c8a4eda5b7 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -56,6 +56,7 @@ # Do not support row-level operations "spark", "trino", + "feldera", # Nullable types are problematic "clickhouse", } @@ -2343,7 +2344,16 @@ def init(cursor: t.Any) -> None: class FelderaConnectionConfig(ConnectionConfig): - """Feldera connection configuration.""" + """Feldera connection configuration. + + Args: + host: The Feldera API base URL. + api_key: The optional Feldera API key. + pipeline_name: The name of the backing Feldera pipeline. + compilation_profile: The Feldera compilation profile to use during deploys. + workers: The number of workers in the Feldera runtime config. + timeout: The timeout, in seconds, for Feldera API operations. + """ host: str = "http://localhost:8080" api_key: t.Optional[str] = None @@ -2386,9 +2396,6 @@ def _connection_factory(self) -> t.Callable: return connect - def get_catalog(self) -> t.Optional[str]: - return None - _CONNECTION_CONFIG_EXCLUDE: t.Set[t.Type[ConnectionConfig]] = { ConnectionConfig, # type: ignore[type-abstract] diff --git a/sqlmesh/core/engine_adapter/feldera.py b/sqlmesh/core/engine_adapter/feldera.py index c4dfa85890..8af139f07d 100644 --- a/sqlmesh/core/engine_adapter/feldera.py +++ b/sqlmesh/core/engine_adapter/feldera.py @@ -4,8 +4,6 @@ import typing as t from sqlglot import exp -from sqlglot.dialects.dialect import Dialect -from sqlglot.generator import Generator as SQLGlotGenerator from sqlmesh.core.dialect import to_schema from sqlmesh.core.engine_adapter.base import EngineAdapter @@ -14,6 +12,7 @@ CommentCreationView, DataObject, DataObjectType, + set_catalog, ) from sqlmesh.engines.feldera.db_api import QUERY_MIRROR_PREFIX from sqlmesh.utils.errors import SQLMeshError @@ -41,30 +40,6 @@ def _view_type(state: t.Any, object_name: str) -> DataObjectType: return DataObjectType.VIEW -class _SQLMeshFeldera(Dialect): - class Generator(SQLGlotGenerator): - TYPE_MAPPING = { - **SQLGlotGenerator.TYPE_MAPPING, - exp.DataType.Type.FLOAT: "REAL", - exp.DataType.Type.INT: "INTEGER", - } - TRANSFORMS = { - **SQLGlotGenerator.TRANSFORMS, - exp.CurrentTimestamp: lambda self, expression: ( - self.func("CURRENT_TIMESTAMP", expression.this) - if expression.this - else "CURRENT_TIMESTAMP" - ), - exp.DateStrToDate: lambda self, expression: ( - f"CAST({self.sql(expression, 'this')} AS DATE)" - ), - } - - -if Dialect.get("feldera") is None: - Dialect.classes["feldera"] = _SQLMeshFeldera - - _FELDERA_TO_EXP_TYPE: t.Dict[str, t.Any] = { "BOOLEAN": exp.DataType.Type.BOOLEAN, "TINYINT": exp.DataType.Type.TINYINT, @@ -91,6 +66,7 @@ def _feldera_type_to_exp(dtype_str: str) -> exp.DataType: return exp.DataType(this=kind) +@set_catalog() class FelderaEngineAdapter(EngineAdapter): DIALECT = "feldera" SUPPORTS_TRANSACTIONS = False diff --git a/sqlmesh/engines/feldera/db_api.py b/sqlmesh/engines/feldera/db_api.py index 66ec92aba9..851d6582b0 100644 --- a/sqlmesh/engines/feldera/db_api.py +++ b/sqlmesh/engines/feldera/db_api.py @@ -9,6 +9,8 @@ from sqlglot import exp, parse, parse_one from sqlglot.errors import ParseError +import sqlmesh.engines.feldera.dialect + logger = logging.getLogger(__name__) QUERY_MIRROR_PREFIX = "__sqlmesh_query__" @@ -943,6 +945,9 @@ def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: if parameters is not None: raise NotImplementedError("Feldera DB-API does not support query parameters") + self.description = None + self.rowcount = -1 + original_sql = sql normalized_sql = _normalize_ddl(sql) intent = _classify(normalized_sql) @@ -1055,6 +1060,11 @@ def __init__( with self._state_lock: self._state: t.Any = self._shared_states.setdefault(state_key, PipelineStateManager()) + @classmethod + def reset_shared_states(cls) -> None: + with cls._state_lock: + cls._shared_states.clear() + def cursor(self) -> FelderaCursor: return FelderaCursor( self._client, @@ -1101,7 +1111,7 @@ def connect( ) -> FelderaConnection: from feldera.rest.feldera_client import FelderaClient - client = FelderaClient(url=host, api_key=api_key, timeout=float(timeout)) + client = FelderaClient(url=host, api_key=api_key, timeout=timeout) return FelderaConnection( client, host, diff --git a/sqlmesh/engines/feldera/dialect.py b/sqlmesh/engines/feldera/dialect.py new file mode 100644 index 0000000000..b0f794121c --- /dev/null +++ b/sqlmesh/engines/feldera/dialect.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.generator import Generator as SQLGlotGenerator + + +class SQLMeshFelderaDialect(Dialect): + class Generator(SQLGlotGenerator): + TYPE_MAPPING = { + **SQLGlotGenerator.TYPE_MAPPING, + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.INT: "INTEGER", + } + TRANSFORMS = { + **SQLGlotGenerator.TRANSFORMS, + exp.CurrentTimestamp: lambda self, expression: ( + self.func("CURRENT_TIMESTAMP", expression.this) + if expression.this + else "CURRENT_TIMESTAMP" + ), + exp.DateStrToDate: lambda self, expression: ( + f"CAST({self.sql(expression, 'this')} AS DATE)" + ), + } + + +def register_feldera_dialect() -> None: + if Dialect.get("feldera") is None: + Dialect.classes["feldera"] = SQLMeshFelderaDialect + + +register_feldera_dialect() \ No newline at end of file diff --git a/tests/core/engine_adapter/test_feldera.py b/tests/core/engine_adapter/test_feldera.py index 580c2466f7..a97b016276 100644 --- a/tests/core/engine_adapter/test_feldera.py +++ b/tests/core/engine_adapter/test_feldera.py @@ -5,15 +5,17 @@ import typing as t import pytest +from sqlglot import Dialect from sqlglot import parse_one from sqlglot import exp from sqlmesh.core.engine_adapter import FelderaEngineAdapter from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType -from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.engines.feldera.dialect import register_feldera_dialect +from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError from tests.core.engine_adapter import to_sql_calls -pytestmark = [pytest.mark.engine] +pytestmark = [pytest.mark.engine, pytest.mark.feldera] @pytest.fixture @@ -115,27 +117,44 @@ def get_pipeline(pipeline_name: str, client: t.Any) -> t.Any: monkeypatch.setitem(sys.modules, "feldera", feldera_module) monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) - data_objects = adapter._get_data_objects("catalog.requested_pipeline") + data_objects = adapter._get_data_objects("requested_pipeline") assert requested_pipeline_names == ["requested_pipeline"] assert [(obj.schema_name, obj.name, obj.type) for obj in data_objects] == [ ("requested_pipeline", "source", DataObjectType.TABLE), ("requested_pipeline", "sink", DataObjectType.VIEW), ] + + +def test_adapter_dialect_is_feldera(adapter: FelderaEngineAdapter) -> None: assert adapter.dialect == "feldera" -def test_builtin_dialect_registers_feldera_name() -> None: +def test_feldera_dialect_is_registered() -> None: assert parse_one("SELECT 1", dialect="feldera").sql(dialect="feldera") == "SELECT 1" -def test_builtin_dialect_preserves_current_timestamp_keyword() -> None: +def test_feldera_dialect_preserves_current_timestamp_keyword() -> None: assert ( parse_one("SELECT CURRENT_TIMESTAMP AS ts", dialect="feldera").sql(dialect="feldera") == "SELECT CURRENT_TIMESTAMP AS ts" ) +def test_register_feldera_dialect_registers_custom_type_mapping() -> None: + original = Dialect.classes.pop("feldera", None) + + try: + register_feldera_dialect() + assert exp.DataType.build("FLOAT").sql(dialect="feldera") == "REAL" + finally: + Dialect.classes.pop("feldera", None) + if original is not None: + Dialect.classes["feldera"] = original + else: + register_feldera_dialect() + + def test_get_data_objects_marks_materialized_views_from_state( adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch ): @@ -149,7 +168,7 @@ def test_get_data_objects_marks_materialized_views_from_state( ) _install_feldera_pipeline(monkeypatch, pipeline) - data_objects = adapter._get_data_objects("catalog.requested_pipeline") + data_objects = adapter._get_data_objects("requested_pipeline") assert [(obj.schema_name, obj.name, obj.type) for obj in data_objects] == [ ("requested_pipeline", "sink", DataObjectType.MATERIALIZED_VIEW), @@ -194,7 +213,7 @@ def test_replace_query_recreates_existing_table( ] -def test_insert_overwrite_by_time_partition_uses_table_operations( +def test_insert_overwrite_by_time_partition_uses_delete_insert( adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch ): monkeypatch.setattr(adapter, "_get_data_objects", lambda schema_name, object_names=None: []) @@ -218,6 +237,54 @@ def test_insert_overwrite_by_time_partition_uses_table_operations( ] +def test_insert_overwrite_without_condition_drops_and_recreates_table( + adapter: FelderaEngineAdapter, +) -> None: + recorded_calls: list[tuple[str, tuple[t.Any, ...], dict[str, t.Any]]] = [] + source_queries = [t.cast(t.Any, object())] + target_columns_to_types = {"a": exp.DataType.build("INT")} + + def record_drop_table(*args: t.Any, **kwargs: t.Any) -> None: + recorded_calls.append(("drop_table", args, kwargs)) + + def record_create_table(*args: t.Any, **kwargs: t.Any) -> None: + recorded_calls.append(("create_table", args, kwargs)) + + adapter.drop_table = record_drop_table # type: ignore[method-assign] + adapter._create_table_from_source_queries = record_create_table # type: ignore[method-assign] + + adapter._insert_overwrite_by_condition( + "db.full_model", + source_queries, + target_columns_to_types=target_columns_to_types, + where=None, + ) + + assert recorded_calls == [ + ("drop_table", ("db.full_model",), {}), + ( + "create_table", + ("db.full_model", source_queries), + { + "target_columns_to_types": target_columns_to_types, + "exists": True, + "replace": False, + }, + ), + ] + + +def test_create_table_from_source_queries_requires_known_column_types( + adapter: FelderaEngineAdapter, +) -> None: + with pytest.raises(SQLMeshError, match="requires known column types"): + adapter._create_table_from_source_queries("db.full_model", [], target_columns_to_types=None) + + +def test_create_schema_is_no_op(adapter: FelderaEngineAdapter) -> None: + assert adapter.create_schema("db") is None + + def test_create_view_creates_view(adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(adapter, "_get_data_objects", lambda schema_name, object_names=None: []) adapter.create_view( @@ -232,6 +299,14 @@ def test_create_view_creates_view(adapter: FelderaEngineAdapter, monkeypatch: py ] +def test_create_view_rejects_catalog_qualified_names(adapter: FelderaEngineAdapter) -> None: + with pytest.raises(UnsupportedCatalogOperationError, match="does not support catalogs"): + adapter.create_view( + "catalog.db.view_model", + parse_one("SELECT a FROM tbl"), + ) + + def test_replace_view_drops_then_creates_view( adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch ): diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index 1e6deb2959..e3ebf7a251 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1961,6 +1961,7 @@ def test_feldera_connection_config(make_config): assert isinstance(config, FelderaConnectionConfig) assert config.DIALECT == "feldera" + assert config.is_forbidden_for_state_sync with patch("sqlmesh.engines.feldera.db_api.connect") as mock_connect: config._connection_factory_with_kwargs() diff --git a/tests/engines/feldera/test_db_api.py b/tests/engines/feldera/test_db_api.py index 5bd4591718..ee4ccb17f6 100644 --- a/tests/engines/feldera/test_db_api.py +++ b/tests/engines/feldera/test_db_api.py @@ -1,3 +1,4 @@ +import sys import types import typing as t @@ -7,6 +8,13 @@ from sqlmesh.engines.feldera import db_api +@pytest.fixture(autouse=True) +def reset_feldera_shared_states() -> t.Iterator[None]: + db_api.FelderaConnection.reset_shared_states() + yield + db_api.FelderaConnection.reset_shared_states() + + def test_classify_treats_comment_prefixed_create_schema_as_pipeline_ddl() -> None: assert db_api._classify("/* sqlmesh */ CREATE SCHEMA foo") == db_api.SqlIntent.PIPELINE_DDL @@ -97,6 +105,44 @@ def queryable_relation_names(self) -> set[str]: assert cursor.fetchall() == [(1,)] +def test_cursor_clears_description_after_non_query_execute() -> None: + class FakePipeline: + def query(self, sql: str) -> list[dict[str, int]]: + return [{"a": 1}] + + def input_json(self, table_name: str, rows: object) -> None: + return None + + def execute(self, sql: str) -> None: + return None + + class FakeStateManager: + def __init__(self) -> None: + self.pipeline = FakePipeline() + + def has_pending_changes(self) -> bool: + return False + + def current_pipeline(self) -> object: + return self.pipeline + + def queryable_relation_names(self) -> set[str]: + return set() + + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=t.cast(t.Any, FakeStateManager()), + ) + + cursor.execute("SELECT 1 AS a") + assert cursor.description == [("a", None, None, None, None, None, None)] + + cursor.execute('INSERT INTO "seed_model" SELECT 1 AS "a" FROM (VALUES (1)) AS "t"("a")') + + assert cursor.description is None + + def test_cursor_ignores_virtual_layer_view_ddl() -> None: class FakeStateManager: def __init__(self) -> None: @@ -182,8 +228,8 @@ def test_hydrate_existing_program_skips_empty_parse_results(monkeypatch) -> None feldera_module = types.ModuleType("feldera") setattr(feldera_module, "pipeline", pipeline_module) - monkeypatch.setitem(__import__("sys").modules, "feldera", feldera_module) - monkeypatch.setitem(__import__("sys").modules, "feldera.pipeline", pipeline_module) + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) manager._hydrate_existing_program(object(), "test_pipeline") @@ -277,6 +323,44 @@ def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: ) +def test_format_compile_error_preserves_rust_and_system_error_details() -> None: + manager = db_api.PipelineStateManager() + + class FakeClient: + def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: + return types.SimpleNamespace( + program_status="RustError", + program_error={ + "rust_compilation": "rust failed", + "system_error": "system failed", + }, + ) + + error = manager._format_compile_error( + FakeClient(), + "test_pipeline", + RuntimeError("The program failed to compile: RustError"), + ) + + assert str(error) == ( + "The program failed to compile: RustError\n" + "Rust Error: rust failed\n" + "System Error: system failed" + ) + + +def test_format_compile_error_returns_original_message_without_program_error() -> None: + manager = db_api.PipelineStateManager() + + class FakeClient: + def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: + return types.SimpleNamespace(program_status="Unknown", program_error=None) + + error = manager._format_compile_error(FakeClient(), "test_pipeline", RuntimeError("boom")) + + assert str(error) == "boom" + + def test_format_compile_error_requests_all_pipeline_fields(monkeypatch) -> None: manager = db_api.PipelineStateManager() requested_field_selector: list[object] = [] @@ -284,7 +368,7 @@ def test_format_compile_error_requests_all_pipeline_fields(monkeypatch) -> None: enums_module = types.ModuleType("feldera.enums") setattr(enums_module, "PipelineFieldSelector", types.SimpleNamespace(ALL=selector_all)) - monkeypatch.setitem(__import__("sys").modules, "feldera.enums", enums_module) + monkeypatch.setitem(sys.modules, "feldera.enums", enums_module) class FakeClient: def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: @@ -296,7 +380,7 @@ def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: assert requested_field_selector == [selector_all] -def test_deploy_imports_compilation_profile_from_feldera_enums(monkeypatch) -> None: +def test_deploy_succeeds_with_mocked_feldera_modules(monkeypatch) -> None: manager = db_api.PipelineStateManager() class CompilationProfile(str): @@ -348,16 +432,14 @@ def to_dict(self) -> dict[str, object]: setattr(feldera_module, "runtime_config", runtime_config_module) setattr(feldera_module, "rest", rest_module) - monkeypatch.setitem(__import__("sys").modules, "feldera", feldera_module) - monkeypatch.setitem(__import__("sys").modules, "feldera.enums", enums_module) - monkeypatch.setitem(__import__("sys").modules, "feldera.pipeline", pipeline_module) - monkeypatch.setitem( - __import__("sys").modules, "feldera.pipeline_builder", pipeline_builder_module - ) - monkeypatch.setitem(__import__("sys").modules, "feldera.runtime_config", runtime_config_module) - monkeypatch.setitem(__import__("sys").modules, "feldera.rest", rest_module) - monkeypatch.setitem(__import__("sys").modules, "feldera.rest.errors", rest_errors_module) - monkeypatch.setitem(__import__("sys").modules, "feldera.rest.pipeline", rest_pipeline_module) + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.enums", enums_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline_builder", pipeline_builder_module) + monkeypatch.setitem(sys.modules, "feldera.runtime_config", runtime_config_module) + monkeypatch.setitem(sys.modules, "feldera.rest", rest_module) + monkeypatch.setitem(sys.modules, "feldera.rest.errors", rest_errors_module) + monkeypatch.setitem(sys.modules, "feldera.rest.pipeline", rest_pipeline_module) manager.deploy(object(), "test_pipeline") @@ -398,6 +480,10 @@ class RuntimeConfig: def to_dict(self) -> dict[str, object]: return {} + class PipelineBuilder: + def __init__(self, *args: object, **kwargs: object) -> None: + raise AssertionError("PipelineBuilder should not be constructed when the pipeline exists") + client = types.SimpleNamespace(create_or_update_pipeline=lambda pipeline, wait=True: pipeline) manager._compile_program( @@ -408,7 +494,7 @@ def to_dict(self) -> dict[str, object]: RuntimeConfig(), 300, Pipeline, - object, + PipelineBuilder, InnerPipeline, RuntimeError, ) @@ -463,14 +549,12 @@ def to_dict(self) -> dict[str, object]: setattr(rest_errors_module, "FelderaAPIError", RuntimeError) setattr(rest_pipeline_module, "Pipeline", type("InnerPipeline", (), {})) - monkeypatch.setitem(__import__("sys").modules, "feldera.enums", enums_module) - monkeypatch.setitem(__import__("sys").modules, "feldera.pipeline", pipeline_module) - monkeypatch.setitem( - __import__("sys").modules, "feldera.pipeline_builder", pipeline_builder_module - ) - monkeypatch.setitem(__import__("sys").modules, "feldera.runtime_config", runtime_config_module) - monkeypatch.setitem(__import__("sys").modules, "feldera.rest.errors", rest_errors_module) - monkeypatch.setitem(__import__("sys").modules, "feldera.rest.pipeline", rest_pipeline_module) + monkeypatch.setitem(sys.modules, "feldera.enums", enums_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline_builder", pipeline_builder_module) + monkeypatch.setitem(sys.modules, "feldera.runtime_config", runtime_config_module) + monkeypatch.setitem(sys.modules, "feldera.rest.errors", rest_errors_module) + monkeypatch.setitem(sys.modules, "feldera.rest.pipeline", rest_pipeline_module) monkeypatch.setattr(manager, "_hydrate_existing_program", lambda client, pipeline_name: None) monkeypatch.setattr(manager, "assemble_program", lambda: "CREATE TABLE x (id INT)") @@ -512,6 +596,37 @@ def deploy(self, *args: object, **kwargs: object) -> object: assert "TIMESTAMP_TRUNC problem" in caplog.text +def test_hydrate_existing_program_returns_silently_when_pipeline_lookup_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + manager = db_api.PipelineStateManager() + + pipeline_module = types.ModuleType("feldera.pipeline") + setattr( + pipeline_module, + "Pipeline", + type( + "Pipeline", + (), + { + "get": staticmethod( + lambda pipeline_name, client: (_ for _ in ()).throw(RuntimeError("missing")) + ) + }, + ), + ) + feldera_module = types.ModuleType("feldera") + setattr(feldera_module, "pipeline", pipeline_module) + + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + + manager._hydrate_existing_program(object(), "test_pipeline") + + assert manager.pending_tables() == set() + assert manager.pending_views() == set() + + def test_state_manager_adds_query_mirrors_for_non_materialized_relations() -> None: manager = db_api.PipelineStateManager() @@ -612,8 +727,8 @@ def test_hydrate_existing_program_skips_query_mirrors(monkeypatch) -> None: feldera_module = types.ModuleType("feldera") setattr(feldera_module, "pipeline", pipeline_module) - monkeypatch.setitem(__import__("sys").modules, "feldera", feldera_module) - monkeypatch.setitem(__import__("sys").modules, "feldera.pipeline", pipeline_module) + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) manager._hydrate_existing_program(object(), "test_pipeline") @@ -800,3 +915,24 @@ def current_pipeline(self) -> object: with pytest.raises(RuntimeError, match="Execution error: test failure"): cursor.execute("SELECT COUNT(*) FROM full_model") + + +def test_connect_preserves_integer_timeout(monkeypatch) -> None: + captured_kwargs: dict[str, object] = {} + + class FelderaClient: + def __init__(self, **kwargs: object) -> None: + captured_kwargs.update(kwargs) + + feldera_client_module = types.ModuleType("feldera.rest.feldera_client") + setattr(feldera_client_module, "FelderaClient", FelderaClient) + monkeypatch.setitem(sys.modules, "feldera.rest.feldera_client", feldera_client_module) + + db_api.connect( + host="http://localhost:8080", + pipeline_name="test_pipeline", + timeout=123, + ) + + assert captured_kwargs["timeout"] == 123 + assert isinstance(captured_kwargs["timeout"], int) From e43665e254b1e5e32767a6cb7516978752152cc1 Mon Sep 17 00:00:00 2001 From: fresioAS Date: Wed, 13 May 2026 00:57:53 +0200 Subject: [PATCH 10/10] style Signed-off-by: fresioAS --- sqlmesh/engines/feldera/db_api.py | 1 - sqlmesh/engines/feldera/dialect.py | 2 +- tests/engines/feldera/test_db_api.py | 4 +++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sqlmesh/engines/feldera/db_api.py b/sqlmesh/engines/feldera/db_api.py index 851d6582b0..86e85d268e 100644 --- a/sqlmesh/engines/feldera/db_api.py +++ b/sqlmesh/engines/feldera/db_api.py @@ -9,7 +9,6 @@ from sqlglot import exp, parse, parse_one from sqlglot.errors import ParseError -import sqlmesh.engines.feldera.dialect logger = logging.getLogger(__name__) diff --git a/sqlmesh/engines/feldera/dialect.py b/sqlmesh/engines/feldera/dialect.py index b0f794121c..6362202266 100644 --- a/sqlmesh/engines/feldera/dialect.py +++ b/sqlmesh/engines/feldera/dialect.py @@ -30,4 +30,4 @@ def register_feldera_dialect() -> None: Dialect.classes["feldera"] = SQLMeshFelderaDialect -register_feldera_dialect() \ No newline at end of file +register_feldera_dialect() diff --git a/tests/engines/feldera/test_db_api.py b/tests/engines/feldera/test_db_api.py index ee4ccb17f6..6d9cdb0fe2 100644 --- a/tests/engines/feldera/test_db_api.py +++ b/tests/engines/feldera/test_db_api.py @@ -482,7 +482,9 @@ def to_dict(self) -> dict[str, object]: class PipelineBuilder: def __init__(self, *args: object, **kwargs: object) -> None: - raise AssertionError("PipelineBuilder should not be constructed when the pipeline exists") + raise AssertionError( + "PipelineBuilder should not be constructed when the pipeline exists" + ) client = types.SimpleNamespace(create_or_update_pipeline=lambda pipeline, wait=True: pipeline)