diff --git a/pyproject.toml b/pyproject.toml index bcc69c667e..137df13a7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "sqlglot~=30.4.2", "tenacity", "time-machine", - "json-stream" + "json-stream", ] classifiers = [ "Intended Audience :: Developers", @@ -42,10 +42,7 @@ classifiers = [ athena = ["PyAthena[Pandas]"] azuresql = ["pymssql"] azuresql-odbc = ["pyodbc>=5.0.0"] -bigquery = [ - "google-cloud-bigquery[pandas]", - "google-cloud-bigquery-storage" -] +bigquery = ["google-cloud-bigquery[pandas]", "google-cloud-bigquery-storage"] # bigframes has to be separate to support environments with an older google-cloud-bigquery pin # this is because that pin pulls in an older bigframes and the bigframes team # pinned an older SQLGlot which is incompatible with SQLMesh @@ -71,6 +68,7 @@ dev = [ "dbt-redshift", "dbt-trino", "Faker", + "feldera", "google-auth", "google-cloud-bigquery", "google-cloud-bigquery-storage", @@ -109,6 +107,7 @@ dbt = ["dbt-core<2"] dlt = ["dlt"] duckdb = [] fabric = ["pyodbc>=5.0.0"] +feldera = ["feldera"] gcppostgres = ["cloud-sql-python-connector[pg8000]>=1.8.0"] github = ["PyGithub>=2.6.0"] motherduck = ["duckdb>=1.3.2"] @@ -183,11 +182,7 @@ no_implicit_optional = true disallow_untyped_defs = true [[tool.mypy.overrides]] -module = [ - "examples.*.macros.*", - "tests.*", - "sqlmesh.migrations.*" -] +module = ["examples.*.macros.*", "tests.*", "sqlmesh.migrations.*"] disallow_untyped_defs = false # Sometimes it's helpful to use types within an "untyped" function because it allows IDE assistance # Unfortunately this causes MyPy to print an annoying 'By default the bodies of untyped functions are not checked' @@ -225,9 +220,10 @@ module = [ "bs4.*", "pydantic_core.*", "dlt.*", + "feldera.*", "bigframes.*", "json_stream.*", - "duckdb.*" + "duckdb.*", ] ignore_missing_imports = true @@ -259,6 +255,7 @@ markers = [ "clickhouse_cloud: test for Clickhouse (cloud mode)", "databricks: test for Databricks", "duckdb: test for DuckDB", + "feldera: test for Feldera", "fabric: test for Fabric", "motherduck: test for MotherDuck", "mssql: test for MSSQL", @@ -274,7 +271,7 @@ markers = [ # Other "set_default_connection", - "registry_isolation" + "registry_isolation", ] addopts = "-n 0 --dist=loadgroup" asyncio_default_fixture_loop_scope = "session" @@ -282,7 +279,7 @@ log_cli = false # Set this to true to enable logging during tests log_cli_format = "%(asctime)s.%(msecs)03d %(filename)s:%(lineno)d %(levelname)s %(message)s" log_cli_level = "INFO" filterwarnings = [ - "ignore:The localize method is no longer necessary, as this time zone supports the fold attribute" + "ignore:The localize method is no longer necessary, as this time zone supports the fold attribute", ] reruns_delay = 10 @@ -290,20 +287,12 @@ reruns_delay = 10 line-length = 100 [tool.ruff.lint] -select = [ - "F401", - "RET505", - "T100", -] +select = ["F401", "RET505", "T100"] extend-select = ["TID"] [tool.ruff.lint.flake8-tidy-imports] -banned-module-level-imports = [ - "duckdb", - "numpy", - "pandas", -] +banned-module-level-imports = ["duckdb", "numpy", "pandas"] # Bans imports from sqlmesh.lsp in files outside of sqlmesh/lsp [tool.ruff.lint.flake8-tidy-imports.banned-api] diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index d930537711..c8a4eda5b7 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -56,6 +56,7 @@ # Do not support row-level operations "spark", "trino", + "feldera", # Nullable types are problematic "clickhouse", } @@ -2342,6 +2343,60 @@ def init(cursor: t.Any) -> None: return init +class FelderaConnectionConfig(ConnectionConfig): + """Feldera connection configuration. + + Args: + host: The Feldera API base URL. + api_key: The optional Feldera API key. + pipeline_name: The name of the backing Feldera pipeline. + compilation_profile: The Feldera compilation profile to use during deploys. + workers: The number of workers in the Feldera runtime config. + timeout: The timeout, in seconds, for Feldera API operations. + """ + + host: str = "http://localhost:8080" + api_key: t.Optional[str] = None + pipeline_name: str + compilation_profile: str = "dev" + workers: int = 4 + timeout: int = 300 + + type_: t.Literal["feldera"] = Field(alias="type", default="feldera") + DIALECT: t.ClassVar[t.Literal["feldera"]] = "feldera" + DISPLAY_NAME: t.ClassVar[t.Literal["Feldera"]] = "Feldera" + DISPLAY_ORDER: t.ClassVar[t.Literal[18]] = 18 + + concurrent_tasks: int = 1 + register_comments: t.Literal[False] = False + pre_ping: t.Literal[False] = False + + _engine_import_validator = _get_engine_import_validator("feldera", "feldera") + + @property + def _connection_kwargs_keys(self) -> t.Set[str]: + return { + "host", + "api_key", + "pipeline_name", + "workers", + "compilation_profile", + "timeout", + } + + @property + def _engine_adapter(self) -> t.Type[EngineAdapter]: + from sqlmesh.core.engine_adapter.feldera import FelderaEngineAdapter + + return FelderaEngineAdapter + + @property + def _connection_factory(self) -> t.Callable: + from sqlmesh.engines.feldera.db_api import connect + + return connect + + _CONNECTION_CONFIG_EXCLUDE: t.Set[t.Type[ConnectionConfig]] = { ConnectionConfig, # type: ignore[type-abstract] BaseDuckDBConnectionConfig, # type: ignore[type-abstract] diff --git a/sqlmesh/core/engine_adapter/__init__.py b/sqlmesh/core/engine_adapter/__init__.py index ab29885c7b..cb0e2a8413 100644 --- a/sqlmesh/core/engine_adapter/__init__.py +++ b/sqlmesh/core/engine_adapter/__init__.py @@ -20,6 +20,7 @@ from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter from sqlmesh.core.engine_adapter.fabric import FabricEngineAdapter +from sqlmesh.core.engine_adapter.feldera import FelderaEngineAdapter DIALECT_TO_ENGINE_ADAPTER = { "hive": SparkEngineAdapter, @@ -37,6 +38,7 @@ "athena": AthenaEngineAdapter, "risingwave": RisingwaveEngineAdapter, "fabric": FabricEngineAdapter, + "feldera": FelderaEngineAdapter, } DIALECT_ALIASES = { diff --git a/sqlmesh/core/engine_adapter/feldera.py b/sqlmesh/core/engine_adapter/feldera.py new file mode 100644 index 0000000000..8af139f07d --- /dev/null +++ b/sqlmesh/core/engine_adapter/feldera.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +import logging +import typing as t + +from sqlglot import exp + +from sqlmesh.core.dialect import to_schema +from sqlmesh.core.engine_adapter.base import EngineAdapter +from sqlmesh.core.engine_adapter.shared import ( + CommentCreationTable, + CommentCreationView, + DataObject, + DataObjectType, + set_catalog, +) +from sqlmesh.engines.feldera.db_api import QUERY_MIRROR_PREFIX +from sqlmesh.utils.errors import SQLMeshError + +if t.TYPE_CHECKING: + import pandas as pd + + from sqlmesh.core._typing import SchemaName, TableName + from sqlmesh.core.engine_adapter.shared import SourceQuery + + +logger = logging.getLogger(__name__) + + +def _is_query_mirror_name(name: str) -> bool: + return name.lower().startswith(QUERY_MIRROR_PREFIX) + + +def _view_type(state: t.Any, object_name: str) -> DataObjectType: + is_materialized_view = getattr(state, "is_materialized_view", None) + + if callable(is_materialized_view) and is_materialized_view(object_name): + return DataObjectType.MATERIALIZED_VIEW + + return DataObjectType.VIEW + + +_FELDERA_TO_EXP_TYPE: t.Dict[str, t.Any] = { + "BOOLEAN": exp.DataType.Type.BOOLEAN, + "TINYINT": exp.DataType.Type.TINYINT, + "SMALLINT": exp.DataType.Type.SMALLINT, + "INTEGER": exp.DataType.Type.INT, + "INT": exp.DataType.Type.INT, + "BIGINT": exp.DataType.Type.BIGINT, + "REAL": exp.DataType.Type.FLOAT, + "DOUBLE": exp.DataType.Type.DOUBLE, + "DECIMAL": exp.DataType.Type.DECIMAL, + "NUMERIC": exp.DataType.Type.DECIMAL, + "VARCHAR": exp.DataType.Type.VARCHAR, + "CHAR": exp.DataType.Type.CHAR, + "DATE": exp.DataType.Type.DATE, + "TIME": exp.DataType.Type.TIME, + "TIMESTAMP": exp.DataType.Type.TIMESTAMP, + "ARRAY": exp.DataType.Type.ARRAY, +} + + +def _feldera_type_to_exp(dtype_str: str) -> exp.DataType: + base = dtype_str.split("(")[0].strip().upper() + kind = _FELDERA_TO_EXP_TYPE.get(base, exp.DataType.Type.TEXT) + return exp.DataType(this=kind) + + +@set_catalog() +class FelderaEngineAdapter(EngineAdapter): + DIALECT = "feldera" + SUPPORTS_TRANSACTIONS = False + SUPPORTS_INDEXES = False + SUPPORTS_MATERIALIZED_VIEWS = True + SUPPORTS_REPLACE_TABLE = False + COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED + COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED + + def _fetch_native_df( + self, + query: t.Union[exp.Expr, str], + quote_identifiers: bool = False, + ) -> pd.DataFrame: + with self.transaction(): + self.execute(query, quote_identifiers=quote_identifiers) + return self.cursor.fetchdf() + + def _get_data_objects( + self, + schema_name: SchemaName, + object_names: t.Optional[t.Set[str]] = None, + ) -> t.List[DataObject]: + from feldera.pipeline import Pipeline + + connection = self.connection + pipeline_name = to_schema(schema_name).db + lower_object_names = {name.lower() for name in object_names} if object_names else None + + try: + pipeline = Pipeline.get(pipeline_name, connection._client) + except Exception: + return [] + + objects_by_name: t.Dict[str, DataObject] = {} + for table in pipeline.tables(): + name = table.name.lower() + if lower_object_names and name not in lower_object_names: + continue + objects_by_name[name] = DataObject( + catalog=None, + schema=pipeline_name, + name=name, + type=DataObjectType.TABLE, + ) + + for view in pipeline.views(): + name = view.name.lower() + if _is_query_mirror_name(name): + continue + if lower_object_names and name not in lower_object_names: + continue + objects_by_name[name] = DataObject( + catalog=None, + schema=pipeline_name, + name=name, + type=_view_type(connection._state, name), + ) + + pending_drops = connection._state.pending_drops() + for object_name in pending_drops: + objects_by_name.pop(object_name, None) + + for object_name in connection._state.pending_tables(): + if lower_object_names and object_name not in lower_object_names: + continue + objects_by_name[object_name] = DataObject( + catalog=None, + schema=pipeline_name, + name=object_name, + type=DataObjectType.TABLE, + ) + + for object_name in connection._state.pending_views(): + if lower_object_names and object_name not in lower_object_names: + continue + objects_by_name[object_name] = DataObject( + catalog=None, + schema=pipeline_name, + name=object_name, + type=_view_type(connection._state, object_name), + ) + + return list(objects_by_name.values()) + + def columns( + self, table_name: TableName, include_pseudo_columns: bool = False + ) -> t.Dict[str, exp.DataType]: + from feldera.pipeline import Pipeline + + connection = self.connection + pipeline = Pipeline.get(connection._pipeline_name, connection._client) + target = exp.to_table(table_name).name.lower() + + for obj in (*pipeline.tables(), *pipeline.views()): + if obj.name.lower() == target: + return { + field["name"]: _feldera_type_to_exp( + field.get("columntype", {}).get("type", "VARCHAR") + if isinstance(field.get("columntype"), dict) + else str(field.get("columntype", "VARCHAR")) + ) + for field in (obj.fields or []) + } + + raise SQLMeshError( + f"Table/view '{target}' not found in pipeline '{connection._pipeline_name}'" + ) + + def get_current_catalog(self) -> t.Optional[str]: + return None + + def create_view( + self, + view_name: TableName, + query_or_df: t.Any, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + replace: bool = True, + materialized: bool = False, + materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expr]] = None, + source_columns: t.Optional[t.List[str]] = None, + **create_kwargs: t.Any, + ) -> None: + if replace: + target_data_object = self.get_data_object(exp.to_table(view_name)) + if target_data_object is not None: + self.drop_data_object(target_data_object, ignore_if_not_exists=True) + + super().create_view( + view_name, + query_or_df, + target_columns_to_types=target_columns_to_types, + replace=False, + materialized=materialized, + materialized_properties=materialized_properties, + table_description=table_description, + column_descriptions=column_descriptions, + view_properties=view_properties, + source_columns=source_columns, + **create_kwargs, + ) + + def _create_table_from_source_queries( + self, + table_name: TableName, + source_queries: t.List[SourceQuery], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + exists: bool = True, + replace: bool = False, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, + **kwargs: t.Any, + ) -> None: + if replace: + return super()._create_table_from_source_queries( + table_name, + source_queries, + target_columns_to_types=target_columns_to_types, + exists=exists, + replace=replace, + table_description=table_description, + column_descriptions=column_descriptions, + table_kind=table_kind, + track_rows_processed=track_rows_processed, + **kwargs, + ) + + if not target_columns_to_types: + raise SQLMeshError( + "Feldera requires known column types when creating a table from a query." + ) + + with self.transaction(condition=len(source_queries) > 1): + self._create_table_from_columns( + table_name, + target_columns_to_types, + exists=exists, + table_description=table_description, + column_descriptions=column_descriptions, + **kwargs, + ) + for source_query in source_queries: + with source_query as query: + self._insert_append_query( + table_name, + query, + target_columns_to_types, + track_rows_processed=track_rows_processed, + ) + + def _insert_overwrite_by_condition( + self, + table_name: TableName, + source_queries: t.List[SourceQuery], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + where: t.Optional[exp.Condition] = None, + insert_overwrite_strategy_override: t.Optional[t.Any] = None, + **kwargs: t.Any, + ) -> None: + # Whole-table replacement is cheaper in Feldera as DROP+CREATE than DELETE+INSERT. + if where is None and source_queries: + self.drop_table(table_name) + self._create_table_from_source_queries( + table_name, + source_queries, + target_columns_to_types=target_columns_to_types, + exists=True, + replace=False, + **kwargs, + ) + return + + super()._insert_overwrite_by_condition( + table_name, + source_queries, + target_columns_to_types=target_columns_to_types, + where=where, + insert_overwrite_strategy_override=insert_overwrite_strategy_override, + **kwargs, + ) + + def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None: + target_data_object = self.get_data_object(exp.to_table(table_name)) + if target_data_object: + if target_data_object.type.is_materialized_view: + self.drop_view( + table_name, + ignore_if_not_exists=exists, + materialized=True, + **kwargs, + ) + return + if target_data_object.type.is_view: + self.drop_view(table_name, ignore_if_not_exists=exists, **kwargs) + return + super().drop_table(table_name, exists=exists, **kwargs) + + def create_schema( + self, + schema_name: SchemaName, + ignore_if_exists: bool = True, + warn_on_error: bool = True, + properties: t.Optional[t.List[exp.Expr]] = None, + ) -> None: + return None + + def ping(self) -> None: + self.connection._client.get_config() diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index b1ffd4dc26..89c179d19e 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -2060,7 +2060,7 @@ def create( # Only sql models have queries that can be tested. # We also need to make sure that we don't dry run on Redshift because its planner / optimizer sometimes # breaks on our CTAS queries due to us relying on the WHERE FALSE LIMIT 0 combo. - if model.is_sql and dry_run and self.adapter.dialect != "redshift": + if model.is_sql and dry_run and self.adapter.dialect not in {"redshift", "feldera"}: logger.info("Dry running model '%s'", model.name) self.adapter.fetchall(ctas_query) else: diff --git a/sqlmesh/engines/feldera/__init__.py b/sqlmesh/engines/feldera/__init__.py new file mode 100644 index 0000000000..9d48db4f9f --- /dev/null +++ b/sqlmesh/engines/feldera/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/sqlmesh/engines/feldera/db_api.py b/sqlmesh/engines/feldera/db_api.py new file mode 100644 index 0000000000..86e85d268e --- /dev/null +++ b/sqlmesh/engines/feldera/db_api.py @@ -0,0 +1,1121 @@ +from __future__ import annotations + +import logging +import threading +import typing as t +from enum import Enum +import re + +from sqlglot import exp, parse, parse_one +from sqlglot.errors import ParseError + + +logger = logging.getLogger(__name__) + +QUERY_MIRROR_PREFIX = "__sqlmesh_query__" +FELDERA_DIALECT = "feldera" + +if t.TYPE_CHECKING: + import pandas as pd + + +class SqlIntent(Enum): + ADHOC_QUERY = "adhoc_query" + PIPELINE_DDL = "pipeline_ddl" + DATA_INGRESS = "data_ingress" + NO_OP = "no_op" + + +def _classify(sql: str) -> SqlIntent: + """Route SQL to the correct Feldera endpoint.""" + stripped = _strip_leading_comments(sql) + if not stripped: + return SqlIntent.NO_OP + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + expression = None + + if isinstance(expression, (exp.Create, exp.Drop, exp.Alter)): + return SqlIntent.PIPELINE_DDL + if isinstance(expression, exp.Insert): + return SqlIntent.DATA_INGRESS + + upper = stripped.upper() + if upper.startswith(("CREATE", "DROP", "ALTER")): + return SqlIntent.PIPELINE_DDL + if upper.startswith("INSERT"): + return SqlIntent.DATA_INGRESS + return SqlIntent.ADHOC_QUERY + + +def _is_virtual_layer_ddl(sql: str) -> bool: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + return False + + if not isinstance(expression, (exp.Create, exp.Drop)): + return False + + target = expression.this + if isinstance(target, exp.Schema): + target = target.this + + if not isinstance(target, exp.Table) or not target.db: + return False + + target_db = target.db.lower() + if target_db.startswith("sqlmesh__"): + return False + + referenced_snapshot_tables = [ + table + for table in expression.find_all(exp.Table) + if table is not target and table.db and table.db.lower().startswith("sqlmesh__") + ] + return bool(referenced_snapshot_tables) + + +class PipelineStateManager: + """Accumulates DDL and deploys it as a single Feldera pipeline program.""" + + def __init__(self) -> None: + self._lock = threading.RLock() + self._tables: t.Dict[str, str] = {} + self._views: t.Dict[str, str] = {} + self._dropped_objects: t.Set[str] = set() + self._pipeline = None + self._dirty = False + self._hydrated = False + self._hydrated_object_keys: t.Set[str] = set() + + def register_ddl(self, sql: str) -> None: + with self._lock: + object_key = _extract_name(sql) + expression = None + + try: + expression = parse_one(_strip_leading_comments(sql), dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + pass + + self._hydrated_object_keys.discard(object_key) + + if isinstance(expression, exp.Create): + kind = str(expression.args.get("kind") or "").upper() + if "TABLE" in kind: + sql = _rewrite_table_ctas_sql(sql) + self._dropped_objects.discard(object_key) + self._views.pop(object_key, None) + self._tables[object_key] = sql + self._dirty = True + elif "VIEW" in kind: + self._dropped_objects.discard(object_key) + self._tables.pop(object_key, None) + self._views[object_key] = sql + self._dirty = True + return + + if isinstance(expression, exp.Drop): + kind = str(expression.args.get("kind") or "").upper() + if "TABLE" in kind: + self._tables.pop(object_key, None) + self._dropped_objects.add(object_key) + self._dirty = True + elif "VIEW" in kind: + self._views.pop(object_key, None) + self._dropped_objects.add(object_key) + self._dirty = True + return + + upper = sql.strip().upper() + if "CREATE TABLE" in upper or "CREATE MATERIALIZED TABLE" in upper: + sql = _rewrite_table_ctas_sql(sql) + self._dropped_objects.discard(object_key) + self._views.pop(object_key, None) + self._tables[object_key] = sql + self._dirty = True + elif "CREATE VIEW" in upper or "CREATE MATERIALIZED VIEW" in upper: + self._dropped_objects.discard(object_key) + self._tables.pop(object_key, None) + self._views[object_key] = sql + self._dirty = True + elif "DROP TABLE" in upper: + self._tables.pop(object_key, None) + self._dropped_objects.add(object_key) + self._dirty = True + elif "DROP VIEW" in upper: + self._views.pop(object_key, None) + self._dropped_objects.add(object_key) + self._dirty = True + + def has_pending_changes(self) -> bool: + with self._lock: + return self._dirty + + def pending_tables(self) -> t.Set[str]: + with self._lock: + return set(self._tables) + + def pending_views(self) -> t.Set[str]: + with self._lock: + return set(self._views) + + def is_materialized_view(self, object_name: str) -> bool: + with self._lock: + return _is_materialized_view_sql(self._views.get(object_name.lower(), "")) + + def pending_drops(self) -> t.Set[str]: + with self._lock: + return set(self._dropped_objects) + + def current_pipeline(self) -> t.Any: + with self._lock: + return self._pipeline + + def queryable_relation_names(self) -> t.Set[str]: + with self._lock: + return { + *self._tables, + *(name for name, sql in self._views.items() if not _is_materialized_view_sql(sql)), + } + + def assemble_program(self) -> str: + with self._lock: + statements = [ + *( + statement + for ddl in self._tables.values() + for statement in _ddl_statements_with_query_mirror(ddl) + ), + *( + statement + for ddl in self._views.values() + for statement in _ddl_statements_with_query_mirror(ddl) + ), + ] + return "\n\n".join(statements) + + def deploy( + self, + client: t.Any, + pipeline_name: str, + workers: int = 4, + compilation_profile: str = "dev", + timeout: int = 300, + ) -> t.Any: + from feldera.enums import CompilationProfile + from feldera.pipeline import Pipeline + from feldera.pipeline_builder import PipelineBuilder + from feldera.rest.errors import FelderaAPIError + from feldera.rest.pipeline import Pipeline as InnerPipeline + from feldera.runtime_config import RuntimeConfig + + with self._lock: + self._hydrate_existing_program(client, pipeline_name) + if not self._dirty: + return self._pipeline + + sql = self.assemble_program() + if not sql.strip(): + self._dirty = False + self._dropped_objects.clear() + return self._pipeline + + runtime_config = RuntimeConfig.default() + runtime_config.workers = workers + + profile = compilation_profile + if isinstance(profile, str): + profile = CompilationProfile(profile) + + while True: + try: + pipeline = self._compile_program( + client, + pipeline_name, + sql, + profile, + runtime_config, + timeout, + Pipeline, + PipelineBuilder, + InnerPipeline, + FelderaAPIError, + ) + break + except RuntimeError as ex: + if not self._evict_hydrated_objects(str(ex)): + raise + + sql = self.assemble_program() + if not sql.strip(): + self._dirty = False + self._dropped_objects.clear() + return self._pipeline + + pipeline.start(timeout_s=timeout) + self._pipeline = pipeline + self._dirty = False + self._dropped_objects.clear() + return pipeline + + def _hydrate_existing_program(self, client: t.Any, pipeline_name: str) -> None: + if self._hydrated: + return + + self._hydrated = True + + try: + from feldera.pipeline import Pipeline + + pipeline = Pipeline.get(pipeline_name, client) + program_code = getattr(getattr(pipeline, "_inner", None), "program_code", "") or "" + except Exception: + return + + for expression in parse(program_code, dialect=FELDERA_DIALECT): + if expression is None: + continue + + sql = expression.sql(dialect=FELDERA_DIALECT) + + if isinstance(expression, exp.Create): + target = expression.this + if isinstance(target, exp.Schema): + target = target.this + + if isinstance(target, exp.Table) and _is_query_mirror_name(target.name): + continue + + object_key = _extract_name(sql) + kind = str(expression.args.get("kind") or "").upper() + if "VIEW" in kind: + self._tables.pop(object_key, None) + self._views[object_key] = sql + elif "TABLE" in kind: + sql = _rewrite_table_ctas_sql(sql) + self._views.pop(object_key, None) + self._tables[object_key] = sql + self._hydrated_object_keys.add(object_key) + + def _evict_hydrated_objects(self, error_message: str) -> bool: + if not self._hydrated_object_keys: + return False + + normalized_error = error_message.lower() + object_keys = [ + object_key + for object_key in self._hydrated_object_keys + if object_key in normalized_error + ] + + if not object_keys: + if "not found" not in normalized_error: + return False + object_keys = list(self._hydrated_object_keys) + + for object_key in object_keys: + self._tables.pop(object_key, None) + self._views.pop(object_key, None) + self._hydrated_object_keys.discard(object_key) + + return True + + def _compile_program( + self, + client: t.Any, + pipeline_name: str, + sql: str, + profile: t.Any, + runtime_config: t.Any, + timeout: int, + Pipeline: t.Any, + PipelineBuilder: t.Any, + InnerPipeline: t.Any, + FelderaAPIError: t.Any, + ) -> t.Any: + try: + existing_pipeline = Pipeline.get(pipeline_name, client) + except FelderaAPIError: + existing_pipeline = None + + if existing_pipeline is None: + try: + return PipelineBuilder( + client, + name=pipeline_name, + sql=sql, + compilation_profile=profile, + runtime_config=runtime_config, + ).create_or_replace(wait=True) + except RuntimeError as ex: + raise self._format_compile_error(client, pipeline_name, ex) from ex + + existing_pipeline.stop(force=True, timeout_s=timeout) + existing_pipeline.dismiss_error() + + inner_pipeline = InnerPipeline( + name=pipeline_name, + sql=sql, + udf_rust="", + udf_toml="", + description="", + program_config={ + "profile": profile.value, + "runtime_version": None, + }, + runtime_config=runtime_config.to_dict(), + ) + try: + inner_pipeline = client.create_or_update_pipeline(inner_pipeline, wait=True) + except RuntimeError as ex: + raise self._format_compile_error(client, pipeline_name, ex) from ex + pipeline = Pipeline(client) + pipeline._inner = inner_pipeline + return pipeline + + def _format_compile_error( + self, client: t.Any, pipeline_name: str, error: Exception + ) -> RuntimeError: + error_message = str(error) + from feldera.enums import PipelineFieldSelector + + try: + pipeline = client.get_pipeline(pipeline_name, PipelineFieldSelector.ALL) + except Exception: + return RuntimeError(error_message) + + program_error = getattr(pipeline, "program_error", None) or {} + sql_compilation = program_error.get("sql_compilation") or {} + sql_messages = sql_compilation.get("messages") or [] + + if sql_messages: + details = self._sql_compilation_error_message(pipeline_name, sql_messages) + if details != error_message: + return RuntimeError(details) + + rust_error = program_error.get("rust_compilation") + system_error = program_error.get("system_error") + if rust_error or system_error: + message = ( + f"The program failed to compile: {getattr(pipeline, 'program_status', 'unknown')}\n" + ) + if rust_error is not None: + message += f"Rust Error: {rust_error}\n" + if system_error is not None: + message += f"System Error: {system_error}" + return RuntimeError(message.rstrip()) + + return RuntimeError(error_message) + + @staticmethod + def _sql_compilation_error_message( + pipeline_name: str, sql_errors: t.Sequence[t.Mapping[str, t.Any]] + ) -> str: + err_msg = f"Pipeline {pipeline_name} failed to compile:\n" + for sql_error in sql_errors: + err_msg += f"{sql_error['error_type']}\n{sql_error['message']}\n" + err_msg += f"Code snippet:\n{sql_error['snippet']}" + return err_msg + + +def _extract_name(sql: str) -> str: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + return stripped[:80].lower() + + target = expression.this if isinstance(expression, (exp.Create, exp.Drop)) else None + if isinstance(target, exp.Schema): + target = target.this + + if isinstance(target, exp.Table): + return target.name.lower() + + return stripped[:80].lower() + + +def _normalize_ddl(sql: str) -> str: + return re.sub( + r"\bCREATE\s+TABLE\s+IF\s+NOT\s+EXISTS\b", + "CREATE TABLE", + sql, + flags=re.IGNORECASE, + ) + + +def _is_query_mirror_name(name: str) -> bool: + return name.lower().startswith(QUERY_MIRROR_PREFIX) + + +def _query_mirror_name(name: str) -> str: + return f"{QUERY_MIRROR_PREFIX}{name}" + + +def _query_mirror_table(table: exp.Table) -> exp.Table: + mirror = table.copy() + mirror.set("this", exp.to_identifier(_query_mirror_name(table.name), quoted=True)) + return mirror + + +def _ddl_statements_with_query_mirror(sql: str) -> t.List[str]: + try: + expressions = [ + expression + for expression in parse(sql, dialect=FELDERA_DIALECT) + if expression is not None + ] + except (ParseError, ValueError): + expressions = [] + + if not expressions: + statement = sql.rstrip(";") + ";" + mirror_sql = _query_mirror_sql(sql) + return [statement, *([mirror_sql.rstrip(";") + ";"] if mirror_sql else [])] + + statements = [] + for expression in expressions: + statement = expression.sql(dialect=FELDERA_DIALECT).rstrip(";") + ";" + statements.append(statement) + mirror_sql = _query_mirror_sql(statement) + if mirror_sql: + statements.append(mirror_sql.rstrip(";") + ";") + + return statements + + +def _query_mirror_sql(sql: str) -> t.Optional[str]: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + return None + + if not isinstance(expression, exp.Create): + return None + + target = expression.this + if isinstance(target, exp.Schema): + target = target.this + + if not isinstance(target, exp.Table) or _is_query_mirror_name(target.name): + return None + if target.db and target.db.lower().startswith("sqlmesh__"): + return None + + kind = str(expression.args.get("kind") or "").upper() + if "TABLE" not in kind and "VIEW" not in kind: + return None + if "VIEW" in kind and _is_materialized_view_sql(sql): + return None + + mirror_sql = exp.Create( + this=_query_mirror_table(target), + kind="MATERIALIZED VIEW", + expression=exp.select("*").from_(target.copy()), + ).sql(dialect=FELDERA_DIALECT) + + return _strip_table_qualifiers(mirror_sql) + + +def _rewrite_query_for_query_mirrors(sql: str, relation_names: t.Set[str]) -> str: + if not relation_names: + return sql + + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + return sql + + cte_names = { + cte.alias_or_name.lower() for cte in expression.find_all(exp.CTE) if cte.alias_or_name + } + + def transform(node: exp.Expression) -> exp.Expression: + if ( + isinstance(node, exp.Table) + and not _is_query_mirror_name(node.name) + and node.name.lower() in relation_names + and node.name.lower() not in cte_names + ): + return _query_mirror_table(node) + return node + + return expression.transform(transform, copy=True).sql(dialect=FELDERA_DIALECT) + + +def _execution_error(rows: t.Sequence[t.Mapping[str, t.Any]]) -> t.Optional[str]: + for row in rows: + for value in row.values(): + if isinstance(value, str) and value.startswith("Execution error:"): + return value + + return None + + +def _is_materialized_view_sql(sql: str) -> bool: + stripped = _strip_leading_comments(sql) + upper = stripped.upper() + + if "CREATE MATERIALIZED VIEW" in upper: + return True + if "CREATE VIEW" in upper: + return False + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + expression = None + + if isinstance(expression, exp.Create): + kind = str(expression.args.get("kind") or "").upper() + return "MATERIALIZED" in kind and "VIEW" in kind + + return False + + +def _rewrite_table_ctas_sql(sql: str) -> str: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + return sql + + if not isinstance(expression, exp.Create): + return sql + + kind = str(expression.args.get("kind") or "").upper() + query = expression.expression + target = expression.this + + if "TABLE" not in kind or query is None or not isinstance(target, exp.Table): + return sql + + columns_to_types = _select_columns_to_types(query) + if not columns_to_types: + return sql + + schema = exp.Schema( + this=target.copy(), + expressions=[ + exp.ColumnDef(this=exp.to_identifier(column, quoted=True), kind=data_type.copy()) + for column, data_type in columns_to_types.items() + ], + ) + create_exp = exp.Create( + this=schema, + kind=expression.args.get("kind") or "TABLE", + replace=bool(expression.args.get("replace")), + exists=bool(expression.args.get("exists")), + properties=expression.args.get("properties"), + ) + insert_exp = exp.insert(query.copy(), target.copy(), columns=list(columns_to_types)) + return f"{create_exp.sql(dialect=FELDERA_DIALECT)};\n{insert_exp.sql(dialect=FELDERA_DIALECT)}" + + +def _select_columns_to_types(query: exp.Expression) -> t.Optional[t.Dict[str, exp.DataType]]: + if not isinstance(query, exp.Query): + return None + + columns_to_types: t.Dict[str, exp.DataType] = {} + unknown = exp.DataType.build("unknown") + + for select in query.selects: + output_name = select.output_name + data_type = ( + _projection_type(t.cast(exp.Expression, select)) or (select.type or unknown).copy() + ) + + if not output_name or output_name in columns_to_types or data_type == unknown: + return None + + columns_to_types[output_name] = data_type + + return columns_to_types or None + + +def _projection_type(select: exp.Expression) -> t.Optional[exp.DataType]: + expression = select + if isinstance(select, exp.Alias): + expression = select.this + + if isinstance(expression, exp.Cast) and isinstance(expression.args.get("to"), exp.DataType): + return expression.args["to"].copy() + + return None + + +def _strip_leading_comments(sql: str) -> str: + stripped = sql.lstrip() + while stripped.startswith("/*"): + comment_end = stripped.find("*/") + if comment_end == -1: + return stripped + stripped = stripped[comment_end + 2 :].lstrip() + return stripped + + +def _strip_table_qualifiers(sql: str) -> str: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + return stripped + + def transform(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Table): + node = _canonicalize_snapshot_table(node) + node = _unqualify_table(node) + return node + + expression = expression.transform(transform, copy=True) + return expression.sql(dialect=FELDERA_DIALECT) + + +def _snapshot_name_parts(name: str) -> t.Optional[t.Tuple[str, str]]: + parts = name.split("__") + if len(parts) < 3: + return None + + if parts[-1].isdigit(): + schema_name = parts[0] + model_name = "__".join(parts[1:-1]) + return (schema_name, model_name) if model_name else None + + if len(parts) >= 4 and parts[-2].isdigit(): + schema_name = parts[0] + model_name = "__".join(parts[1:-2]) + return (schema_name, model_name) if model_name else None + + return None + + +def _canonicalize_snapshot_table(node: exp.Table) -> exp.Table: + if _is_query_mirror_name(node.name): + return node + + snapshot_parts = _snapshot_name_parts(node.name) + if snapshot_parts is None: + return node + + _, model_name = snapshot_parts + canonicalized = node.copy() + canonicalized.set("this", exp.to_identifier(model_name, quoted=True)) + canonicalized.set("db", None) + canonicalized.set("catalog", None) + return canonicalized + + +def _normalize_pipeline_ddl(sql: str) -> str: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + return stripped + + def transform(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Table): + node = _canonicalize_snapshot_table(node) + node = _unqualify_table(node) + return node + + return expression.transform(transform, copy=True).sql(dialect=FELDERA_DIALECT) + + +def _insert_to_input_json_payload( + sql: str, +) -> t.Optional[t.Tuple[str, t.List[t.Dict[str, t.Any]]]]: + stripped = _strip_leading_comments(sql) + + try: + expression = parse_one(stripped, dialect=FELDERA_DIALECT) + except (ParseError, ValueError): + return None + + if not isinstance(expression, exp.Insert): + return None + + target = expression.this + if isinstance(target, exp.Schema): + table = target.this + target_columns = [column.name for column in target.expressions] + elif isinstance(target, exp.Table): + table = target + target_columns = [] + else: + return None + + if not isinstance(table, exp.Table): + return None + + table = _canonicalize_snapshot_table(table) + + query = expression.expression + if not isinstance(query, exp.Query): + return None + + values = _query_values_source(query) + if values is None: + return None + + alias = values.args.get("alias") + if alias is None: + return None + + source_columns = [column.name for column in alias.columns] + if not source_columns: + return None + + if not target_columns: + target_columns = [select.output_name for select in query.selects] + + if len(target_columns) != len(query.selects): + return None + + rows = [] + for row in values.expressions: + if not isinstance(row, exp.Tuple): + return None + + source_row = { + column: _literal_value(value) for column, value in zip(source_columns, row.expressions) + } + payload_row: t.Dict[str, t.Any] = {} + for target_column, select in zip(target_columns, query.selects): + if not target_column: + return None + + evaluated: t.Any = _evaluate_insert_value_expression( + t.cast(exp.Expression, select), source_row + ) + if evaluated is _UNSUPPORTED_INGEST_EXPRESSION: + return None + payload_row[target_column] = evaluated + rows.append(payload_row) + + return table.name, rows + + +def _query_values_source(query: exp.Query) -> t.Optional[exp.Values]: + from_expression = query.args.get("from_") + if from_expression is None: + return None + + source = from_expression.this + if isinstance(source, exp.Values): + return source + if isinstance(source, exp.Subquery) and isinstance(source.this, exp.Query): + return _query_values_source(source.this) + return None + + +_UNSUPPORTED_INGEST_EXPRESSION = object() + + +def _evaluate_insert_value_expression( + expression: exp.Expression, source_row: t.Mapping[str, t.Any] +) -> t.Any: + if isinstance(expression, exp.Alias): + return _evaluate_insert_value_expression(expression.this, source_row) + + if isinstance(expression, exp.Cast): + value = _evaluate_insert_value_expression(expression.this, source_row) + if value is _UNSUPPORTED_INGEST_EXPRESSION: + return value + to_type = expression.args.get("to") + if isinstance(to_type, exp.DataType): + return _coerce_input_json_value(value, to_type) + return value + + if isinstance(expression, exp.Column): + return source_row.get(expression.name) + + if isinstance(expression, (exp.Literal, exp.Boolean, exp.Null)): + return _literal_value(expression) + + if isinstance(expression, exp.Paren): + return _evaluate_insert_value_expression(expression.this, source_row) + + if isinstance(expression, exp.Neg): + value = _evaluate_insert_value_expression(expression.this, source_row) + if isinstance(value, (int, float)): + return -value + return _UNSUPPORTED_INGEST_EXPRESSION + + return _UNSUPPORTED_INGEST_EXPRESSION + + +def _literal_value(expression: exp.Expression) -> t.Any: + if isinstance(expression, (exp.Literal, exp.Boolean, exp.Null)): + return expression.to_py() + return _UNSUPPORTED_INGEST_EXPRESSION + + +def _coerce_input_json_value(value: t.Any, data_type: exp.DataType) -> t.Any: + if value is None: + return None + + dtype = data_type.this + if dtype in { + exp.DataType.Type.TINYINT, + exp.DataType.Type.SMALLINT, + exp.DataType.Type.INT, + exp.DataType.Type.BIGINT, + }: + return int(value) + + if dtype in { + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + exp.DataType.Type.DECIMAL, + }: + return float(value) + + if dtype == exp.DataType.Type.BOOLEAN: + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"true", "t", "1"}: + return True + if normalized in {"false", "f", "0"}: + return False + return bool(value) + + if dtype in { + exp.DataType.Type.CHAR, + exp.DataType.Type.NCHAR, + exp.DataType.Type.TEXT, + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NVARCHAR, + exp.DataType.Type.DATE, + exp.DataType.Type.TIME, + exp.DataType.Type.TIMESTAMP, + exp.DataType.Type.TIMESTAMPTZ, + }: + return str(value) + + return value + + +def _unqualify_table(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Table): + node = node.copy() + node.set("db", None) + node.set("catalog", None) + return node + + +class FelderaCursor: + """DB-API 2.0 cursor backed by Feldera's REST API.""" + + def __init__( + self, + client: t.Any, + pipeline_name: str, + state_manager: t.Any, + workers: int = 4, + compilation_profile: str = "dev", + timeout: int = 300, + ) -> None: + self._client = client + self._pipeline_name = pipeline_name + self._state = state_manager + self._workers = workers + self._compilation_profile = compilation_profile + self._timeout = timeout + self._rows: t.List[t.Dict[str, t.Any]] = [] + self._columns: t.List[str] = [] + self.rowcount = -1 + self.description: t.Optional[t.List] = None + + def execute(self, sql: str, parameters: t.Optional[t.Any] = None) -> None: + if parameters is not None: + raise NotImplementedError("Feldera DB-API does not support query parameters") + + self.description = None + self.rowcount = -1 + + original_sql = sql + normalized_sql = _normalize_ddl(sql) + intent = _classify(normalized_sql) + sql = ( + _normalize_pipeline_ddl(normalized_sql) + if intent == SqlIntent.PIPELINE_DDL + else _strip_table_qualifiers(normalized_sql) + ) + logger.debug("Feldera execute (intent=%s): %.200s", intent.value, sql) + + if intent == SqlIntent.NO_OP: + self._rows = [] + self._columns = [] + return + + if intent == SqlIntent.PIPELINE_DDL: + if _is_virtual_layer_ddl(original_sql): + self._rows = [] + self._columns = [] + return + self._state.register_ddl(sql) + self._rows = [] + self._columns = [] + return + + if self._state.has_pending_changes(): + self._state.deploy( + self._client, + self._pipeline_name, + self._workers, + self._compilation_profile, + self._timeout, + ) + + if intent == SqlIntent.DATA_INGRESS: + pipeline = self._get_pipeline() + payload = _insert_to_input_json_payload(sql) + if payload is not None: + table_name, rows = payload + pipeline.input_json(table_name, rows) + else: + pipeline.execute(sql) + self._rows = [] + self._columns = [] + return + + query_sql = _rewrite_query_for_query_mirrors(sql, self._state.queryable_relation_names()) + rows = list(self._get_pipeline().query(query_sql)) + if error := _execution_error(rows): + raise RuntimeError(error) + self._rows = rows + self._columns = list(rows[0].keys()) if rows else [] + self.rowcount = len(rows) + self.description = [ + (column, None, None, None, None, None, None) for column in self._columns + ] + + def _get_pipeline(self) -> t.Any: + pipeline = self._state.current_pipeline() + if pipeline is not None: + return pipeline + + from feldera.pipeline import Pipeline + + return Pipeline.get(self._pipeline_name, self._client) + + def fetchone(self) -> t.Optional[t.Tuple]: + return tuple(self._rows.pop(0).values()) if self._rows else None + + def fetchmany(self, size: int = 1) -> t.List[t.Tuple]: + batch, self._rows = self._rows[:size], self._rows[size:] + return [tuple(row.values()) for row in batch] + + def fetchall(self) -> t.List[t.Tuple]: + rows, self._rows = self._rows, [] + return [tuple(row.values()) for row in rows] + + def fetchdf(self) -> pd.DataFrame: + import pandas as pd + + rows, self._rows = self._rows, [] + return pd.DataFrame(rows) + + def close(self) -> None: + self._rows = [] + + +class FelderaConnection: + """DB-API 2.0 connection wrapper around FelderaClient.""" + + _state_lock = threading.Lock() + _shared_states: t.Dict[t.Tuple[str, str], PipelineStateManager] = {} + + def __init__( + self, + client: t.Any, + host: str, + pipeline_name: str, + workers: int = 4, + compilation_profile: str = "dev", + timeout: int = 300, + ) -> None: + self._client = client + self._host = host + self._pipeline_name = pipeline_name + self._workers = workers + self._compilation_profile = compilation_profile + self._timeout = timeout + state_key = (host, pipeline_name) + with self._state_lock: + self._state: t.Any = self._shared_states.setdefault(state_key, PipelineStateManager()) + + @classmethod + def reset_shared_states(cls) -> None: + with cls._state_lock: + cls._shared_states.clear() + + def cursor(self) -> FelderaCursor: + return FelderaCursor( + self._client, + self._pipeline_name, + self._state, + self._workers, + self._compilation_profile, + self._timeout, + ) + + def commit(self) -> None: + return None + + def rollback(self) -> None: + return None + + def close(self) -> None: + if self._state.has_pending_changes(): + try: + self._state.deploy( + self._client, + self._pipeline_name, + self._workers, + self._compilation_profile, + self._timeout, + ) + except Exception as ex: + logger.error( + "Feldera pending DDL failed during connection close for pipeline %s:\n%s", + self._pipeline_name, + ex, + ) + raise + return None + + +def connect( + host: str, + pipeline_name: str, + api_key: t.Optional[str] = None, + timeout: int = 300, + workers: int = 4, + compilation_profile: str = "dev", +) -> FelderaConnection: + from feldera.rest.feldera_client import FelderaClient + + client = FelderaClient(url=host, api_key=api_key, timeout=timeout) + return FelderaConnection( + client, + host, + pipeline_name, + workers, + compilation_profile, + timeout, + ) diff --git a/sqlmesh/engines/feldera/dialect.py b/sqlmesh/engines/feldera/dialect.py new file mode 100644 index 0000000000..6362202266 --- /dev/null +++ b/sqlmesh/engines/feldera/dialect.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.generator import Generator as SQLGlotGenerator + + +class SQLMeshFelderaDialect(Dialect): + class Generator(SQLGlotGenerator): + TYPE_MAPPING = { + **SQLGlotGenerator.TYPE_MAPPING, + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.INT: "INTEGER", + } + TRANSFORMS = { + **SQLGlotGenerator.TRANSFORMS, + exp.CurrentTimestamp: lambda self, expression: ( + self.func("CURRENT_TIMESTAMP", expression.this) + if expression.this + else "CURRENT_TIMESTAMP" + ), + exp.DateStrToDate: lambda self, expression: ( + f"CAST({self.sql(expression, 'this')} AS DATE)" + ), + } + + +def register_feldera_dialect() -> None: + if Dialect.get("feldera") is None: + Dialect.classes["feldera"] = SQLMeshFelderaDialect + + +register_feldera_dialect() diff --git a/tests/core/engine_adapter/test_feldera.py b/tests/core/engine_adapter/test_feldera.py new file mode 100644 index 0000000000..a97b016276 --- /dev/null +++ b/tests/core/engine_adapter/test_feldera.py @@ -0,0 +1,411 @@ +# type: ignore + +import sys +import types +import typing as t + +import pytest +from sqlglot import Dialect +from sqlglot import parse_one +from sqlglot import exp + +from sqlmesh.core.engine_adapter import FelderaEngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType +from sqlmesh.engines.feldera.dialect import register_feldera_dialect +from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError +from tests.core.engine_adapter import to_sql_calls + +pytestmark = [pytest.mark.engine, pytest.mark.feldera] + + +@pytest.fixture +def adapter(make_mocked_engine_adapter: t.Callable) -> FelderaEngineAdapter: + adapter = make_mocked_engine_adapter(FelderaEngineAdapter, patch_get_data_objects=False) + connection = adapter._connection_pool.get() + connection._state.pending_drops.return_value = set() + connection._state.pending_tables.return_value = set() + connection._state.pending_views.return_value = set() + connection._state.is_materialized_view.side_effect = lambda _: False + return adapter + + +def _install_feldera_pipeline(monkeypatch: pytest.MonkeyPatch, pipeline: t.Any) -> None: + pipeline_cls = type( + "Pipeline", + (), + {"get": staticmethod(lambda pipeline_name, client: pipeline)}, + ) + pipeline_module = types.ModuleType("feldera.pipeline") + pipeline_module.Pipeline = pipeline_cls + + feldera_module = types.ModuleType("feldera") + feldera_module.pipeline = pipeline_module + + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + + +def test_columns_uses_pipeline_metadata( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + connection = adapter._connection_pool.get() + connection._client = object() + connection._pipeline_name = "configured_pipeline" + + pipeline = types.SimpleNamespace( + tables=lambda: [ + types.SimpleNamespace( + name="orders", + fields=[ + {"name": "id", "columntype": {"type": "INTEGER"}}, + {"name": "ratio", "columntype": {"type": "REAL"}}, + {"name": "name", "columntype": "VARCHAR(12)"}, + ], + ) + ], + views=lambda: [], + ) + _install_feldera_pipeline(monkeypatch, pipeline) + + assert adapter.columns("orders") == { + "id": exp.DataType.build("INT"), + "ratio": exp.DataType.build("FLOAT"), + "name": exp.DataType.build("VARCHAR"), + } + + +def test_columns_raises_sqlmesh_error_for_missing_object( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + connection = adapter._connection_pool.get() + connection._client = object() + connection._pipeline_name = "configured_pipeline" + + pipeline = types.SimpleNamespace(tables=lambda: [], views=lambda: []) + _install_feldera_pipeline(monkeypatch, pipeline) + + with pytest.raises(SQLMeshError, match="missing"): + adapter.columns("missing") + + +def test_get_data_objects_uses_requested_pipeline_name( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + connection = adapter._connection_pool.get() + connection._client = object() + connection._pipeline_name = "configured_pipeline" + + requested_pipeline_names: t.List[str] = [] + + def get_pipeline(pipeline_name: str, client: t.Any) -> t.Any: + requested_pipeline_names.append(pipeline_name) + return types.SimpleNamespace( + tables=lambda: [types.SimpleNamespace(name="source")], + views=lambda: [ + types.SimpleNamespace(name="sink"), + types.SimpleNamespace(name="__sqlmesh_query__source"), + ], + ) + + pipeline_cls = type("Pipeline", (), {"get": staticmethod(get_pipeline)}) + pipeline_module = types.ModuleType("feldera.pipeline") + pipeline_module.Pipeline = pipeline_cls + + feldera_module = types.ModuleType("feldera") + feldera_module.pipeline = pipeline_module + + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + + data_objects = adapter._get_data_objects("requested_pipeline") + + assert requested_pipeline_names == ["requested_pipeline"] + assert [(obj.schema_name, obj.name, obj.type) for obj in data_objects] == [ + ("requested_pipeline", "source", DataObjectType.TABLE), + ("requested_pipeline", "sink", DataObjectType.VIEW), + ] + + +def test_adapter_dialect_is_feldera(adapter: FelderaEngineAdapter) -> None: + assert adapter.dialect == "feldera" + + +def test_feldera_dialect_is_registered() -> None: + assert parse_one("SELECT 1", dialect="feldera").sql(dialect="feldera") == "SELECT 1" + + +def test_feldera_dialect_preserves_current_timestamp_keyword() -> None: + assert ( + parse_one("SELECT CURRENT_TIMESTAMP AS ts", dialect="feldera").sql(dialect="feldera") + == "SELECT CURRENT_TIMESTAMP AS ts" + ) + + +def test_register_feldera_dialect_registers_custom_type_mapping() -> None: + original = Dialect.classes.pop("feldera", None) + + try: + register_feldera_dialect() + assert exp.DataType.build("FLOAT").sql(dialect="feldera") == "REAL" + finally: + Dialect.classes.pop("feldera", None) + if original is not None: + Dialect.classes["feldera"] = original + else: + register_feldera_dialect() + + +def test_get_data_objects_marks_materialized_views_from_state( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + connection = adapter._connection_pool.get() + connection._client = object() + connection._state.is_materialized_view.side_effect = lambda name: name == "sink" + + pipeline = types.SimpleNamespace( + tables=lambda: [], + views=lambda: [types.SimpleNamespace(name="sink")], + ) + _install_feldera_pipeline(monkeypatch, pipeline) + + data_objects = adapter._get_data_objects("requested_pipeline") + + assert [(obj.schema_name, obj.name, obj.type) for obj in data_objects] == [ + ("requested_pipeline", "sink", DataObjectType.MATERIALIZED_VIEW), + ] + + +def test_replace_query_creates_table( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr(adapter, "_get_data_objects", lambda schema_name, object_names=None: []) + adapter.replace_query( + "db.full_model", + parse_one("SELECT a FROM tbl"), + {"a": exp.DataType.build("INT")}, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "db"."full_model" ("a" INTEGER)', + 'INSERT INTO "db"."full_model" ("a") SELECT "a" FROM "tbl"', + ] + + +def test_replace_query_recreates_existing_table( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject(schema="db", name="full_model", type=DataObjectType.TABLE), + ) + + adapter.replace_query( + "db.full_model", + parse_one("SELECT a FROM tbl"), + {"a": exp.DataType.build("INT")}, + ) + + assert to_sql_calls(adapter) == [ + 'DROP TABLE IF EXISTS "db"."full_model"', + 'CREATE TABLE IF NOT EXISTS "db"."full_model" ("a" INTEGER)', + 'INSERT INTO "db"."full_model" ("a") SELECT "a" FROM "tbl"', + ] + + +def test_insert_overwrite_by_time_partition_uses_delete_insert( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr(adapter, "_get_data_objects", lambda schema_name, object_names=None: []) + + adapter.insert_overwrite_by_time_partition( + "db.incremental_model", + parse_one("SELECT id, ds FROM source"), + start="2024-01-01", + end="2024-01-02", + time_column="ds", + time_formatter=lambda x, _: exp.Literal.string(str(x)[:10]), + target_columns_to_types={ + "id": exp.DataType.build("INT"), + "ds": exp.DataType.build("DATE"), + }, + ) + + assert to_sql_calls(adapter) == [ + 'DELETE FROM "db"."incremental_model" WHERE "ds" BETWEEN \'2024-01-01\' AND \'2024-01-02\'', + 'INSERT INTO "db"."incremental_model" ("id", "ds") SELECT "id", "ds" FROM (SELECT "id", "ds" FROM "source") AS "_subquery" WHERE "ds" BETWEEN \'2024-01-01\' AND \'2024-01-02\'', + ] + + +def test_insert_overwrite_without_condition_drops_and_recreates_table( + adapter: FelderaEngineAdapter, +) -> None: + recorded_calls: list[tuple[str, tuple[t.Any, ...], dict[str, t.Any]]] = [] + source_queries = [t.cast(t.Any, object())] + target_columns_to_types = {"a": exp.DataType.build("INT")} + + def record_drop_table(*args: t.Any, **kwargs: t.Any) -> None: + recorded_calls.append(("drop_table", args, kwargs)) + + def record_create_table(*args: t.Any, **kwargs: t.Any) -> None: + recorded_calls.append(("create_table", args, kwargs)) + + adapter.drop_table = record_drop_table # type: ignore[method-assign] + adapter._create_table_from_source_queries = record_create_table # type: ignore[method-assign] + + adapter._insert_overwrite_by_condition( + "db.full_model", + source_queries, + target_columns_to_types=target_columns_to_types, + where=None, + ) + + assert recorded_calls == [ + ("drop_table", ("db.full_model",), {}), + ( + "create_table", + ("db.full_model", source_queries), + { + "target_columns_to_types": target_columns_to_types, + "exists": True, + "replace": False, + }, + ), + ] + + +def test_create_table_from_source_queries_requires_known_column_types( + adapter: FelderaEngineAdapter, +) -> None: + with pytest.raises(SQLMeshError, match="requires known column types"): + adapter._create_table_from_source_queries("db.full_model", [], target_columns_to_types=None) + + +def test_create_schema_is_no_op(adapter: FelderaEngineAdapter) -> None: + assert adapter.create_schema("db") is None + + +def test_create_view_creates_view(adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(adapter, "_get_data_objects", lambda schema_name, object_names=None: []) + adapter.create_view( + "db.view_model", + parse_one("SELECT a FROM tbl"), + replace=False, + materialized=False, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE VIEW "db"."view_model" AS SELECT "a" FROM "tbl"', + ] + + +def test_create_view_rejects_catalog_qualified_names(adapter: FelderaEngineAdapter) -> None: + with pytest.raises(UnsupportedCatalogOperationError, match="does not support catalogs"): + adapter.create_view( + "catalog.db.view_model", + parse_one("SELECT a FROM tbl"), + ) + + +def test_replace_view_drops_then_creates_view( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject(schema="db", name="view_model", type=DataObjectType.VIEW), + ) + + adapter.create_view( + "db.view_model", + parse_one("SELECT a FROM tbl"), + replace=True, + materialized=False, + ) + + assert to_sql_calls(adapter) == [ + 'DROP VIEW IF EXISTS "db"."view_model"', + 'CREATE VIEW "db"."view_model" AS SELECT "a" FROM "tbl"', + ] + + +def test_create_materialized_view_creates_materialized_view( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr(adapter, "_get_data_objects", lambda schema_name, object_names=None: []) + adapter.create_view( + "db.materialized_view_model", + parse_one("SELECT a FROM tbl"), + replace=False, + materialized=True, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE MATERIALIZED VIEW "db"."materialized_view_model" AS SELECT "a" FROM "tbl"', + ] + + +def test_replace_materialized_view_drops_then_creates_materialized_view( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject( + schema="db", name="materialized_view_model", type=DataObjectType.MATERIALIZED_VIEW + ), + ) + + adapter.create_view( + "db.materialized_view_model", + parse_one("SELECT a FROM tbl"), + replace=True, + materialized=True, + ) + + assert to_sql_calls(adapter) == [ + 'DROP MATERIALIZED VIEW IF EXISTS "db"."materialized_view_model"', + 'CREATE MATERIALIZED VIEW "db"."materialized_view_model" AS SELECT "a" FROM "tbl"', + ] + + +def test_drop_table_drops_view(adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject(schema="db", name="view_model", type=DataObjectType.VIEW), + ) + + adapter.drop_table("db.view_model") + + assert to_sql_calls(adapter) == ['DROP VIEW IF EXISTS "db"."view_model"'] + + +def test_drop_table_drops_materialized_view( + adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject( + schema="db", name="materialized_view_model", type=DataObjectType.MATERIALIZED_VIEW + ), + ) + + adapter.drop_table("db.materialized_view_model") + + assert to_sql_calls(adapter) == [ + 'DROP MATERIALIZED VIEW IF EXISTS "db"."materialized_view_model"' + ] + + +def test_drop_table_drops_table(adapter: FelderaEngineAdapter, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + adapter, + "get_data_object", + lambda table: DataObject(schema="db", name="full_model", type=DataObjectType.TABLE), + ) + + adapter.drop_table("db.full_model") + + assert to_sql_calls(adapter) == ['DROP TABLE IF EXISTS "db"."full_model"'] diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index 2ff95525f7..e3ebf7a251 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -14,6 +14,7 @@ DatabricksConnectionConfig, DuckDBAttachOptions, FabricConnectionConfig, + FelderaConnectionConfig, DuckDBConnectionConfig, GCPPostgresConnectionConfig, MotherDuckConnectionConfig, @@ -1955,6 +1956,25 @@ def test_fabric_pyodbc_connection_string_generation(): assert call_args[1]["autocommit"] is True +def test_feldera_connection_config(make_config): + config = make_config(type="feldera", pipeline_name="pipeline", check_import=False) + + assert isinstance(config, FelderaConnectionConfig) + assert config.DIALECT == "feldera" + assert config.is_forbidden_for_state_sync + + with patch("sqlmesh.engines.feldera.db_api.connect") as mock_connect: + config._connection_factory_with_kwargs() + + mock_connect.assert_called_once_with( + host="http://localhost:8080", + pipeline_name="pipeline", + workers=4, + compilation_profile="dev", + timeout=300, + ) + + def test_schema_differ_overrides(make_config) -> None: default_config = make_config(type="duckdb") assert default_config.schema_differ_overrides is None diff --git a/tests/engines/feldera/test_db_api.py b/tests/engines/feldera/test_db_api.py new file mode 100644 index 0000000000..6d9cdb0fe2 --- /dev/null +++ b/tests/engines/feldera/test_db_api.py @@ -0,0 +1,940 @@ +import sys +import types +import typing as t + +import pytest +from sqlglot import parse_one + +from sqlmesh.engines.feldera import db_api + + +@pytest.fixture(autouse=True) +def reset_feldera_shared_states() -> t.Iterator[None]: + db_api.FelderaConnection.reset_shared_states() + yield + db_api.FelderaConnection.reset_shared_states() + + +def test_classify_treats_comment_prefixed_create_schema_as_pipeline_ddl() -> None: + assert db_api._classify("/* sqlmesh */ CREATE SCHEMA foo") == db_api.SqlIntent.PIPELINE_DDL + + +def test_is_virtual_layer_ddl_identifies_environment_alias_view() -> None: + assert db_api._is_virtual_layer_ddl( + 'CREATE VIEW "analytics__dev"."source_events" AS ' + 'SELECT * FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev"' + ) + + assert not db_api._is_virtual_layer_ddl( + 'CREATE MATERIALIZED VIEW "sqlmesh__analytics"."analytics__aggregate_view__1225616675__dev" AS ' + 'SELECT * FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev"' + ) + + +def test_strip_table_qualifiers_preserves_current_timestamp_keyword() -> None: + sql = ( + 'CREATE MATERIALIZED VIEW "db"."view_model" AS ' + 'SELECT CURRENT_TIMESTAMP AS ts FROM "db"."source"' + ) + + assert db_api._strip_table_qualifiers(sql) == ( + 'CREATE MATERIALIZED VIEW "view_model" AS SELECT CURRENT_TIMESTAMP AS ts FROM "source"' + ) + + +def test_normalize_pipeline_ddl_canonicalizes_snapshot_names_to_logical_names() -> None: + sql = ( + 'CREATE MATERIALIZED VIEW "sqlmesh__analytics"."analytics__aggregate_view__1225616675__dev" AS ' + 'SELECT "source_events"."entity_id" AS "entity_id" ' + 'FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev" AS "source_events"' + ) + + assert db_api._normalize_pipeline_ddl(sql) == ( + 'CREATE MATERIALIZED VIEW "aggregate_view" AS ' + 'SELECT "source_events"."entity_id" AS "entity_id" ' + 'FROM "source_events" AS "source_events"' + ) + + +def test_normalize_pipeline_ddl_strips_schema_from_logical_tables() -> None: + sql = 'CREATE TABLE "analytics"."sample_seed" ("entity_id" VARCHAR, "description" VARCHAR)' + + assert db_api._normalize_pipeline_ddl(sql) == ( + 'CREATE TABLE "sample_seed" ("entity_id" VARCHAR, "description" VARCHAR)' + ) + + +def test_cursor_defers_pipeline_deploy_until_non_ddl_statement() -> None: + class FakeStateManager: + def __init__(self) -> None: + self.pending = False + self.deploy_calls = 0 + self.pipeline = types.SimpleNamespace(query=lambda sql: [{"a": 1}]) + + def register_ddl(self, sql: str) -> None: + self.pending = True + + def has_pending_changes(self) -> bool: + return self.pending + + def deploy(self, *args: object, **kwargs: object) -> object: + self.deploy_calls += 1 + self.pending = False + return self.pipeline + + def current_pipeline(self) -> object: + return self.pipeline + + def queryable_relation_names(self) -> set[str]: + return set() + + state_manager = t.cast(t.Any, FakeStateManager()) + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=state_manager, + ) + + cursor.execute("/* sqlmesh */ CREATE TABLE foo (id INT)") + + assert state_manager.deploy_calls == 0 + + cursor.execute("SELECT 1 AS a") + + assert state_manager.deploy_calls == 1 + assert cursor.fetchall() == [(1,)] + + +def test_cursor_clears_description_after_non_query_execute() -> None: + class FakePipeline: + def query(self, sql: str) -> list[dict[str, int]]: + return [{"a": 1}] + + def input_json(self, table_name: str, rows: object) -> None: + return None + + def execute(self, sql: str) -> None: + return None + + class FakeStateManager: + def __init__(self) -> None: + self.pipeline = FakePipeline() + + def has_pending_changes(self) -> bool: + return False + + def current_pipeline(self) -> object: + return self.pipeline + + def queryable_relation_names(self) -> set[str]: + return set() + + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=t.cast(t.Any, FakeStateManager()), + ) + + cursor.execute("SELECT 1 AS a") + assert cursor.description == [("a", None, None, None, None, None, None)] + + cursor.execute('INSERT INTO "seed_model" SELECT 1 AS "a" FROM (VALUES (1)) AS "t"("a")') + + assert cursor.description is None + + +def test_cursor_ignores_virtual_layer_view_ddl() -> None: + class FakeStateManager: + def __init__(self) -> None: + self.registered_sql: list[str] = [] + + def register_ddl(self, sql: str) -> None: + self.registered_sql.append(sql) + + def has_pending_changes(self) -> bool: + return False + + state_manager = t.cast(t.Any, FakeStateManager()) + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=state_manager, + ) + + cursor.execute( + 'CREATE VIEW "analytics__dev"."source_events" AS ' + 'SELECT * FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev"' + ) + + assert state_manager.registered_sql == [] + + +def test_cursor_registers_logical_model_names_in_pipeline_ddl() -> None: + class FakeStateManager: + def __init__(self) -> None: + self.registered_sql: list[str] = [] + + def register_ddl(self, sql: str) -> None: + self.registered_sql.append(sql) + + def has_pending_changes(self) -> bool: + return False + + state_manager = t.cast(t.Any, FakeStateManager()) + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=state_manager, + ) + + cursor.execute( + 'CREATE MATERIALIZED VIEW "sqlmesh__analytics"."analytics__aggregated_observations__1956259901__dev" AS ' + 'SELECT "source_events"."entity_id" AS "entity_id" ' + 'FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev" AS "source_events"' + ) + + assert state_manager.registered_sql == [ + 'CREATE MATERIALIZED VIEW "aggregated_observations" AS ' + 'SELECT "source_events"."entity_id" AS "entity_id" ' + 'FROM "source_events" AS "source_events"' + ] + + +def test_hydrate_existing_program_skips_empty_parse_results(monkeypatch) -> None: + manager = db_api.PipelineStateManager() + + monkeypatch.setattr( + db_api, + "parse", + lambda sql, **kwargs: [None, parse_one("CREATE TABLE foo (id INT)", **kwargs)], + ) + + pipeline_module = types.ModuleType("feldera.pipeline") + setattr( + pipeline_module, + "Pipeline", + type( + "Pipeline", + (), + { + "get": staticmethod( + lambda pipeline_name, client: types.SimpleNamespace( + _inner=types.SimpleNamespace(program_code="ignored") + ) + ) + }, + ), + ) + feldera_module = types.ModuleType("feldera") + setattr(feldera_module, "pipeline", pipeline_module) + + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + + manager._hydrate_existing_program(object(), "test_pipeline") + + assert manager.pending_tables() == {"foo"} + + +def test_state_manager_tracks_view_materialization() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl("CREATE VIEW regular_view AS SELECT 1") + manager.register_ddl("CREATE MATERIALIZED VIEW materialized_view AS SELECT 1") + + assert manager.pending_views() == {"regular_view", "materialized_view"} + assert manager.is_materialized_view("regular_view") is False + assert manager.is_materialized_view("materialized_view") is True + + +def test_state_manager_rewrites_table_ctas() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl( + 'CREATE TABLE "seed_model" AS SELECT CAST("id" AS INTEGER) AS "id" FROM (VALUES (1)) AS "t"("id")' + ) + + assert manager.assemble_program() == ( + 'CREATE TABLE "seed_model" ("id" INTEGER);\n' + "\n" + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__seed_model" AS SELECT * FROM "seed_model";\n' + "\n" + 'INSERT INTO "seed_model" (id) SELECT CAST("id" AS INTEGER) AS "id" FROM (VALUES (1)) AS "t"("id");' + ) + + +def test_evict_hydrated_objects_removes_stale_object_from_compile_error() -> None: + manager = db_api.PipelineStateManager() + manager._views = { + "analytics__aggregate_view__781619724__dev": ( + 'CREATE MATERIALIZED VIEW "analytics__aggregate_view__781619724__dev" AS ' + "SELECT CURRENT_TIMESTAMP AS ts" + ), + "analytics__aggregate_view__1225616675__dev": ( + 'CREATE MATERIALIZED VIEW "analytics__aggregate_view__1225616675__dev" AS ' + "SELECT NOW() AS ts" + ), + } + manager._hydrated_object_keys = { + "analytics__aggregate_view__781619724__dev", + "analytics__aggregate_view__1225616675__dev", + } + + removed = manager._evict_hydrated_objects( + 'Compilation error in CREATE MATERIALIZED VIEW "analytics__aggregate_view__781619724__dev"' + ) + + assert removed is True + assert "analytics__aggregate_view__781619724__dev" not in manager.pending_views() + assert "analytics__aggregate_view__1225616675__dev" in manager.pending_views() + + +def test_format_compile_error_preserves_sql_compilation_details() -> None: + manager = db_api.PipelineStateManager() + + class FakeClient: + def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: + return types.SimpleNamespace( + program_status="SqlError", + program_error={ + "sql_compilation": { + "messages": [ + { + "error_type": "Compilation error", + "message": "Object 'analytics__source_events__1782741465__dev' not found", + "snippet": '1|CREATE MATERIALIZED VIEW "analytics__aggregate_view__1225616675__dev" AS SELECT ...', + } + ] + } + }, + ) + + error = manager._format_compile_error( + FakeClient(), + "test_pipeline", + RuntimeError("The program failed to compile: SqlError"), + ) + + assert str(error) == ( + "Pipeline test_pipeline failed to compile:\n" + "Compilation error\n" + "Object 'analytics__source_events__1782741465__dev' not found\n" + 'Code snippet:\n1|CREATE MATERIALIZED VIEW "analytics__aggregate_view__1225616675__dev" AS SELECT ...' + ) + + +def test_format_compile_error_preserves_rust_and_system_error_details() -> None: + manager = db_api.PipelineStateManager() + + class FakeClient: + def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: + return types.SimpleNamespace( + program_status="RustError", + program_error={ + "rust_compilation": "rust failed", + "system_error": "system failed", + }, + ) + + error = manager._format_compile_error( + FakeClient(), + "test_pipeline", + RuntimeError("The program failed to compile: RustError"), + ) + + assert str(error) == ( + "The program failed to compile: RustError\n" + "Rust Error: rust failed\n" + "System Error: system failed" + ) + + +def test_format_compile_error_returns_original_message_without_program_error() -> None: + manager = db_api.PipelineStateManager() + + class FakeClient: + def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: + return types.SimpleNamespace(program_status="Unknown", program_error=None) + + error = manager._format_compile_error(FakeClient(), "test_pipeline", RuntimeError("boom")) + + assert str(error) == "boom" + + +def test_format_compile_error_requests_all_pipeline_fields(monkeypatch) -> None: + manager = db_api.PipelineStateManager() + requested_field_selector: list[object] = [] + selector_all = object() + + enums_module = types.ModuleType("feldera.enums") + setattr(enums_module, "PipelineFieldSelector", types.SimpleNamespace(ALL=selector_all)) + monkeypatch.setitem(sys.modules, "feldera.enums", enums_module) + + class FakeClient: + def get_pipeline(self, pipeline_name: str, field_selector: object) -> object: + requested_field_selector.append(field_selector) + return types.SimpleNamespace(program_error={}, program_status="Unknown") + + manager._format_compile_error(FakeClient(), "test_pipeline", RuntimeError("boom")) + + assert requested_field_selector == [selector_all] + + +def test_deploy_succeeds_with_mocked_feldera_modules(monkeypatch) -> None: + manager = db_api.PipelineStateManager() + + class CompilationProfile(str): + pass + + class Pipeline: + @staticmethod + def get(name: str, client: object) -> object: + return types.SimpleNamespace(_inner=types.SimpleNamespace(program_code="")) + + class PipelineBuilder: + def __init__(self, *args: object, **kwargs: object) -> None: + pass + + def create_or_replace(self, wait: bool = True) -> object: + return types.SimpleNamespace(start=lambda: None) + + class RuntimeConfig: + def __init__(self) -> None: + self.workers = 0 + + @classmethod + def default(cls) -> "RuntimeConfig": + return cls() + + def to_dict(self) -> dict[str, object]: + return {} + + feldera_module = types.ModuleType("feldera") + enums_module = types.ModuleType("feldera.enums") + pipeline_module = types.ModuleType("feldera.pipeline") + pipeline_builder_module = types.ModuleType("feldera.pipeline_builder") + runtime_config_module = types.ModuleType("feldera.runtime_config") + rest_module = types.ModuleType("feldera.rest") + rest_errors_module = types.ModuleType("feldera.rest.errors") + rest_pipeline_module = types.ModuleType("feldera.rest.pipeline") + + setattr(enums_module, "CompilationProfile", CompilationProfile) + setattr(pipeline_module, "Pipeline", Pipeline) + setattr(pipeline_builder_module, "PipelineBuilder", PipelineBuilder) + setattr(runtime_config_module, "RuntimeConfig", RuntimeConfig) + setattr(rest_errors_module, "FelderaAPIError", RuntimeError) + setattr(rest_pipeline_module, "Pipeline", type("InnerPipeline", (), {})) + setattr(rest_module, "errors", rest_errors_module) + setattr(rest_module, "pipeline", rest_pipeline_module) + setattr(feldera_module, "enums", enums_module) + setattr(feldera_module, "pipeline", pipeline_module) + setattr(feldera_module, "pipeline_builder", pipeline_builder_module) + setattr(feldera_module, "runtime_config", runtime_config_module) + setattr(feldera_module, "rest", rest_module) + + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.enums", enums_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline_builder", pipeline_builder_module) + monkeypatch.setitem(sys.modules, "feldera.runtime_config", runtime_config_module) + monkeypatch.setitem(sys.modules, "feldera.rest", rest_module) + monkeypatch.setitem(sys.modules, "feldera.rest.errors", rest_errors_module) + monkeypatch.setitem(sys.modules, "feldera.rest.pipeline", rest_pipeline_module) + + manager.deploy(object(), "test_pipeline") + + +def test_compile_program_does_not_wait_after_stop() -> None: + manager = db_api.PipelineStateManager() + wait_calls: list[tuple[object, int]] = [] + stop_calls: list[tuple[bool, t.Optional[float]]] = [] + dismiss_error_calls = 0 + + class ExistingPipeline: + def stop(self, force: bool = True, timeout_s: t.Optional[float] = None) -> None: + stop_calls.append((force, timeout_s)) + + def wait_for_status(self, status: object, timeout: int = 0) -> None: + wait_calls.append((status, timeout)) + + def dismiss_error(self) -> None: + nonlocal dismiss_error_calls + dismiss_error_calls += 1 + + class Pipeline: + @staticmethod + def get(name: str, client: object) -> object: + return ExistingPipeline() + + def __init__(self, client: object) -> None: + self._inner = None + + class InnerPipeline: + def __init__(self, **kwargs: object) -> None: + pass + + class Profile: + value = "dev" + + class RuntimeConfig: + def to_dict(self) -> dict[str, object]: + return {} + + class PipelineBuilder: + def __init__(self, *args: object, **kwargs: object) -> None: + raise AssertionError( + "PipelineBuilder should not be constructed when the pipeline exists" + ) + + client = types.SimpleNamespace(create_or_update_pipeline=lambda pipeline, wait=True: pipeline) + + manager._compile_program( + client, + "test_pipeline", + "SELECT 1", + Profile(), + RuntimeConfig(), + 300, + Pipeline, + PipelineBuilder, + InnerPipeline, + RuntimeError, + ) + + assert stop_calls == [(True, 300)] + assert dismiss_error_calls == 1 + assert wait_calls == [] + + +def test_deploy_does_not_wait_after_start(monkeypatch) -> None: + manager = db_api.PipelineStateManager() + manager._dirty = True + start_calls: list[t.Optional[float]] = [] + wait_calls: list[tuple[object, int]] = [] + + class Pipeline: + def start(self, timeout_s: t.Optional[float] = None) -> None: + start_calls.append(timeout_s) + + def wait_for_status(self, status: object, timeout: int = 0) -> None: + wait_calls.append((status, timeout)) + + class CompilationProfile(str): + pass + + class RuntimeConfig: + def __init__(self) -> None: + self.workers = 0 + + @classmethod + def default(cls) -> "RuntimeConfig": + return cls() + + def to_dict(self) -> dict[str, object]: + return {} + + enums_module = types.ModuleType("feldera.enums") + pipeline_module = types.ModuleType("feldera.pipeline") + pipeline_builder_module = types.ModuleType("feldera.pipeline_builder") + runtime_config_module = types.ModuleType("feldera.runtime_config") + rest_errors_module = types.ModuleType("feldera.rest.errors") + rest_pipeline_module = types.ModuleType("feldera.rest.pipeline") + + setattr(enums_module, "CompilationProfile", CompilationProfile) + setattr( + pipeline_module, + "Pipeline", + type("PipelineClass", (), {"get": staticmethod(lambda n, c: None)}), + ) + setattr(pipeline_builder_module, "PipelineBuilder", object) + setattr(runtime_config_module, "RuntimeConfig", RuntimeConfig) + setattr(rest_errors_module, "FelderaAPIError", RuntimeError) + setattr(rest_pipeline_module, "Pipeline", type("InnerPipeline", (), {})) + + monkeypatch.setitem(sys.modules, "feldera.enums", enums_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline_builder", pipeline_builder_module) + monkeypatch.setitem(sys.modules, "feldera.runtime_config", runtime_config_module) + monkeypatch.setitem(sys.modules, "feldera.rest.errors", rest_errors_module) + monkeypatch.setitem(sys.modules, "feldera.rest.pipeline", rest_pipeline_module) + + monkeypatch.setattr(manager, "_hydrate_existing_program", lambda client, pipeline_name: None) + monkeypatch.setattr(manager, "assemble_program", lambda: "CREATE TABLE x (id INT)") + monkeypatch.setattr(manager, "_compile_program", lambda *args, **kwargs: Pipeline()) + + manager.deploy(object(), "test_pipeline") + + assert start_calls == [300] + assert wait_calls == [] + + +def test_connection_close_logs_pending_compile_error(caplog: pytest.LogCaptureFixture) -> None: + class FakeStateManager: + def has_pending_changes(self) -> bool: + return True + + def deploy(self, *args: object, **kwargs: object) -> object: + raise RuntimeError( + "Pipeline test_pipeline failed to compile:\n" + "Compilation error\n" + "TIMESTAMP_TRUNC problem" + ) + + connection = db_api.FelderaConnection( + client=object(), + host="http://localhost:8080", + pipeline_name="test_pipeline", + ) + connection._state = t.cast(t.Any, FakeStateManager()) + + with caplog.at_level("ERROR", logger="sqlmesh.engines.feldera.db_api"): + with pytest.raises(RuntimeError, match="TIMESTAMP_TRUNC problem"): + connection.close() + + assert ( + "Feldera pending DDL failed during connection close for pipeline test_pipeline" + in caplog.text + ) + assert "TIMESTAMP_TRUNC problem" in caplog.text + + +def test_hydrate_existing_program_returns_silently_when_pipeline_lookup_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + manager = db_api.PipelineStateManager() + + pipeline_module = types.ModuleType("feldera.pipeline") + setattr( + pipeline_module, + "Pipeline", + type( + "Pipeline", + (), + { + "get": staticmethod( + lambda pipeline_name, client: (_ for _ in ()).throw(RuntimeError("missing")) + ) + }, + ), + ) + feldera_module = types.ModuleType("feldera") + setattr(feldera_module, "pipeline", pipeline_module) + + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + + manager._hydrate_existing_program(object(), "test_pipeline") + + assert manager.pending_tables() == set() + assert manager.pending_views() == set() + + +def test_state_manager_adds_query_mirrors_for_non_materialized_relations() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl('CREATE TABLE "full_model" ("id" INT)') + manager.register_ddl('CREATE VIEW "view_model" AS SELECT "id" FROM "full_model"') + manager.register_ddl( + 'CREATE MATERIALIZED VIEW "materialized_view_model" AS SELECT "id" FROM "full_model"' + ) + + program = manager.assemble_program() + + assert ( + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__full_model" AS SELECT * FROM "full_model";' + in program + ) + assert ( + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__view_model" AS SELECT * FROM "view_model";' + in program + ) + assert 'CREATE MATERIALIZED VIEW "__sqlmesh_query__materialized_view_model"' not in program + + +def test_state_manager_query_mirror_strips_schema_qualifiers() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl( + 'CREATE TABLE "analytics"."sample_seed" AS ' + 'SELECT CAST("id" AS INTEGER) AS "id" FROM (VALUES (1)) AS "t"("id")' + ) + + program = manager.assemble_program() + + assert ( + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__sample_seed" AS SELECT * FROM "sample_seed";' + in program + ) + assert 'SELECT * FROM "analytics"."sample_seed"' not in program + + +def test_state_manager_skips_query_mirrors_for_sqlmesh_internal_objects() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl( + 'CREATE TABLE "sqlmesh__analytics"."analytics__source_snapshot__441175831__dev" ' + '("entity_id" VARCHAR, "metric_value" DOUBLE)' + ) + + program = manager.assemble_program() + + assert "__sqlmesh_query__analytics__source_snapshot__441175831__dev" not in program + + +def test_state_manager_canonicalizes_snapshot_tables_to_logical_names() -> None: + manager = db_api.PipelineStateManager() + + manager.register_ddl( + db_api._normalize_pipeline_ddl( + 'CREATE TABLE "sqlmesh__analytics"."analytics__records__849752499__dev" ' + '("entity_id" VARCHAR, "description" VARCHAR)' + ) + ) + + program = manager.assemble_program() + + assert 'CREATE TABLE "records" ("entity_id" VARCHAR, "description" VARCHAR);' in program + assert ( + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__records" AS SELECT * FROM "records";' in program + ) + assert "analytics__records__849752499__dev" not in program + + +def test_hydrate_existing_program_skips_query_mirrors(monkeypatch) -> None: + manager = db_api.PipelineStateManager() + + pipeline_module = types.ModuleType("feldera.pipeline") + setattr( + pipeline_module, + "Pipeline", + type( + "Pipeline", + (), + { + "get": staticmethod( + lambda pipeline_name, client: types.SimpleNamespace( + _inner=types.SimpleNamespace( + program_code=( + 'CREATE TABLE "full_model" ("id" INT);\n' + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__full_model" AS SELECT * FROM "full_model";\n' + 'CREATE VIEW "view_model" AS SELECT "id" FROM "full_model";\n' + 'CREATE MATERIALIZED VIEW "__sqlmesh_query__view_model" AS SELECT * FROM "view_model";' + ) + ) + ) + ) + }, + ), + ) + feldera_module = types.ModuleType("feldera") + setattr(feldera_module, "pipeline", pipeline_module) + + monkeypatch.setitem(sys.modules, "feldera", feldera_module) + monkeypatch.setitem(sys.modules, "feldera.pipeline", pipeline_module) + + manager._hydrate_existing_program(object(), "test_pipeline") + + assert manager.pending_tables() == {"full_model"} + assert manager.pending_views() == {"view_model"} + + +def test_cursor_rewrites_queries_to_query_mirrors() -> None: + captured_queries: list[str] = [] + + class FakeStateManager: + def has_pending_changes(self) -> bool: + return False + + def queryable_relation_names(self) -> set[str]: + return {"full_model", "view_model"} + + def current_pipeline(self) -> object: + def query(sql: str) -> list[dict[str, int]]: + captured_queries.append(sql) + return [{"count": 1}] + + return types.SimpleNamespace(query=query) + + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=t.cast(t.Any, FakeStateManager()), + ) + + cursor.execute('SELECT COUNT(*) FROM "full_model"') + + assert ( + parse_one(captured_queries[0]).sql() + == parse_one('SELECT COUNT(*) FROM "__sqlmesh_query__full_model"').sql() + ) + assert cursor.fetchone() == (1,) + + +def test_cursor_rewrites_snapshot_queries_to_logical_query_mirrors() -> None: + captured_queries: list[str] = [] + + class FakeStateManager: + def has_pending_changes(self) -> bool: + return False + + def queryable_relation_names(self) -> set[str]: + return {"source_events"} + + def current_pipeline(self) -> object: + def query(sql: str) -> list[dict[str, int]]: + captured_queries.append(sql) + return [{"count": 1}] + + return types.SimpleNamespace(query=query) + + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=t.cast(t.Any, FakeStateManager()), + ) + + cursor.execute( + 'SELECT COUNT(*) FROM "sqlmesh__analytics"."analytics__source_events__1782741465__dev"' + ) + + assert ( + parse_one(captured_queries[0]).sql() + == parse_one('SELECT COUNT(*) FROM "__sqlmesh_query__source_events"').sql() + ) + assert cursor.fetchone() == (1,) + + +def test_insert_to_input_json_payload_rewrites_seed_values() -> None: + payload = db_api._insert_to_input_json_payload( + 'INSERT INTO "seed_model" ("entity_id", "score", "event_ts") ' + 'SELECT CAST("entity_id" AS TEXT) AS "entity_id", ' + 'CAST("score" AS DOUBLE) AS "score", ' + 'CAST("event_ts" AS TIMESTAMP) AS "event_ts" ' + "FROM (VALUES " + "('123', '4.5', '2026-05-09 01:00:00'), " + "('456', '7.0', NULL)" + ') AS "t"("entity_id", "score", "event_ts")' + ) + + assert payload == ( + "seed_model", + [ + { + "entity_id": "123", + "score": 4.5, + "event_ts": "2026-05-09 01:00:00", + }, + { + "entity_id": "456", + "score": 7.0, + "event_ts": None, + }, + ], + ) + + +def test_insert_to_input_json_payload_canonicalizes_snapshot_target() -> None: + payload = db_api._insert_to_input_json_payload( + 'INSERT INTO "sqlmesh__analytics"."analytics__sample_seed__2883946936__dev" ("entity_id") ' + 'SELECT CAST("entity_id" AS TEXT) AS "entity_id" ' + 'FROM (VALUES (\'123\')) AS "t"("entity_id")' + ) + + assert payload == ( + "sample_seed", + [ + { + "entity_id": "123", + } + ], + ) + + +def test_cursor_uses_input_json_for_seed_inserts() -> None: + captured_rows = [] + executed_sql = [] + + class FakePipeline: + def input_json(self, table_name: str, data: object, **kwargs: object) -> None: + captured_rows.append((table_name, data, kwargs)) + + def execute(self, sql: str) -> None: + executed_sql.append(sql) + + class FakeStateManager: + def has_pending_changes(self) -> bool: + return False + + def current_pipeline(self) -> object: + return FakePipeline() + + def queryable_relation_names(self) -> set[str]: + return set() + + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=t.cast(t.Any, FakeStateManager()), + ) + + cursor.execute( + 'INSERT INTO "seed_model" ("a", "b") ' + 'SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" ' + 'FROM (VALUES (1, 4), (2, 5)) AS "t"("a", "b")' + ) + + assert executed_sql == [] + assert captured_rows == [ + ( + "seed_model", + [ + {"a": 1, "b": 4}, + {"a": 2, "b": 5}, + ], + {}, + ) + ] + + +def test_cursor_raises_execution_error_from_query_rows() -> None: + class FakeStateManager: + def has_pending_changes(self) -> bool: + return False + + def queryable_relation_names(self) -> set[str]: + return set() + + def current_pipeline(self) -> object: + return types.SimpleNamespace( + query=lambda sql: [{"COUNT(*)": "Execution error: test failure"}] + ) + + cursor = db_api.FelderaCursor( + client=object(), + pipeline_name="test_pipeline", + state_manager=t.cast(t.Any, FakeStateManager()), + ) + + with pytest.raises(RuntimeError, match="Execution error: test failure"): + cursor.execute("SELECT COUNT(*) FROM full_model") + + +def test_connect_preserves_integer_timeout(monkeypatch) -> None: + captured_kwargs: dict[str, object] = {} + + class FelderaClient: + def __init__(self, **kwargs: object) -> None: + captured_kwargs.update(kwargs) + + feldera_client_module = types.ModuleType("feldera.rest.feldera_client") + setattr(feldera_client_module, "FelderaClient", FelderaClient) + monkeypatch.setitem(sys.modules, "feldera.rest.feldera_client", feldera_client_module) + + db_api.connect( + host="http://localhost:8080", + pipeline_name="test_pipeline", + timeout=123, + ) + + assert captured_kwargs["timeout"] == 123 + assert isinstance(captured_kwargs["timeout"], int)