diff --git a/docs/examples/patterns/migrations_with_schema.py b/docs/examples/patterns/migrations_with_schema.py new file mode 100644 index 000000000..a1a6a61ab --- /dev/null +++ b/docs/examples/patterns/migrations_with_schema.py @@ -0,0 +1,72 @@ +from pathlib import Path + +__all__ = ("test_migrations_with_schema",) + + +def test_migrations_with_schema(tmp_path: Path) -> None: + # start-example + from sqlspec.adapters.duckdb import DuckDBConfig + from sqlspec.migrations.commands import SyncMigrationCommands + + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "app.duckdb" + + config = DuckDBConfig( + connection_config={"database": str(db_path)}, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": "schema_versions", + "default_schema": "app_schema", + "version_table_schema": "admin_schema", + }, + ) + + try: + with config.provide_session() as session: + session.execute("CREATE SCHEMA app_schema") + session.execute("CREATE SCHEMA admin_schema") + + commands = SyncMigrationCommands(config) + commands.init(str(migration_dir), package=True) + + (migration_dir / "0001_create_users.py").write_text( + '''"""Create users.""" + + +def up(): + """Create an unqualified table in app_schema.""" + return ["CREATE TABLE users (id INTEGER PRIMARY KEY, name VARCHAR NOT NULL)"] + + +def down(): + """Drop the unqualified table from app_schema.""" + return ["DROP TABLE IF EXISTS users"] +''' + ) + + commands.upgrade() + + with config.provide_session() as session: + users_table = session.select_value( + """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema = ? AND table_name = ? + """, + ("app_schema", "users"), + ) + tracker_table = session.select_value( + """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema = ? AND table_name = ? + """, + ("admin_schema", "schema_versions"), + ) + + assert users_table == "users" + assert tracker_table == "schema_versions" + finally: + if config.connection_instance: + config.close_pool() + # end-example diff --git a/docs/usage/migrations.rst b/docs/usage/migrations.rst index 2e5cd1b62..fecd38ab0 100644 --- a/docs/usage/migrations.rst +++ b/docs/usage/migrations.rst @@ -74,6 +74,68 @@ specific extension, ``migration_config["include_extensions"]`` to opt in explicitly by extension name, or ``migration_config["enabled"] = False`` to disable migrations entirely for a database config. +Configuring a Default Schema +---------------------------- + +Use ``migration_config["default_schema"]`` when migration SQL should run +against a pre-existing schema or dataset without qualifying every table in each +migration file. SQLSpec validates the schema before creating the tracker table +or applying DDL, then configures the migration session before each migration is +executed. + +Use ``migration_config["version_table_schema"]`` when the migration tracker +table should live somewhere different from the objects managed by migrations. +If ``version_table_schema`` is not set, the tracker schema resolves to +``default_schema``. If neither field is set, the tracker table is unqualified and +uses the adapter's normal default namespace. + +.. code-block:: python + + from sqlspec.adapters.asyncpg import AsyncpgConfig + + config = AsyncpgConfig( + connection_config={"dsn": "postgresql://localhost/app"}, + migration_config={ + "script_location": "migrations/postgres", + "version_table_name": "schema_versions", + "default_schema": "app_schema", + "version_table_schema": "admin_schema", + }, + ) + +The operator must create the target schema or dataset before running +migrations. The migration role also needs the database-specific privileges to +create objects there. For PostgreSQL, that usually means ``USAGE`` and +``CREATE`` on the target schema, plus permission to create or update the +tracker table. + +Adapter support: + +.. list-table:: + :header-rows: 1 + + * - Adapter + - Behavior + * - ``asyncpg``, ``psycopg``, ``psqlpy``, ADBC PostgreSQL + - Uses PostgreSQL ``search_path`` and validates ``information_schema.schemata``. + * - ``oracledb`` + - Uses ``ALTER SESSION SET CURRENT_SCHEMA`` and validates Oracle users. + * - ``duckdb`` + - Uses ``SET search_path`` and validates ``information_schema.schemata``. + * - ``bigquery`` + - Treats schemas as datasets and sets the BigQuery job ``default_dataset``. + * - ``sqlite``, ``aiosqlite``, ``asyncmy`` + - Accept the setting as an explicit no-op and log that default schemas are unsupported. + * - ADBC SQL Server + - Accepts the setting as a no-op; configure the default schema at the user or login level. + +Example with unqualified DDL: + +.. literalinclude:: /examples/patterns/migrations_with_schema.py + :language: python + :start-after: # start-example + :end-before: # end-example + Logging and Echo Controls ------------------------- diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index edf9f3ba6..b30126cb2 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -240,6 +240,9 @@ def __init__( observability_config=observability_config, **kwargs, ) + object.__setattr__( + self, "supports_migration_schemas", is_postgres_dialect(resolve_dialect_from_config(self.connection_config)) + ) def create_connection(self) -> AdbcConnection: """Create and return a new connection using the specified driver. diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index fb67fe468..1fa2f753d 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -28,6 +28,7 @@ from sqlspec.core import SQL, StatementConfig, build_arrow_result_from_table, get_cache_config, register_driver_profile from sqlspec.driver import BaseSyncExceptionHandler, SyncDriverAdapterBase from sqlspec.exceptions import DatabaseConnectionError, SQLSpecError +from sqlspec.migrations.utils import quote_migration_identifier from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import ensure_pyarrow from sqlspec.utils.serializers import to_json @@ -282,6 +283,40 @@ def rollback(self) -> None: msg = f"Failed to rollback transaction: {e}" raise SQLSpecError(msg) from e + def set_migration_session_schema(self, schema: str) -> None: + """Set the PostgreSQL search path for migration SQL when using ADBC PostgreSQL.""" + if not self._is_postgres: + super().set_migration_session_schema(schema) + if self._dialect_name in {"mssql", "sqlserver", "tsql"}: + logger.debug( + "SQL Server schema support not yet implemented for ADBC; configure default schema at the " + "user/login level; ignoring default_schema=%r", + schema, + ) + else: + logger.debug("%s driver does not support default schemas; ignoring default_schema=%r", "ADBC", schema) + return + quoted_schema = quote_migration_identifier(schema) + with self.with_cursor(self.connection) as cursor: + cursor.execute(f'SET search_path TO {quoted_schema}, "$user", public') + + def has_schema(self, schema: str) -> bool: + """Return whether a PostgreSQL schema exists when using ADBC PostgreSQL.""" + if not self._is_postgres: + super().has_schema(schema) + if self._dialect_name in {"mssql", "sqlserver", "tsql"}: + logger.debug( + "SQL Server schema support not yet implemented for ADBC; configure default schema at the " + "user/login level; accepting default_schema=%r", + schema, + ) + else: + logger.debug("%s driver does not support default schemas; accepting default_schema=%r", "ADBC", schema) + return True + with self.with_cursor(self.connection) as cursor: + cursor.execute("SELECT 1 FROM information_schema.schemata WHERE schema_name = $1", parameters=[schema]) + return cursor.fetchone() is not None + def with_cursor(self, connection: "AdbcConnection") -> "AdbcCursor": """Create context manager for cursor. diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index b6f195fd3..ab27d2da8 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -23,6 +23,7 @@ from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile from sqlspec.driver import AsyncDriverAdapterBase, BaseAsyncExceptionHandler from sqlspec.exceptions import SQLSpecError +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection @@ -47,6 +48,8 @@ SQLITE_IOERR_CODE = 10 SQLITE_MISMATCH_CODE = 20 +logger = get_logger(__name__) + class AiosqliteExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling aiosqlite database exceptions. @@ -181,6 +184,17 @@ async def rollback(self) -> None: msg = f"Failed to rollback transaction: {e}" raise SQLSpecError(msg) from e + async def set_migration_session_schema(self, schema: str) -> None: + """Ignore migration default schema for aiosqlite.""" + await super().set_migration_session_schema(schema) + logger.debug("%s driver does not support default schemas; ignoring default_schema=%r", "aiosqlite", schema) + + async def has_schema(self, schema: str) -> bool: + """Return True because SQLite has no separate schema namespace.""" + await super().has_schema(schema) + logger.debug("%s driver does not support default schemas; accepting default_schema=%r", "aiosqlite", schema) + return True + def with_cursor(self, connection: "AiosqliteConnection") -> "AiosqliteCursor": """Create async context manager for AIOSQLite cursor.""" return AiosqliteCursor(connection) diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index 375e55a7f..6193c0a38 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -241,6 +241,17 @@ async def rollback(self) -> None: msg = f"Failed to rollback MySQL transaction: {e}" raise SQLSpecError(msg) from e + async def set_migration_session_schema(self, schema: str) -> None: + """Ignore migration default schema for asyncmy.""" + await super().set_migration_session_schema(schema) + logger.debug("%s driver does not support default schemas; ignoring default_schema=%r", "asyncmy", schema) + + async def has_schema(self, schema: str) -> bool: + """Return True because asyncmy does not manage migration default schemas.""" + await super().has_schema(schema) + logger.debug("%s driver does not support default schemas; accepting default_schema=%r", "asyncmy", schema) + return True + def with_cursor(self, connection: "AsyncmyConnection") -> "AsyncmyCursor": """Create cursor context manager for the connection. diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index b6f8a400e..cbd48a060 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -260,6 +260,7 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async driver_type: "ClassVar[type[AsyncpgDriver]]" = AsyncpgDriver connection_type: "ClassVar[type[AsyncpgConnection]]" = type(AsyncpgConnection) # type: ignore[assignment] supports_transactional_ddl: "ClassVar[bool]" = True + supports_migration_schemas: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index c00c02b5d..ba6d414c6 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -36,6 +36,7 @@ describe_stack_statement, ) from sqlspec.exceptions import SQLSpecError, StackExecutionError +from sqlspec.migrations.utils import quote_migration_identifier from sqlspec.utils.logging import get_logger from sqlspec.utils.type_guards import has_sqlstate @@ -228,6 +229,17 @@ async def rollback(self) -> None: msg = f"Failed to rollback async transaction: {e}" raise SQLSpecError(msg) from e + async def set_migration_session_schema(self, schema: str) -> None: + """Set the PostgreSQL search path for migration SQL.""" + quoted_schema = quote_migration_identifier(schema) + await self.connection.execute(f'SET LOCAL search_path TO {quoted_schema}, "$user", public') + + async def has_schema(self, schema: str) -> bool: + """Return whether a PostgreSQL schema exists.""" + return bool( + await self.connection.fetchval("SELECT 1 FROM information_schema.schemata WHERE schema_name = $1", schema) + ) + def with_cursor(self, connection: "AsyncpgConnection") -> "AsyncpgCursor": """Create context manager for AsyncPG cursor.""" return AsyncpgCursor(connection) diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index 8fe671512..80a801f53 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -164,6 +164,7 @@ class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]): driver_type: ClassVar[type[BigQueryDriver]] = BigQueryDriver connection_type: "ClassVar[type[BigQueryConnection]]" = BigQueryConnection supports_transactional_ddl: ClassVar[bool] = False + supports_migration_schemas: ClassVar[bool] = True supports_native_parquet_import: ClassVar[bool] = True supports_native_arrow_export: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True diff --git a/sqlspec/adapters/bigquery/core.py b/sqlspec/adapters/bigquery/core.py index bc267778a..3e02bcd11 100644 --- a/sqlspec/adapters/bigquery/core.py +++ b/sqlspec/adapters/bigquery/core.py @@ -12,6 +12,8 @@ from google.cloud.bigquery import LoadJobConfig, QueryJob, QueryJobConfig from google.cloud.exceptions import GoogleCloudError from sqlglot import exp +from sqlglot.generator import _DISPATCH_CACHE # pyright: ignore[reportPrivateUsage] +from sqlglot.generators.bigquery import BigQueryGenerator from sqlspec.core import ( DriverParameterProfile, @@ -76,6 +78,41 @@ COLUMN_CACHE_MAX_SIZE = 256 +if not getattr(BigQueryGenerator, "_sqlspec_constraint_rendering", False): + _original_primary_key_column_constraint_sql = BigQueryGenerator.primarykeycolumnconstraint_sql + _original_column_def_sql = BigQueryGenerator.columndef_sql + + def _primary_key_column_constraint_sql(self: BigQueryGenerator, expression: exp.PrimaryKeyColumnConstraint) -> str: + rendered = _original_primary_key_column_constraint_sql(self, expression) + if "NOT ENFORCED" not in rendered.upper(): + return f"{rendered} NOT ENFORCED" + return rendered + + def _column_constraint_sort_key(item: tuple[int, exp.ColumnConstraint]) -> tuple[int, int]: + index, constraint = item + kind = constraint.args.get("kind") + if isinstance(kind, exp.DefaultColumnConstraint): + return 1, index + if isinstance(kind, exp.NotNullColumnConstraint): + return 2, index + return 0, index + + def _column_def_sql(self: BigQueryGenerator, expression: exp.ColumnDef, sep: str = " ") -> str: + constraints = expression.args.get("constraints") + if constraints: + expression = expression.copy() + expression.set( + "constraints", + [constraint for _, constraint in sorted(enumerate(constraints), key=_column_constraint_sort_key)], + ) + return _original_column_def_sql(self, expression, sep=sep) + + BigQueryGenerator.primarykeycolumnconstraint_sql = _primary_key_column_constraint_sql # type: ignore[assignment, method-assign] + BigQueryGenerator.columndef_sql = _column_def_sql # type: ignore[assignment, method-assign] + BigQueryGenerator._sqlspec_constraint_rendering = True # type: ignore[attr-defined] + _DISPATCH_CACHE.pop(BigQueryGenerator, None) + + def _identity(value: Any) -> Any: return value diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index fa2772dac..221de7dfd 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -8,7 +8,8 @@ import io from typing import TYPE_CHECKING, Any, cast -from google.cloud.exceptions import GoogleCloudError +from google.cloud.bigquery import QueryJobConfig +from google.cloud.exceptions import GoogleCloudError, NotFound from sqlspec.adapters.bigquery._typing import BigQueryConnection, BigQueryCursor, BigQuerySessionContext from sqlspec.adapters.bigquery.core import ( @@ -18,6 +19,7 @@ build_load_job_telemetry, build_retry, collect_rows, + copy_job_config, create_mapped_exception, default_statement_config, detect_emulator, @@ -47,7 +49,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from google.cloud.bigquery import QueryJob, QueryJobConfig + from google.cloud.bigquery import QueryJob from sqlspec.builder import QueryBuilder from sqlspec.core import SQL, ArrowResult, SQLResult, Statement, StatementFilter @@ -65,6 +67,11 @@ __all__ = ("BigQueryCursor", "BigQueryDriver", "BigQueryExceptionHandler", "BigQuerySessionContext") +def _normalize_bigquery_dataset_name(dataset_name: str) -> str: + """Return a BigQuery dataset path without identifier quoting.""" + return dataset_name.strip().replace("`", "").replace("`.`", ".") + + class BigQueryExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling BigQuery API exceptions. @@ -287,6 +294,24 @@ def commit(self) -> None: def rollback(self) -> None: """Rollback transaction - BigQuery doesn't support transactions.""" + def set_migration_session_schema(self, schema: str) -> None: + """Set BigQuery default dataset for migration query jobs.""" + dataset_path = self._qualify_dataset_path(schema) + job_config = QueryJobConfig() + if self._default_query_job_config is not None: + copy_job_config(self._default_query_job_config, job_config) + job_config.default_dataset = dataset_path + self._default_query_job_config = job_config + + def has_schema(self, schema: str) -> bool: + """Return whether a BigQuery dataset exists.""" + dataset_path = self._qualify_dataset_path(schema) + try: + self.connection.get_dataset(dataset_path) + except NotFound: + return False + return True + def with_cursor(self, connection: "BigQueryConnection") -> "BigQueryCursor": """Create context manager for cursor management. @@ -538,5 +563,25 @@ def _connection_in_transaction(self) -> bool: """ return False + def _qualify_dataset_path(self, dataset_name: str) -> str: + """Return a project-qualified BigQuery dataset path.""" + normalized = _normalize_bigquery_dataset_name(dataset_name) + if "." in normalized: + return normalized + + project = str(getattr(self.connection, "project", "") or "") + if not project and self._default_query_job_config is not None: + default_dataset = self._default_query_job_config.default_dataset + if default_dataset is not None: + default_dataset_path = _normalize_bigquery_dataset_name(str(default_dataset)) + if "." in default_dataset_path: + project = default_dataset_path.split(".", 1)[0] + + if not project: + msg = "BigQuery migration schemas require a configured project" + raise SQLSpecError(msg) + + return f"{project}.{normalized}" + register_driver_profile("bigquery", driver_profile) diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index 5fb03c133..098f26673 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -222,6 +222,7 @@ class DuckDBConfig(SyncDatabaseConfig[DuckDBConnection, DuckDBConnectionPool, Du driver_type: "ClassVar[type[DuckDBDriver]]" = DuckDBDriver connection_type: "ClassVar[type[DuckDBConnection]]" = DuckDBConnection supports_transactional_ddl: "ClassVar[bool]" = True + supports_migration_schemas: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index a08588e53..05868ff41 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -40,6 +40,11 @@ _type_converter = DuckDBOutputConverter() +def _quote_duckdb_search_path(schema: str) -> str: + """Return a DuckDB string literal for SET search_path.""" + return "'" + schema.replace("'", "''") + "'" + + class DuckDBExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling DuckDB database exceptions. @@ -210,6 +215,17 @@ def rollback(self) -> None: msg = f"Failed to rollback DuckDB transaction: {e}" raise SQLSpecError(msg) from e + def set_migration_session_schema(self, schema: str) -> None: + """Set DuckDB search_path for migration SQL.""" + self.connection.execute(f"SET search_path = {_quote_duckdb_search_path(schema)}") + + def has_schema(self, schema: str) -> bool: + """Return whether a DuckDB schema exists.""" + result = self.connection.execute( + "SELECT 1 FROM information_schema.schemata WHERE schema_name = ?", [schema] + ).fetchone() + return result is not None + def with_cursor(self, connection: "DuckDBConnection") -> "DuckDBCursor": """Create context manager for DuckDB cursor. diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index 69cabfed4..1229452a7 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -232,6 +232,7 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "OracleSyncConne connection_type: "ClassVar[type[OracleSyncConnection]]" = OracleSyncConnection migration_tracker_type: "ClassVar[type[OracleSyncMigrationTracker]]" = OracleSyncMigrationTracker supports_transactional_ddl: ClassVar[bool] = False + supports_migration_schemas: ClassVar[bool] = True supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True @@ -422,6 +423,7 @@ class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "OracleAsyncC driver_type: ClassVar[type[OracleAsyncDriver]] = OracleAsyncDriver migration_tracker_type: "ClassVar[type[OracleAsyncMigrationTracker]]" = OracleAsyncMigrationTracker supports_transactional_ddl: ClassVar[bool] = False + supports_migration_schemas: ClassVar[bool] = True supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 0c83c837d..018f5b1c5 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -51,6 +51,7 @@ hash_stack_operations, ) from sqlspec.exceptions import ImproperConfigurationError, SQLSpecError, StackExecutionError +from sqlspec.migrations.utils import quote_migration_identifier from sqlspec.utils.logging import get_logger, log_with_context from sqlspec.utils.module_loader import ensure_pyarrow from sqlspec.utils.type_guards import has_pipeline_capability @@ -446,6 +447,19 @@ def rollback(self) -> None: msg = f"Failed to rollback Oracle transaction: {e}" raise SQLSpecError(msg) from e + def set_migration_session_schema(self, schema: str) -> None: + """Set Oracle CURRENT_SCHEMA for migration SQL.""" + quoted_schema = quote_migration_identifier(schema.strip().upper()) + with self.with_cursor(self.connection) as cursor: + cursor.execute(f"ALTER SESSION SET CURRENT_SCHEMA = {quoted_schema}") + + def has_schema(self, schema: str) -> bool: + """Return whether an Oracle schema/user exists.""" + schema_name = schema.strip() + with self.with_cursor(self.connection) as cursor: + cursor.execute("SELECT 1 FROM ALL_USERS WHERE USERNAME = UPPER(:schema_name)", {"schema_name": schema_name}) + return cursor.fetchone() is not None + def with_cursor(self, connection: OracleSyncConnection) -> OracleSyncCursor: """Create context manager for Oracle cursor. @@ -944,6 +958,22 @@ async def rollback(self) -> None: msg = f"Failed to rollback Oracle transaction: {e}" raise SQLSpecError(msg) from e + async def set_migration_session_schema(self, schema: str) -> None: + """Set Oracle CURRENT_SCHEMA for migration SQL.""" + quoted_schema = quote_migration_identifier(schema.strip().upper()) + async with self.with_cursor(self.connection) as cursor: + await cursor.execute(f"ALTER SESSION SET CURRENT_SCHEMA = {quoted_schema}") + + async def has_schema(self, schema: str) -> bool: + """Return whether an Oracle schema/user exists.""" + schema_name = schema.strip() + async with self.with_cursor(self.connection) as cursor: + await cursor.execute( + "SELECT 1 FROM ALL_USERS WHERE USERNAME = UPPER(:schema_name)", {"schema_name": schema_name} + ) + row = await cursor.fetchone() + return row is not None + def with_cursor(self, connection: OracleAsyncConnection) -> OracleAsyncCursor: """Create context manager for Oracle cursor. diff --git a/sqlspec/adapters/oracledb/migrations.py b/sqlspec/adapters/oracledb/migrations.py index 0b7186d42..eb3c28c13 100644 --- a/sqlspec/adapters/oracledb/migrations.py +++ b/sqlspec/adapters/oracledb/migrations.py @@ -38,6 +38,22 @@ class OracleMigrationTrackerMixin: __slots__ = () version_table: str + version_table_name: str + version_table_schema: str | None + + def _qualify_version_table(self, version_table_name: str, version_table_schema: str | None) -> str: + """Return Oracle tracker table name, uppercasing qualified identifiers.""" + if version_table_schema: + return f"{version_table_schema.upper()}.{version_table_name.upper()}" + return version_table_name + + def _get_create_table_builder(self) -> CreateTable: + """Return an Oracle CREATE TABLE builder for the tracker table.""" + table_name = self.version_table_name.upper() if self.version_table_schema else self.version_table_name + builder = sql.create_table(table_name) + if self.version_table_schema: + builder.in_schema(self.version_table_schema.upper()) + return builder def _get_create_table_sql(self) -> CreateTable: """Get Oracle-specific SQL builder for creating the tracking table. @@ -51,8 +67,8 @@ def _get_create_table_sql(self) -> CreateTable: SQL builder object for Oracle table creation. """ return ( - sql - .create_table(self.version_table) + self + ._get_create_table_builder() .column("version_num", "VARCHAR2(32)", primary_key=True) .column("version_type", "VARCHAR2(16)") .column("execution_sequence", "INTEGER") @@ -125,10 +141,14 @@ def _get_existing_columns_sql(self) -> str: Returns: Raw SQL string for Oracle's USER_TAB_COLUMNS query. """ + table_name = self.version_table_name.upper() + owner_filter = "" + if self.version_table_schema: + owner_filter = f"\n AND owner = '{self.version_table_schema.upper()}'" return f""" SELECT column_name - FROM user_tab_columns - WHERE table_name = '{self.version_table.upper()}' + FROM all_tab_columns + WHERE table_name = '{table_name}'{owner_filter} """ def _detect_missing_columns(self, existing_columns: "set[str]") -> "set[str]": diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index 724523c27..2c3caa440 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -191,6 +191,7 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD driver_type: ClassVar[type[PsqlpyDriver]] = PsqlpyDriver connection_type: "ClassVar[type[PsqlpyConnection]]" = PsqlpyConnection supports_transactional_ddl: "ClassVar[bool]" = True + supports_migration_schemas: "ClassVar[bool]" = True supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index 85cb616b7..0e707c079 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -32,6 +32,7 @@ from sqlspec.core import SQL, StatementConfig, get_cache_config, register_driver_profile from sqlspec.driver import AsyncDriverAdapterBase, BaseAsyncExceptionHandler from sqlspec.exceptions import SQLSpecError +from sqlspec.migrations.utils import quote_migration_identifier from sqlspec.utils.logging import get_logger if TYPE_CHECKING: @@ -219,6 +220,17 @@ async def rollback(self) -> None: msg = f"Failed to rollback psqlpy transaction: {e}" raise SQLSpecError(msg) from e + async def set_migration_session_schema(self, schema: str) -> None: + """Set the PostgreSQL search path for migration SQL.""" + quoted_schema = quote_migration_identifier(schema) + await self.connection.execute(f'SET LOCAL search_path TO {quoted_schema}, "$user", public') + + async def has_schema(self, schema: str) -> bool: + """Return whether a PostgreSQL schema exists.""" + rows = await self.connection.fetch("SELECT 1 FROM information_schema.schemata WHERE schema_name = $1", [schema]) + data, _ = collect_rows(rows) + return bool(data) + def with_cursor(self, connection: "PsqlpyConnection") -> "PsqlpyCursor": """Create context manager for psqlpy cursor. diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 5ebdd864b..82edee4a5 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -196,6 +196,7 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool driver_type: "ClassVar[type[PsycopgSyncDriver]]" = PsycopgSyncDriver connection_type: "ClassVar[type[PsycopgSyncConnection]]" = PsycopgSyncConnection supports_transactional_ddl: "ClassVar[bool]" = True + supports_migration_schemas: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True @@ -477,6 +478,7 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec driver_type: ClassVar[type[PsycopgAsyncDriver]] = PsycopgAsyncDriver connection_type: "ClassVar[type[PsycopgAsyncConnection]]" = PsycopgAsyncConnection supports_transactional_ddl: "ClassVar[bool]" = True + supports_migration_schemas: "ClassVar[bool]" = True supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index 47ebf08f3..04cf201f4 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -56,6 +56,7 @@ describe_stack_statement, ) from sqlspec.exceptions import SQLSpecError, StackExecutionError +from sqlspec.migrations.utils import quote_migration_identifier from sqlspec.utils.logging import get_logger from sqlspec.utils.type_guards import is_readable @@ -339,6 +340,19 @@ def rollback(self) -> None: msg = f"Failed to rollback transaction: {e}" raise SQLSpecError(msg) from e + def set_migration_session_schema(self, schema: str) -> None: + """Set the PostgreSQL search path for migration SQL.""" + quoted_schema = quote_migration_identifier(schema) + sql = cast("LiteralString", f'SET LOCAL search_path TO {quoted_schema}, "$user", public') # type: ignore[redundant-cast] + with self.with_cursor(self.connection) as cursor: + cursor.execute(sql) + + def has_schema(self, schema: str) -> bool: + """Return whether a PostgreSQL schema exists.""" + with self.with_cursor(self.connection) as cursor: + cursor.execute("SELECT 1 FROM information_schema.schemata WHERE schema_name = %s", (schema,)) + return cursor.fetchone() is not None + def with_cursor(self, connection: PsycopgSyncConnection) -> PsycopgSyncCursor: """Create context manager for PostgreSQL cursor.""" return PsycopgSyncCursor(connection) @@ -796,6 +810,20 @@ async def rollback(self) -> None: msg = f"Failed to rollback transaction: {e}" raise SQLSpecError(msg) from e + async def set_migration_session_schema(self, schema: str) -> None: + """Set the PostgreSQL search path for migration SQL.""" + quoted_schema = quote_migration_identifier(schema) + sql = cast("LiteralString", f'SET LOCAL search_path TO {quoted_schema}, "$user", public') # type: ignore[redundant-cast] + async with self.with_cursor(self.connection) as cursor: + await cursor.execute(sql) + + async def has_schema(self, schema: str) -> bool: + """Return whether a PostgreSQL schema exists.""" + async with self.with_cursor(self.connection) as cursor: + await cursor.execute("SELECT 1 FROM information_schema.schemata WHERE schema_name = %s", (schema,)) + row = await cursor.fetchone() + return row is not None + def with_cursor(self, connection: "PsycopgAsyncConnection") -> "PsycopgAsyncCursor": """Create async context manager for PostgreSQL cursor.""" return PsycopgAsyncCursor(connection) diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index 5b58f76dd..747817345 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -20,6 +20,7 @@ from sqlspec.core.result import DMLResult from sqlspec.driver import BaseSyncExceptionHandler, SyncDriverAdapterBase from sqlspec.exceptions import SQLSpecError +from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import Sequence @@ -35,6 +36,8 @@ __all__ = ("SqliteCursor", "SqliteDriver", "SqliteExceptionHandler", "SqliteSessionContext") +logger = get_logger(__name__) + class SqliteExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling SQLite database exceptions. @@ -388,6 +391,17 @@ def rollback(self) -> None: msg = f"Failed to rollback transaction: {e}" raise SQLSpecError(msg) from e + def set_migration_session_schema(self, schema: str) -> None: + """Ignore migration default schema for SQLite.""" + super().set_migration_session_schema(schema) + logger.debug("%s driver does not support default schemas; ignoring default_schema=%r", "SQLite", schema) + + def has_schema(self, schema: str) -> bool: + """Return True because SQLite has no separate schema namespace.""" + super().has_schema(schema) + logger.debug("%s driver does not support default schemas; accepting default_schema=%r", "SQLite", schema) + return True + def with_cursor(self, connection: "SqliteConnection") -> "SqliteCursor": """Create context manager for SQLite cursor. diff --git a/sqlspec/builder/_dml.py b/sqlspec/builder/_dml.py index 00f97c1d8..617cb9f5a 100644 --- a/sqlspec/builder/_dml.py +++ b/sqlspec/builder/_dml.py @@ -111,8 +111,8 @@ def columns(self, *columns: str | exp.Expr) -> Self: if columns: identifiers = [exp.to_identifier(col) if isinstance(col, str) else col for col in columns] - table_name = current_this.this - current_expr.set("this", exp.Schema(this=table_name, expressions=identifiers)) + table_expression = current_this.this if isinstance(current_this, exp.Schema) else current_this + current_expr.set("this", exp.Schema(this=table_expression.copy(), expressions=identifiers)) elif isinstance(current_this, exp.Schema): table_name = current_this.this current_expr.set("this", exp.Table(this=table_name)) diff --git a/sqlspec/config.py b/sqlspec/config.py index 3e201b3d2..dfb11f44f 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -144,6 +144,12 @@ class MigrationConfig(TypedDict): version_table_name: NotRequired[str] """Name of the table used to track applied migrations. Defaults to 'sqlspec_migrations'.""" + default_schema: NotRequired[str] + """Schema applied to migration sessions before user migration SQL runs, when supported by the adapter.""" + + version_table_schema: NotRequired[str] + """Schema that stores the migration tracking table. Defaults to default_schema when omitted.""" + project_root: NotRequired[str] """Path to the project root directory. Used for relative path resolution.""" @@ -1130,6 +1136,7 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]): supports_transactional_ddl: "ClassVar[bool]" = False supports_native_arrow_import: "ClassVar[bool]" = False supports_native_arrow_export: "ClassVar[bool]" = False + supports_migration_schemas: "ClassVar[bool]" = False supports_native_parquet_import: "ClassVar[bool]" = False supports_native_parquet_export: "ClassVar[bool]" = False requires_staging_for_load: "ClassVar[bool]" = False diff --git a/sqlspec/driver/_async.py b/sqlspec/driver/_async.py index 620d08814..8d070a4f7 100644 --- a/sqlspec/driver/_async.py +++ b/sqlspec/driver/_async.py @@ -163,6 +163,26 @@ def data_dictionary(self) -> "AsyncDataDictionaryBase": """ + async def set_migration_session_schema(self, schema: str) -> None: + """Set the default schema for migration SQL when supported. + + Args: + schema: Schema requested for the current migration session. + """ + logger.debug("migration.schema.noop", extra={"schema": schema, "driver": type(self).__name__}) + + async def has_schema(self, schema: str) -> bool: + """Return whether the schema exists for migration validation. + + Args: + schema: Schema name to validate. + + Returns: + True when the adapter does not provide schema validation. + """ + logger.debug("migration.schema.validation.noop", extra={"schema": schema, "driver": type(self).__name__}) + return True + # ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS - The Execution Engine # ───────────────────────────────────────────────────────────────────────────── diff --git a/sqlspec/driver/_sync.py b/sqlspec/driver/_sync.py index 3a1fa89f7..250ad776b 100644 --- a/sqlspec/driver/_sync.py +++ b/sqlspec/driver/_sync.py @@ -144,6 +144,26 @@ def data_dictionary(self) -> "SyncDataDictionaryBase": """ + def set_migration_session_schema(self, schema: str) -> None: + """Set the default schema for migration SQL when supported. + + Args: + schema: Schema requested for the current migration session. + """ + logger.debug("migration.schema.noop", extra={"schema": schema, "driver": type(self).__name__}) + + def has_schema(self, schema: str) -> bool: + """Return whether the schema exists for migration validation. + + Args: + schema: Schema name to validate. + + Returns: + True when the adapter does not provide schema validation. + """ + logger.debug("migration.schema.validation.noop", extra={"schema": schema, "driver": type(self).__name__}) + return True + # ───────────────────────────────────────────────────────────────────────────── # CORE DISPATCH METHODS - The Execution Engine # ───────────────────────────────────────────────────────────────────────────── diff --git a/sqlspec/migrations/base.py b/sqlspec/migrations/base.py index 34180d6d0..7942cf066 100644 --- a/sqlspec/migrations/base.py +++ b/sqlspec/migrations/base.py @@ -14,6 +14,7 @@ from sqlspec.migrations.context import MigrationContext from sqlspec.migrations.loaders import get_migration_loader from sqlspec.migrations.templates import MigrationTemplateSettings, TemplateDescriptionHints, build_template_settings +from sqlspec.migrations.utils import resolve_tracker_schema as _resolve_tracker_schema from sqlspec.migrations.version import parse_version from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import module_to_os_path @@ -39,17 +40,33 @@ class BaseMigrationTracker(ABC, Generic[DriverT]): """Base class for migration version tracking.""" - __slots__ = ("_output_policy", "version_table") + __slots__ = ("_output_policy", "version_table", "version_table_name", "version_table_schema") - def __init__(self, version_table_name: str = "ddl_migrations") -> None: + def __init__(self, version_table_name: str = "ddl_migrations", version_table_schema: str | None = None) -> None: """Initialize the migration tracker. Args: version_table_name: Name of the table to track migrations. + version_table_schema: Optional schema that stores the tracking table. """ - self.version_table = version_table_name + self.version_table_name = version_table_name + self.version_table_schema = version_table_schema + self.version_table = self._qualify_version_table(version_table_name, version_table_schema) self._output_policy = {"use_logger": False, "echo": True, "summary_only": False} + def _qualify_version_table(self, version_table_name: str, version_table_schema: str | None) -> str: + """Return the tracker table name, qualified with schema when configured.""" + if version_table_schema: + return f"{version_table_schema}.{version_table_name}" + return version_table_name + + def _get_create_table_builder(self) -> CreateTable: + """Return a CREATE TABLE builder for the tracker table.""" + builder = sql.create_table(self.version_table_name) + if self.version_table_schema: + builder.in_schema(self.version_table_schema) + return builder + def set_output_policy(self, *, use_logger: bool, echo: bool, summary_only: bool) -> None: """Set output policy for tracker console/logging behavior.""" self._output_policy = {"use_logger": use_logger, "echo": echo, "summary_only": summary_only} @@ -79,8 +96,8 @@ def _get_create_table_sql(self) -> CreateTable: SQL builder object for table creation. """ return ( - sql - .create_table(self.version_table) + self + ._get_create_table_builder() .if_not_exists() .column("version_num", "VARCHAR(32)", primary_key=True) .column("version_type", "VARCHAR(16)") @@ -635,6 +652,23 @@ def __init__(self, config: ConfigT) -> None: self._last_command_error: Exception | None = None self._last_command_metrics: dict[str, float] | None = None + def _get_migration_config(self) -> "dict[str, Any]": + """Return migration config as a plain dictionary.""" + return cast("dict[str, Any]", self.config.migration_config) or {} + + def _resolve_default_schema(self) -> str | None: + """Return the configured default migration schema.""" + default_schema = self._get_migration_config().get("default_schema") + if isinstance(default_schema, str) and default_schema: + return default_schema + return None + + def _resolve_tracker_schema(self) -> str | None: + """Return tracker schema only for adapters that support schema-qualified migration tables.""" + if not bool(getattr(self.config, "supports_migration_schemas", False)): + return None + return _resolve_tracker_schema(self._get_migration_config()) + def _parse_extension_configs(self) -> "dict[str, dict[str, Any]]": """Parse extension configurations from include_extensions. diff --git a/sqlspec/migrations/commands.py b/sqlspec/migrations/commands.py index 7df3bec66..ed78cebe7 100644 --- a/sqlspec/migrations/commands.py +++ b/sqlspec/migrations/commands.py @@ -14,6 +14,7 @@ from rich.table import Table from sqlspec.builder import sql +from sqlspec.exceptions import MigrationError from sqlspec.migrations.base import BaseMigrationCommands from sqlspec.migrations.context import MigrationContext from sqlspec.migrations.fix import MigrationFixer @@ -264,7 +265,9 @@ def __init__(self, config: "SyncConfigT") -> None: config: The SQLSpec configuration. """ super().__init__(config) - self.tracker = config.migration_tracker_type(self.version_table) + self.tracker = config.migration_tracker_type( + self.version_table, version_table_schema=self._resolve_tracker_schema() + ) # Create context with extension configurations context = MigrationContext.from_config(config) @@ -279,6 +282,15 @@ def __init__(self, config: "SyncConfigT") -> None: description_hints=self._template_settings.description_hints, ) + def _validate_migration_schema(self, driver: Any) -> None: + """Validate the configured migration schema exists before issuing DDL.""" + default_schema = self._resolve_default_schema() + if default_schema is None: + return + if not driver.has_schema(default_schema): + msg = f"Configured schema '{default_schema}' does not exist" + raise MigrationError(msg) + def init(self, directory: str, package: bool = True) -> None: """Initialize migration directory structure. @@ -708,6 +720,7 @@ def upgrade( with self.config.provide_session() as driver: db_system = resolve_db_system(type(driver).__name__) + self._validate_migration_schema(driver) self.tracker.ensure_tracking_table(driver) if auto_sync and self.config.migration_config.get("auto_sync", True): @@ -846,6 +859,7 @@ def downgrade( with self.config.provide_session() as driver: db_system = resolve_db_system(type(driver).__name__) + self._validate_migration_schema(driver) self.tracker.ensure_tracking_table(driver) applied = self.tracker.get_applied_migrations(driver) if runtime is not None: @@ -1191,7 +1205,9 @@ def __init__(self, config: "AsyncConfigT") -> None: config: The SQLSpec configuration. """ super().__init__(config) - self.tracker = config.migration_tracker_type(self.version_table) + self.tracker = config.migration_tracker_type( + self.version_table, version_table_schema=self._resolve_tracker_schema() + ) # Create context with extension configurations context = MigrationContext.from_config(config) @@ -1206,6 +1222,15 @@ def __init__(self, config: "AsyncConfigT") -> None: description_hints=self._template_settings.description_hints, ) + async def _validate_migration_schema(self, driver: Any) -> None: + """Validate the configured migration schema exists before issuing DDL.""" + default_schema = self._resolve_default_schema() + if default_schema is None: + return + if not await driver.has_schema(default_schema): + msg = f"Configured schema '{default_schema}' does not exist" + raise MigrationError(msg) + async def init(self, directory: str, package: bool = True) -> None: """Initialize migration directory structure. @@ -1635,6 +1660,7 @@ async def upgrade( async with self.config.provide_session() as driver: db_system = resolve_db_system(type(driver).__name__) + await self._validate_migration_schema(driver) await self.tracker.ensure_tracking_table(driver) if auto_sync and self.config.migration_config.get("auto_sync", True): @@ -1775,6 +1801,7 @@ async def downgrade( async with self.config.provide_session() as driver: db_system = resolve_db_system(type(driver).__name__) + await self._validate_migration_schema(driver) await self.tracker.ensure_tracking_table(driver) applied = await self.tracker.get_applied_migrations(driver) diff --git a/sqlspec/migrations/runner.py b/sqlspec/migrations/runner.py index 77fc44f07..8e1ad9ca6 100644 --- a/sqlspec/migrations/runner.py +++ b/sqlspec/migrations/runner.py @@ -474,6 +474,15 @@ def should_use_transaction( migration_config = cast("dict[str, Any]", config.migration_config) or {} return bool(migration_config.get("transactional", True)) + def _resolve_default_schema(self) -> str | None: + """Return the configured default schema for migration execution.""" + config = self.context.config if self.context else None + migration_config = cast("dict[str, Any]", getattr(config, "migration_config", None)) or {} + default_schema = migration_config.get("default_schema") + if isinstance(default_schema, str) and default_schema: + return default_schema + return None + class SyncMigrationRunner(BaseMigrationRunner): """Synchronous migration runner with pure sync methods.""" @@ -571,8 +580,11 @@ def execute_upgrade( execution_time = 0 try: + default_schema = self._resolve_default_schema() if use_transaction: driver.begin() + if default_schema: + driver.set_migration_session_schema(default_schema) for sql_statement in upgrade_sql_list: if sql_statement.strip(): driver.execute_script(sql_statement) @@ -581,6 +593,8 @@ def execute_upgrade( on_success(execution_time) driver.commit() else: + if default_schema: + driver.set_migration_session_schema(default_schema) for sql_statement in upgrade_sql_list: if sql_statement.strip(): driver.execute_script(sql_statement) @@ -674,8 +688,11 @@ def execute_downgrade( execution_time = 0 try: + default_schema = self._resolve_default_schema() if use_transaction: driver.begin() + if default_schema: + driver.set_migration_session_schema(default_schema) for sql_statement in downgrade_sql_list: if sql_statement.strip(): driver.execute_script(sql_statement) @@ -684,6 +701,8 @@ def execute_downgrade( on_success(execution_time) driver.commit() else: + if default_schema: + driver.set_migration_session_schema(default_schema) for sql_statement in downgrade_sql_list: if sql_statement.strip(): driver.execute_script(sql_statement) @@ -901,8 +920,11 @@ async def execute_upgrade( execution_time = 0 try: + default_schema = self._resolve_default_schema() if use_transaction: await driver.begin() + if default_schema: + await driver.set_migration_session_schema(default_schema) for sql_statement in upgrade_sql_list: if sql_statement.strip(): await driver.execute_script(sql_statement) @@ -911,6 +933,8 @@ async def execute_upgrade( await on_success(execution_time) await driver.commit() else: + if default_schema: + await driver.set_migration_session_schema(default_schema) for sql_statement in upgrade_sql_list: if sql_statement.strip(): await driver.execute_script(sql_statement) @@ -1004,8 +1028,11 @@ async def execute_downgrade( execution_time = 0 try: + default_schema = self._resolve_default_schema() if use_transaction: await driver.begin() + if default_schema: + await driver.set_migration_session_schema(default_schema) for sql_statement in downgrade_sql_list: if sql_statement.strip(): await driver.execute_script(sql_statement) @@ -1014,6 +1041,8 @@ async def execute_downgrade( await on_success(execution_time) await driver.commit() else: + if default_schema: + await driver.set_migration_session_schema(default_schema) for sql_statement in downgrade_sql_list: if sql_statement.strip(): await driver.execute_script(sql_statement) diff --git a/sqlspec/migrations/utils.py b/sqlspec/migrations/utils.py index 09321c379..6a8ae63e6 100644 --- a/sqlspec/migrations/utils.py +++ b/sqlspec/migrations/utils.py @@ -13,16 +13,48 @@ from sqlspec.utils.text import slugify if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Mapping from sqlspec.config import DatabaseConfigProtocol from sqlspec.driver import AsyncDriverAdapterBase -__all__ = ("create_migration_file", "drop_all", "get_author") +__all__ = ("create_migration_file", "drop_all", "get_author", "quote_migration_identifier", "resolve_tracker_schema") logger = get_logger(__name__) +def resolve_tracker_schema(migration_config: "Mapping[str, Any] | None") -> str | None: + """Resolve the schema for the migration tracking table. + + Args: + migration_config: Migration configuration mapping. + + Returns: + Explicit tracker schema, default migration schema, or None. + """ + if not migration_config: + return None + version_table_schema = migration_config.get("version_table_schema") + if isinstance(version_table_schema, str) and version_table_schema: + return version_table_schema + default_schema = migration_config.get("default_schema") + if isinstance(default_schema, str) and default_schema: + return default_schema + return None + + +def quote_migration_identifier(identifier: str) -> str: + """Quote a SQL identifier for migration schema/session commands. + + Args: + identifier: SQL identifier to quote. + + Returns: + Double-quoted identifier with embedded double quotes escaped. + """ + return '"' + identifier.replace('"', '""') + '"' + + def create_migration_file( migrations_dir: Path, version: str, diff --git a/tests/integration/adapters/_postgres_migration_schema.py b/tests/integration/adapters/_postgres_migration_schema.py new file mode 100644 index 000000000..124aef9a5 --- /dev/null +++ b/tests/integration/adapters/_postgres_migration_schema.py @@ -0,0 +1,65 @@ +"""Shared PostgreSQL migration schema integration helpers.""" + +from pathlib import Path +from typing import Any, Literal +from uuid import uuid4 + +from sqlspec.migrations.utils import quote_migration_identifier + +ParamStyle = Literal["numeric", "pyformat"] + + +def unique_identifier(prefix: str) -> str: + """Return a short PostgreSQL-safe identifier for integration tests.""" + return f"{prefix}_{uuid4().hex[:10]}" + + +def write_unqualified_table_migration(migration_dir: Path, table_name: str) -> None: + """Write a Python migration that creates an unqualified table.""" + migration_content = f'''"""Create an unqualified table.""" + + +def up(): + """Create an unqualified table.""" + return [""" + CREATE TABLE {table_name} ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL + ) + """] + + +def down(): + """Drop the unqualified table.""" + return ["DROP TABLE IF EXISTS {table_name}"] +''' + (migration_dir / "0001_create_unqualified_table.py").write_text(migration_content) + + +def create_schema_sql(schema: str) -> str: + """Return PostgreSQL CREATE SCHEMA SQL for a trusted test identifier.""" + return f"CREATE SCHEMA {quote_migration_identifier(schema)}" + + +def drop_schema_sql(schema: str) -> str: + """Return PostgreSQL DROP SCHEMA SQL for a trusted test identifier.""" + return f"DROP SCHEMA IF EXISTS {quote_migration_identifier(schema)} CASCADE" + + +def table_exists_sql(style: ParamStyle) -> str: + """Return an information_schema table existence query for the adapter parameter style.""" + if style == "pyformat": + return "SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s" + return "SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2" + + +def sync_table_exists(driver: Any, schema: str, table_name: str, *, style: ParamStyle) -> bool: + """Return whether the table exists using a sync SQLSpec driver.""" + result = driver.execute(table_exists_sql(style), (schema, table_name)) + return bool(result.data) + + +async def async_table_exists(driver: Any, schema: str, table_name: str, *, style: ParamStyle) -> bool: + """Return whether the table exists using an async SQLSpec driver.""" + result = await driver.execute(table_exists_sql(style), (schema, table_name)) + return bool(result.data) diff --git a/tests/integration/adapters/adbc/test_migrations.py b/tests/integration/adapters/adbc/test_migrations.py index bb1e53bc0..831f2b64a 100644 --- a/tests/integration/adapters/adbc/test_migrations.py +++ b/tests/integration/adapters/adbc/test_migrations.py @@ -6,7 +6,16 @@ import pytest from sqlspec.adapters.adbc.config import AdbcConfig +from sqlspec.exceptions import MigrationError from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands, create_migration_commands +from tests.integration.adapters._postgres_migration_schema import ( + create_schema_sql, + drop_schema_sql, + sync_table_exists, + unique_identifier, + write_unqualified_table_migration, +) +from tests.integration.adapters.adbc.conftest import xfail_if_driver_missing # xdist_group is assigned per test based on database backend to enable parallel execution @@ -72,9 +81,130 @@ def down(): @pytest.mark.xdist_group("postgres") -def test_adbc_postgresql_migration_workflow() -> None: - """Test ADBC PostgreSQL migration workflow with test database.""" - pytest.skip("Requires running PostgreSQL") +@xfail_if_driver_missing +def test_adbc_postgresql_migration_default_schema_applies_to_ddl( + tmp_path: Path, adbc_postgres_connection_config: "dict[str, str]" +) -> None: + """ADBC PostgreSQL migrations run unqualified DDL in the configured default schema.""" + schema = unique_identifier("adbc_default") + table_name = unique_identifier("adbc_table") + version_table = unique_identifier("adbc_versions") + migration_dir = tmp_path / "migrations" + connection_config = dict(adbc_postgres_connection_config) + connection_config["driver_name"] = "adbc_driver_postgresql" + + config = AdbcConfig( + connection_config=connection_config, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + + try: + with config.provide_session() as driver: + driver.execute_script(create_schema_sql(schema)) + driver.commit() + + commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + with config.provide_session() as driver: + assert sync_table_exists(driver, schema, table_name, style="numeric") + assert not sync_table_exists(driver, "public", table_name, style="numeric") + assert sync_table_exists(driver, schema, version_table, style="numeric") + finally: + with config.provide_session() as driver: + driver.execute_script(drop_schema_sql(schema)) + driver.commit() + config.close_pool() + + +@pytest.mark.xdist_group("postgres") +@xfail_if_driver_missing +def test_adbc_postgresql_migration_separable_tracker_and_default_schema( + tmp_path: Path, adbc_postgres_connection_config: "dict[str, str]" +) -> None: + """ADBC PostgreSQL supports separate schemas for migrated DDL and the tracker table.""" + default_schema = unique_identifier("adbc_default") + tracker_schema = unique_identifier("adbc_tracker") + table_name = unique_identifier("adbc_table") + version_table = unique_identifier("adbc_versions") + migration_dir = tmp_path / "migrations" + connection_config = dict(adbc_postgres_connection_config) + connection_config["driver_name"] = "adbc_driver_postgresql" + + config = AdbcConfig( + connection_config=connection_config, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": default_schema, + "version_table_schema": tracker_schema, + }, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + + try: + with config.provide_session() as driver: + driver.execute_script(create_schema_sql(default_schema)) + driver.execute_script(create_schema_sql(tracker_schema)) + driver.commit() + + commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + with config.provide_session() as driver: + assert sync_table_exists(driver, default_schema, table_name, style="numeric") + assert sync_table_exists(driver, tracker_schema, version_table, style="numeric") + assert not sync_table_exists(driver, default_schema, version_table, style="numeric") + finally: + with config.provide_session() as driver: + driver.execute_script(drop_schema_sql(default_schema)) + driver.execute_script(drop_schema_sql(tracker_schema)) + driver.commit() + config.close_pool() + + +@pytest.mark.xdist_group("postgres") +@xfail_if_driver_missing +def test_adbc_postgresql_migration_missing_schema_fails_fast( + tmp_path: Path, adbc_postgres_connection_config: "dict[str, str]" +) -> None: + """ADBC PostgreSQL validates the default schema before creating tracker tables or applying DDL.""" + schema = unique_identifier("adbc_missing") + table_name = unique_identifier("adbc_table") + version_table = unique_identifier("adbc_versions") + migration_dir = tmp_path / "migrations" + connection_config = dict(adbc_postgres_connection_config) + connection_config["driver_name"] = "adbc_driver_postgresql" + + config = AdbcConfig( + connection_config=connection_config, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = create_migration_commands(config) + + try: + commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + + with pytest.raises(MigrationError, match=f"Configured schema '{schema}' does not exist"): + commands.upgrade() + + with config.provide_session() as driver: + assert not sync_table_exists(driver, "public", version_table, style="numeric") + assert not sync_table_exists(driver, "public", table_name, style="numeric") + finally: + config.close_pool() @pytest.mark.xdist_group("sqlite") diff --git a/tests/integration/adapters/asyncpg/test_migrations.py b/tests/integration/adapters/asyncpg/test_migrations.py index 4013da0a5..725785d6f 100644 --- a/tests/integration/adapters/asyncpg/test_migrations.py +++ b/tests/integration/adapters/asyncpg/test_migrations.py @@ -6,7 +6,15 @@ from pytest_databases.docker.postgres import PostgresService from sqlspec.adapters.asyncpg.config import AsyncpgConfig +from sqlspec.exceptions import MigrationError from sqlspec.migrations.commands import AsyncMigrationCommands +from tests.integration.adapters._postgres_migration_schema import ( + async_table_exists, + create_schema_sql, + drop_schema_sql, + unique_identifier, + write_unqualified_table_migration, +) pytestmark = pytest.mark.xdist_group("postgres") @@ -646,3 +654,133 @@ def down(): finally: if config.connection_instance: await config.close_pool() + + +async def test_asyncpg_migration_default_schema_applies_to_ddl( + tmp_path: Path, postgres_service: "PostgresService" +) -> None: + """AsyncPG migrations run unqualified DDL in the configured default schema.""" + schema = unique_identifier("asyncpg_default") + table_name = unique_identifier("asyncpg_table") + version_table = unique_identifier("asyncpg_versions") + migration_dir = tmp_path / "migrations" + + config = AsyncpgConfig( + connection_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = AsyncMigrationCommands(config) + + try: + async with config.provide_session() as driver: + await driver.execute_script(create_schema_sql(schema)) + + await commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + await commands.upgrade() + + async with config.provide_session() as driver: + assert await async_table_exists(driver, schema, table_name, style="numeric") + assert not await async_table_exists(driver, "public", table_name, style="numeric") + assert await async_table_exists(driver, schema, version_table, style="numeric") + finally: + async with config.provide_session() as driver: + await driver.execute_script(drop_schema_sql(schema)) + if config.connection_instance: + await config.close_pool() + + +async def test_asyncpg_migration_separable_tracker_and_default_schema( + tmp_path: Path, postgres_service: "PostgresService" +) -> None: + """AsyncPG supports separate schemas for migrated DDL and the tracker table.""" + default_schema = unique_identifier("asyncpg_default") + tracker_schema = unique_identifier("asyncpg_tracker") + table_name = unique_identifier("asyncpg_table") + version_table = unique_identifier("asyncpg_versions") + migration_dir = tmp_path / "migrations" + + config = AsyncpgConfig( + connection_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": default_schema, + "version_table_schema": tracker_schema, + }, + ) + commands = AsyncMigrationCommands(config) + + try: + async with config.provide_session() as driver: + await driver.execute_script(create_schema_sql(default_schema)) + await driver.execute_script(create_schema_sql(tracker_schema)) + + await commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + await commands.upgrade() + + async with config.provide_session() as driver: + assert await async_table_exists(driver, default_schema, table_name, style="numeric") + assert await async_table_exists(driver, tracker_schema, version_table, style="numeric") + assert not await async_table_exists(driver, default_schema, version_table, style="numeric") + finally: + async with config.provide_session() as driver: + await driver.execute_script(drop_schema_sql(default_schema)) + await driver.execute_script(drop_schema_sql(tracker_schema)) + if config.connection_instance: + await config.close_pool() + + +async def test_asyncpg_migration_missing_schema_fails_fast(tmp_path: Path, postgres_service: "PostgresService") -> None: + """AsyncPG validates the default schema before creating tracker tables or applying DDL.""" + schema = unique_identifier("asyncpg_missing") + table_name = unique_identifier("asyncpg_table") + version_table = unique_identifier("asyncpg_versions") + migration_dir = tmp_path / "migrations" + + config = AsyncpgConfig( + connection_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + + with pytest.raises(MigrationError, match=f"Configured schema '{schema}' does not exist"): + await commands.upgrade() + + async with config.provide_session() as driver: + assert not await async_table_exists(driver, "public", version_table, style="numeric") + assert not await async_table_exists(driver, "public", table_name, style="numeric") + finally: + if config.connection_instance: + await config.close_pool() diff --git a/tests/integration/adapters/bigquery/test_migrations.py b/tests/integration/adapters/bigquery/test_migrations.py new file mode 100644 index 000000000..25555363d --- /dev/null +++ b/tests/integration/adapters/bigquery/test_migrations.py @@ -0,0 +1,222 @@ +"""Integration tests for BigQuery migration schema workflow.""" + +import os +from pathlib import Path +from typing import TYPE_CHECKING +from uuid import uuid4 + +import pytest +from google.api_core.client_options import ClientOptions +from google.auth.credentials import AnonymousCredentials +from google.cloud.bigquery import Dataset +from google.cloud.exceptions import NotFound + +from sqlspec.adapters.bigquery import BigQueryConfig +from sqlspec.exceptions import MigrationError +from sqlspec.migrations.commands import SyncMigrationCommands + +if TYPE_CHECKING: + from pytest_databases.docker.bigquery import BigQueryService + +BIGQUERY_ENABLED = os.environ.get("CI") == "true" or os.environ.get("SQLSPEC_ENABLE_BIGQUERY_TESTS") == "1" + +pytestmark = [ + pytest.mark.xdist_group("bigquery"), + pytest.mark.skipif( + not BIGQUERY_ENABLED, + reason="BigQuery emulator is optional locally; set SQLSPEC_ENABLE_BIGQUERY_TESTS=1 to enable", + ), +] + + +def _bigquery_identifier(prefix: str) -> str: + """Return a generated BigQuery identifier.""" + return f"{prefix}_{uuid4().hex[:8]}" + + +def _write_bigquery_unqualified_table_migration(migration_dir: Path, table_name: str) -> None: + migration_content = f'''"""Create an unqualified BigQuery table.""" + + +def up(): + """Create an unqualified table.""" + return [""" + CREATE TABLE {table_name} ( + id INT64, + name STRING NOT NULL + ) + """] + + +def down(): + """Drop the unqualified table.""" + return ["DROP TABLE IF EXISTS {table_name}"] +''' + (migration_dir / "0001_create_unqualified_table.py").write_text(migration_content) + + +def _bigquery_config(bigquery_service: "BigQueryService", *, migration_config: dict[str, object]) -> BigQueryConfig: + return BigQueryConfig( + connection_config={ + "project": bigquery_service.project, + "dataset_id": bigquery_service.dataset, + "client_options": ClientOptions(api_endpoint=f"http://{bigquery_service.host}:{bigquery_service.port}"), + "credentials": AnonymousCredentials(), # type: ignore[no-untyped-call] + }, + migration_config=migration_config, + ) + + +def _create_dataset(config: BigQueryConfig, dataset_name: str) -> None: + with config.provide_session() as driver: + project = str(driver.connection.project) + driver.connection.create_dataset(Dataset(f"{project}.{dataset_name}"), exists_ok=True) + + +def _drop_dataset(config: BigQueryConfig, dataset_name: str) -> None: + if config.connection_instance is None: + return + with config.provide_session() as driver: + project = str(driver.connection.project) + driver.connection.delete_dataset(f"{project}.{dataset_name}", delete_contents=True, not_found_ok=True) + + +def _bigquery_table_exists(config: BigQueryConfig, dataset_name: str, table_name: str) -> bool: + with config.provide_session() as driver: + project = str(driver.connection.project) + try: + driver.connection.get_table(f"{project}.{dataset_name}.{table_name}") + except NotFound: + return False + return True + + +def test_bigquery_migration_default_schema_applies_to_ddl( + tmp_path: Path, native_bigquery_service: "BigQueryService" +) -> None: + """BigQuery migrations run unqualified DDL in the configured default dataset.""" + bigquery_service = native_bigquery_service + dataset_name = _bigquery_identifier("dataset") + table_name = _bigquery_identifier("table") + version_table = _bigquery_identifier("versions") + migration_dir = tmp_path / "migrations" + + config = _bigquery_config( + bigquery_service, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": dataset_name, + }, + ) + + try: + _create_dataset(config, dataset_name) + + commands = SyncMigrationCommands(config) + commands.init(str(migration_dir), package=True) + _write_bigquery_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + assert _bigquery_table_exists(config, dataset_name, table_name) + assert _bigquery_table_exists(config, dataset_name, version_table) + finally: + _drop_dataset(config, dataset_name) + + +def test_bigquery_migration_tracker_lives_in_configured_schema( + tmp_path: Path, native_bigquery_service: "BigQueryService" +) -> None: + """BigQuery stores the tracker table in version_table_schema when configured.""" + bigquery_service = native_bigquery_service + tracker_dataset = _bigquery_identifier("tracker") + table_name = _bigquery_identifier("table") + version_table = _bigquery_identifier("versions") + migration_dir = tmp_path / "migrations" + + config = _bigquery_config( + bigquery_service, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "version_table_schema": tracker_dataset, + }, + ) + + try: + _create_dataset(config, tracker_dataset) + + commands = SyncMigrationCommands(config) + commands.init(str(migration_dir), package=True) + _write_bigquery_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + assert _bigquery_table_exists(config, bigquery_service.dataset, table_name) + assert _bigquery_table_exists(config, tracker_dataset, version_table) + assert not _bigquery_table_exists(config, bigquery_service.dataset, version_table) + finally: + _drop_dataset(config, tracker_dataset) + + +def test_bigquery_migration_separable_tracker_and_default_schema( + tmp_path: Path, native_bigquery_service: "BigQueryService" +) -> None: + """BigQuery supports separate datasets for migrated DDL and the tracker table.""" + bigquery_service = native_bigquery_service + default_dataset = _bigquery_identifier("default") + tracker_dataset = _bigquery_identifier("tracker") + table_name = _bigquery_identifier("table") + version_table = _bigquery_identifier("versions") + migration_dir = tmp_path / "migrations" + + config = _bigquery_config( + bigquery_service, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": default_dataset, + "version_table_schema": tracker_dataset, + }, + ) + + try: + _create_dataset(config, default_dataset) + _create_dataset(config, tracker_dataset) + + commands = SyncMigrationCommands(config) + commands.init(str(migration_dir), package=True) + _write_bigquery_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + assert _bigquery_table_exists(config, default_dataset, table_name) + assert _bigquery_table_exists(config, tracker_dataset, version_table) + assert not _bigquery_table_exists(config, default_dataset, version_table) + finally: + _drop_dataset(config, default_dataset) + _drop_dataset(config, tracker_dataset) + + +def test_bigquery_migration_missing_schema_fails_fast(tmp_path: Path, bigquery_service: "BigQueryService") -> None: + """BigQuery validates the default dataset before creating tracker tables or applying DDL.""" + missing_dataset = _bigquery_identifier("missing") + table_name = _bigquery_identifier("table") + version_table = _bigquery_identifier("versions") + migration_dir = tmp_path / "migrations" + + config = _bigquery_config( + bigquery_service, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": missing_dataset, + }, + ) + commands = SyncMigrationCommands(config) + commands.init(str(migration_dir), package=True) + _write_bigquery_unqualified_table_migration(migration_dir, table_name) + + with pytest.raises(MigrationError, match=f"Configured schema '{missing_dataset}' does not exist"): + commands.upgrade() + + assert not _bigquery_table_exists(config, bigquery_service.dataset, table_name) + assert not _bigquery_table_exists(config, bigquery_service.dataset, version_table) diff --git a/tests/integration/adapters/duckdb/test_migrations.py b/tests/integration/adapters/duckdb/test_migrations.py index 8597c54f3..21f59050a 100644 --- a/tests/integration/adapters/duckdb/test_migrations.py +++ b/tests/integration/adapters/duckdb/test_migrations.py @@ -2,15 +2,190 @@ from pathlib import Path from typing import Any +from uuid import uuid4 import pytest from sqlspec.adapters.duckdb.config import DuckDBConfig +from sqlspec.exceptions import MigrationError from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands, create_migration_commands pytestmark = pytest.mark.xdist_group("duckdb") +def _duckdb_identifier(prefix: str) -> str: + """Return a generated DuckDB identifier.""" + return f"{prefix}_{uuid4().hex[:8]}" + + +def _write_duckdb_unqualified_table_migration(migration_dir: Path, table_name: str) -> None: + migration_content = f'''"""Create an unqualified DuckDB table.""" + + +def up(): + """Create an unqualified table.""" + return [""" + CREATE TABLE {table_name} ( + id INTEGER PRIMARY KEY, + name VARCHAR NOT NULL + ) + """] + + +def down(): + """Drop the unqualified table.""" + return ["DROP TABLE IF EXISTS {table_name}"] +''' + (migration_dir / "0001_create_unqualified_table.py").write_text(migration_content) + + +def _duckdb_table_exists(driver: Any, schema: str, table_name: str) -> bool: + result = driver.execute( + "SELECT 1 FROM information_schema.tables WHERE table_schema = ? AND table_name = ?", (schema, table_name) + ) + return bool(result.data) + + +def test_duckdb_migration_default_schema_applies_to_ddl(tmp_path: Path) -> None: + """DuckDB migrations run unqualified DDL in the configured default schema.""" + schema = _duckdb_identifier("schema") + table_name = _duckdb_identifier("table") + version_table = _duckdb_identifier("versions") + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.duckdb" + + config = DuckDBConfig( + connection_config={"database": str(db_path)}, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = SyncMigrationCommands(config) + + try: + with config.provide_session() as driver: + driver.execute(f"CREATE SCHEMA {schema}") + + commands.init(str(migration_dir), package=True) + _write_duckdb_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + with config.provide_session() as driver: + assert _duckdb_table_exists(driver, schema, table_name) + assert _duckdb_table_exists(driver, schema, version_table) + finally: + if config.connection_instance: + config.close_pool() + + +def test_duckdb_migration_separable_tracker_and_default_schema(tmp_path: Path) -> None: + """DuckDB supports separate schemas for migrated DDL and the tracker table.""" + default_schema = _duckdb_identifier("default_schema") + tracker_schema = _duckdb_identifier("tracker_schema") + table_name = _duckdb_identifier("table") + version_table = _duckdb_identifier("versions") + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.duckdb" + + config = DuckDBConfig( + connection_config={"database": str(db_path)}, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": default_schema, + "version_table_schema": tracker_schema, + }, + ) + commands = SyncMigrationCommands(config) + + try: + with config.provide_session() as driver: + driver.execute(f"CREATE SCHEMA {default_schema}") + driver.execute(f"CREATE SCHEMA {tracker_schema}") + + commands.init(str(migration_dir), package=True) + _write_duckdb_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + with config.provide_session() as driver: + assert _duckdb_table_exists(driver, default_schema, table_name) + assert _duckdb_table_exists(driver, tracker_schema, version_table) + assert not _duckdb_table_exists(driver, default_schema, version_table) + finally: + if config.connection_instance: + config.close_pool() + + +def test_duckdb_migration_tracker_lives_in_configured_schema(tmp_path: Path) -> None: + """DuckDB stores the tracker table in version_table_schema when configured.""" + tracker_schema = _duckdb_identifier("tracker_schema") + table_name = _duckdb_identifier("table") + version_table = _duckdb_identifier("versions") + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.duckdb" + + config = DuckDBConfig( + connection_config={"database": str(db_path)}, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "version_table_schema": tracker_schema, + }, + ) + commands = SyncMigrationCommands(config) + + try: + with config.provide_session() as driver: + driver.execute(f"CREATE SCHEMA {tracker_schema}") + + commands.init(str(migration_dir), package=True) + _write_duckdb_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + with config.provide_session() as driver: + assert _duckdb_table_exists(driver, "main", table_name) + assert _duckdb_table_exists(driver, tracker_schema, version_table) + assert not _duckdb_table_exists(driver, "main", version_table) + finally: + if config.connection_instance: + config.close_pool() + + +def test_duckdb_migration_missing_schema_fails_fast(tmp_path: Path) -> None: + """DuckDB validates the default schema before creating tracker tables or applying DDL.""" + schema = _duckdb_identifier("missing_schema") + table_name = _duckdb_identifier("missing_table") + version_table = _duckdb_identifier("missing_versions") + migration_dir = tmp_path / "migrations" + db_path = tmp_path / "test.duckdb" + + config = DuckDBConfig( + connection_config={"database": str(db_path)}, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = SyncMigrationCommands(config) + + try: + commands.init(str(migration_dir), package=True) + _write_duckdb_unqualified_table_migration(migration_dir, table_name) + + with pytest.raises(MigrationError, match=f"Configured schema '{schema}' does not exist"): + commands.upgrade() + + with config.provide_session() as driver: + assert not _duckdb_table_exists(driver, "main", version_table) + assert not _duckdb_table_exists(driver, "main", table_name) + finally: + if config.connection_instance: + config.close_pool() + + def test_duckdb_migration_full_workflow(tmp_path: Path) -> None: """Test full DuckDB migration workflow: init -> create -> upgrade -> downgrade.""" migration_dir = tmp_path / "migrations" diff --git a/tests/integration/adapters/oracledb/test_migrations.py b/tests/integration/adapters/oracledb/test_migrations.py index c2b3e4f24..f14ae849c 100644 --- a/tests/integration/adapters/oracledb/test_migrations.py +++ b/tests/integration/adapters/oracledb/test_migrations.py @@ -2,16 +2,334 @@ from pathlib import Path from typing import Any +from uuid import uuid4 import pytest from pytest_databases.docker.oracle import OracleService from sqlspec.adapters.oracledb.config import OracleAsyncConfig, OracleSyncConfig +from sqlspec.exceptions import MigrationError from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands, create_migration_commands pytestmark = pytest.mark.xdist_group("oracle") +def _oracle_identifier(prefix: str) -> str: + """Return a generated Oracle-safe identifier within the 30 byte limit.""" + return f"{prefix}_{uuid4().hex[:8]}".upper() + + +def _write_oracle_unqualified_table_migration(migration_dir: Path, table_name: str) -> None: + migration_content = f'''"""Create an unqualified Oracle table.""" + + +def up(): + """Create an unqualified table.""" + return [""" + CREATE TABLE {table_name} ( + id NUMBER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, + name VARCHAR2(255) NOT NULL + ) + """] + + +def down(): + """Drop the unqualified table.""" + return ["DROP TABLE {table_name}"] +''' + (migration_dir / "0001_create_unqualified_table.py").write_text(migration_content) + + +def _oracle_admin_sync_config(oracle_service: OracleService) -> OracleSyncConfig: + return OracleSyncConfig( + connection_config={ + "host": oracle_service.host, + "port": oracle_service.port, + "service_name": oracle_service.service_name, + "user": "system", + "password": oracle_service.system_password, + } + ) + + +def _oracle_admin_async_config(oracle_service: OracleService) -> OracleAsyncConfig: + return OracleAsyncConfig( + connection_config={ + "host": oracle_service.host, + "port": oracle_service.port, + "service_name": oracle_service.service_name, + "user": "system", + "password": oracle_service.system_password, + "min": 1, + "max": 5, + } + ) + + +def _create_oracle_schema(oracle_service: OracleService, schema: str) -> None: + password = f"{schema}_Pwd1" + config = _oracle_admin_sync_config(oracle_service) + try: + with config.provide_session() as driver: + driver.execute_script(f'CREATE USER {schema} IDENTIFIED BY "{password}"') + driver.execute_script( + f"GRANT CREATE SESSION, CREATE TABLE, CREATE SEQUENCE, UNLIMITED TABLESPACE TO {schema}" + ) + driver.execute_script( + "GRANT CREATE ANY TABLE, ALTER ANY TABLE, DROP ANY TABLE, " + "SELECT ANY TABLE, INSERT ANY TABLE, UPDATE ANY TABLE, DELETE ANY TABLE, " + "CREATE ANY INDEX, ALTER ANY INDEX, DROP ANY INDEX, " + f"CREATE ANY SEQUENCE, DROP ANY SEQUENCE, SELECT ANY SEQUENCE TO {oracle_service.user}" + ) + driver.commit() + finally: + if config.connection_instance: + config.close_pool() + + +def _drop_oracle_schema(oracle_service: OracleService, schema: str) -> None: + config = _oracle_admin_sync_config(oracle_service) + try: + with config.provide_session() as driver: + driver.execute_script( + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP USER {schema} CASCADE'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -1918 THEN + RAISE; + END IF; + END; + """ + ) + driver.commit() + finally: + if config.connection_instance: + config.close_pool() + + +async def _async_create_oracle_schema(oracle_service: OracleService, schema: str) -> None: + password = f"{schema}_Pwd1" + config = _oracle_admin_async_config(oracle_service) + try: + async with config.provide_session() as driver: + await driver.execute_script(f'CREATE USER {schema} IDENTIFIED BY "{password}"') + await driver.execute_script( + f"GRANT CREATE SESSION, CREATE TABLE, CREATE SEQUENCE, UNLIMITED TABLESPACE TO {schema}" + ) + await driver.execute_script( + "GRANT CREATE ANY TABLE, ALTER ANY TABLE, DROP ANY TABLE, " + "SELECT ANY TABLE, INSERT ANY TABLE, UPDATE ANY TABLE, DELETE ANY TABLE, " + "CREATE ANY INDEX, ALTER ANY INDEX, DROP ANY INDEX, " + f"CREATE ANY SEQUENCE, DROP ANY SEQUENCE, SELECT ANY SEQUENCE TO {oracle_service.user}" + ) + await driver.commit() + finally: + if config.connection_instance: + await config.close_pool() + + +async def _async_drop_oracle_schema(oracle_service: OracleService, schema: str) -> None: + config = _oracle_admin_async_config(oracle_service) + try: + async with config.provide_session() as driver: + await driver.execute_script( + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP USER {schema} CASCADE'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -1918 THEN + RAISE; + END IF; + END; + """ + ) + await driver.commit() + finally: + if config.connection_instance: + await config.close_pool() + + +def _sync_oracle_table_exists(driver: Any, owner: str, table_name: str) -> bool: + result = driver.execute("SELECT 1 FROM ALL_TABLES WHERE OWNER = :1 AND TABLE_NAME = :2", (owner, table_name)) + return bool(result.data) + + +async def _async_oracle_table_exists(driver: Any, owner: str, table_name: str) -> bool: + result = await driver.execute("SELECT 1 FROM ALL_TABLES WHERE OWNER = :1 AND TABLE_NAME = :2", (owner, table_name)) + return bool(result.data) + + +def test_oracledb_sync_migration_default_schema_applies_to_ddl( + tmp_path: Path, oracle_23ai_service: OracleService +) -> None: + """Oracle sync migrations run unqualified DDL in the configured default schema.""" + schema = _oracle_identifier("OSD") + table_name = _oracle_identifier("TBL") + version_table = _oracle_identifier("OV") + migration_dir = tmp_path / "migrations" + + config = OracleSyncConfig( + connection_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = SyncMigrationCommands(config) + + try: + _create_oracle_schema(oracle_23ai_service, schema) + + commands.init(str(migration_dir), package=True) + _write_oracle_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + with config.provide_session() as driver: + assert _sync_oracle_table_exists(driver, schema, table_name) + assert _sync_oracle_table_exists(driver, schema, version_table) + finally: + _drop_oracle_schema(oracle_23ai_service, schema) + if config.connection_instance: + config.close_pool() + + +async def test_oracledb_async_migration_default_schema_applies_to_ddl( + tmp_path: Path, oracle_23ai_service: OracleService +) -> None: + """Oracle async migrations run unqualified DDL in the configured default schema.""" + schema = _oracle_identifier("OAD") + table_name = _oracle_identifier("OAT") + version_table = _oracle_identifier("OAV") + migration_dir = tmp_path / "migrations" + + config = OracleAsyncConfig( + connection_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + "min": 1, + "max": 5, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = AsyncMigrationCommands(config) + + try: + await _async_create_oracle_schema(oracle_23ai_service, schema) + + await commands.init(str(migration_dir), package=True) + _write_oracle_unqualified_table_migration(migration_dir, table_name) + await commands.upgrade() + + async with config.provide_session() as driver: + assert await _async_oracle_table_exists(driver, schema, table_name) + assert await _async_oracle_table_exists(driver, schema, version_table) + finally: + await _async_drop_oracle_schema(oracle_23ai_service, schema) + if config.connection_instance: + await config.close_pool() + + +def test_oracledb_migration_separable_tracker_and_default_schema( + tmp_path: Path, oracle_23ai_service: OracleService +) -> None: + """Oracle supports separate schemas for migrated DDL and the tracker table.""" + default_schema = _oracle_identifier("ODS") + tracker_schema = _oracle_identifier("OTS") + table_name = _oracle_identifier("TBL") + version_table = _oracle_identifier("OV") + migration_dir = tmp_path / "migrations" + + config = OracleSyncConfig( + connection_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": default_schema, + "version_table_schema": tracker_schema, + }, + ) + commands = SyncMigrationCommands(config) + + try: + _create_oracle_schema(oracle_23ai_service, default_schema) + _create_oracle_schema(oracle_23ai_service, tracker_schema) + + commands.init(str(migration_dir), package=True) + _write_oracle_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + with config.provide_session() as driver: + assert _sync_oracle_table_exists(driver, default_schema, table_name) + assert _sync_oracle_table_exists(driver, tracker_schema, version_table) + assert not _sync_oracle_table_exists(driver, default_schema, version_table) + finally: + _drop_oracle_schema(oracle_23ai_service, default_schema) + _drop_oracle_schema(oracle_23ai_service, tracker_schema) + if config.connection_instance: + config.close_pool() + + +def test_oracledb_migration_missing_schema_fails_fast(tmp_path: Path, oracle_23ai_service: OracleService) -> None: + """Oracle validates the default schema before creating tracker tables or applying DDL.""" + schema = _oracle_identifier("OMS") + table_name = _oracle_identifier("OMT") + version_table = _oracle_identifier("OMV") + migration_dir = tmp_path / "migrations" + + config = OracleSyncConfig( + connection_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = SyncMigrationCommands(config) + + try: + commands.init(str(migration_dir), package=True) + _write_oracle_unqualified_table_migration(migration_dir, table_name) + + with pytest.raises(MigrationError, match=f"Configured schema '{schema}' does not exist"): + commands.upgrade() + + with config.provide_session() as driver: + assert not _sync_oracle_table_exists(driver, oracle_23ai_service.user.upper(), version_table) + assert not _sync_oracle_table_exists(driver, oracle_23ai_service.user.upper(), table_name) + finally: + if config.connection_instance: + config.close_pool() + + def test_oracledb_sync_migration_full_workflow(tmp_path: Path, oracle_23ai_service: OracleService) -> None: """Test full OracleDB sync migration workflow: init -> create -> upgrade -> downgrade.""" diff --git a/tests/integration/adapters/psqlpy/test_migrations.py b/tests/integration/adapters/psqlpy/test_migrations.py index 68d76ccd9..460357857 100644 --- a/tests/integration/adapters/psqlpy/test_migrations.py +++ b/tests/integration/adapters/psqlpy/test_migrations.py @@ -6,7 +6,15 @@ from pytest_databases.docker.postgres import PostgresService from sqlspec.adapters.psqlpy.config import PsqlpyConfig +from sqlspec.exceptions import MigrationError from sqlspec.migrations.commands import AsyncMigrationCommands +from tests.integration.adapters._postgres_migration_schema import ( + async_table_exists, + create_schema_sql, + drop_schema_sql, + unique_identifier, + write_unqualified_table_migration, +) pytestmark = pytest.mark.xdist_group("postgres") @@ -86,6 +94,124 @@ def down(): await config.close_pool() +async def test_psqlpy_migration_default_schema_applies_to_ddl( + tmp_path: Path, postgres_service: "PostgresService" +) -> None: + """Psqlpy migrations run unqualified DDL in the configured default schema.""" + schema = unique_identifier("psqlpy_default") + table_name = unique_identifier("psqlpy_table") + version_table = unique_identifier("psqlpy_versions") + migration_dir = tmp_path / "migrations" + + config = PsqlpyConfig( + connection_config={ + "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = AsyncMigrationCommands(config) + + try: + async with config.provide_session() as driver: + await driver.execute_script(create_schema_sql(schema)) + + await commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + await commands.upgrade() + + async with config.provide_session() as driver: + assert await async_table_exists(driver, schema, table_name, style="numeric") + assert not await async_table_exists(driver, "public", table_name, style="numeric") + assert await async_table_exists(driver, schema, version_table, style="numeric") + finally: + async with config.provide_session() as driver: + await driver.execute_script(drop_schema_sql(schema)) + if config.connection_instance: + await config.close_pool() + + +async def test_psqlpy_migration_separable_tracker_and_default_schema( + tmp_path: Path, postgres_service: "PostgresService" +) -> None: + """Psqlpy supports separate schemas for migrated DDL and the tracker table.""" + default_schema = unique_identifier("psqlpy_default") + tracker_schema = unique_identifier("psqlpy_tracker") + table_name = unique_identifier("psqlpy_table") + version_table = unique_identifier("psqlpy_versions") + migration_dir = tmp_path / "migrations" + + config = PsqlpyConfig( + connection_config={ + "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": default_schema, + "version_table_schema": tracker_schema, + }, + ) + commands = AsyncMigrationCommands(config) + + try: + async with config.provide_session() as driver: + await driver.execute_script(create_schema_sql(default_schema)) + await driver.execute_script(create_schema_sql(tracker_schema)) + + await commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + await commands.upgrade() + + async with config.provide_session() as driver: + assert await async_table_exists(driver, default_schema, table_name, style="numeric") + assert await async_table_exists(driver, tracker_schema, version_table, style="numeric") + assert not await async_table_exists(driver, default_schema, version_table, style="numeric") + finally: + async with config.provide_session() as driver: + await driver.execute_script(drop_schema_sql(default_schema)) + await driver.execute_script(drop_schema_sql(tracker_schema)) + if config.connection_instance: + await config.close_pool() + + +async def test_psqlpy_migration_missing_schema_fails_fast(tmp_path: Path, postgres_service: "PostgresService") -> None: + """Psqlpy validates the default schema before creating tracker tables or applying DDL.""" + schema = unique_identifier("psqlpy_missing") + table_name = unique_identifier("psqlpy_table") + version_table = unique_identifier("psqlpy_versions") + migration_dir = tmp_path / "migrations" + + config = PsqlpyConfig( + connection_config={ + "dsn": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + + with pytest.raises(MigrationError, match=f"Configured schema '{schema}' does not exist"): + await commands.upgrade() + + async with config.provide_session() as driver: + assert not await async_table_exists(driver, "public", version_table, style="numeric") + assert not await async_table_exists(driver, "public", table_name, style="numeric") + finally: + if config.connection_instance: + await config.close_pool() + + async def test_psqlpy_multiple_migrations_workflow(tmp_path: Path, postgres_service: "PostgresService") -> None: """Test Psqlpy workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" diff --git a/tests/integration/adapters/psycopg/test_migrations.py b/tests/integration/adapters/psycopg/test_migrations.py index 01cb1c948..2d721213b 100644 --- a/tests/integration/adapters/psycopg/test_migrations.py +++ b/tests/integration/adapters/psycopg/test_migrations.py @@ -8,7 +8,16 @@ from sqlspec.adapters.psycopg import PsycopgAsyncConfig from sqlspec.adapters.psycopg.config import PsycopgSyncConfig +from sqlspec.exceptions import MigrationError from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands, create_migration_commands +from tests.integration.adapters._postgres_migration_schema import ( + async_table_exists, + create_schema_sql, + drop_schema_sql, + sync_table_exists, + unique_identifier, + write_unqualified_table_migration, +) pytestmark = pytest.mark.xdist_group("postgres") @@ -161,6 +170,168 @@ def down(): await config.close_pool() +def test_psycopg_sync_migration_default_schema_applies_to_ddl( + tmp_path: Path, postgres_service: "PostgresService" +) -> None: + """Psycopg sync migrations run unqualified DDL in the configured default schema.""" + schema = unique_identifier("psycopg_sync_default") + table_name = unique_identifier("psycopg_sync_table") + version_table = unique_identifier("psycopg_sync_versions") + migration_dir = tmp_path / "migrations" + + config = PsycopgSyncConfig( + connection_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = SyncMigrationCommands(config) + + try: + with config.provide_session() as driver: + driver.execute_script(create_schema_sql(schema)) + driver.commit() + + commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + with config.provide_session() as driver: + assert sync_table_exists(driver, schema, table_name, style="pyformat") + assert not sync_table_exists(driver, "public", table_name, style="pyformat") + assert sync_table_exists(driver, schema, version_table, style="pyformat") + finally: + with config.provide_session() as driver: + driver.execute_script(drop_schema_sql(schema)) + driver.commit() + if config.connection_instance: + config.close_pool() + + +async def test_psycopg_async_migration_default_schema_applies_to_ddl( + tmp_path: Path, postgres_service: "PostgresService" +) -> None: + """Psycopg async migrations run unqualified DDL in the configured default schema.""" + schema = unique_identifier("psycopg_async_default") + table_name = unique_identifier("psycopg_async_table") + version_table = unique_identifier("psycopg_async_versions") + migration_dir = tmp_path / "migrations" + + config = PsycopgAsyncConfig( + connection_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = AsyncMigrationCommands(config) + + try: + async with config.provide_session() as driver: + await driver.execute_script(create_schema_sql(schema)) + + await commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + await commands.upgrade() + + async with config.provide_session() as driver: + assert await async_table_exists(driver, schema, table_name, style="pyformat") + assert not await async_table_exists(driver, "public", table_name, style="pyformat") + assert await async_table_exists(driver, schema, version_table, style="pyformat") + finally: + async with config.provide_session() as driver: + await driver.execute_script(drop_schema_sql(schema)) + if config.connection_instance: + await config.close_pool() + + +def test_psycopg_migration_separable_tracker_and_default_schema( + tmp_path: Path, postgres_service: "PostgresService" +) -> None: + """Psycopg supports separate schemas for migrated DDL and the tracker table.""" + default_schema = unique_identifier("psycopg_default") + tracker_schema = unique_identifier("psycopg_tracker") + table_name = unique_identifier("psycopg_table") + version_table = unique_identifier("psycopg_versions") + migration_dir = tmp_path / "migrations" + + config = PsycopgSyncConfig( + connection_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": default_schema, + "version_table_schema": tracker_schema, + }, + ) + commands = SyncMigrationCommands(config) + + try: + with config.provide_session() as driver: + driver.execute_script(create_schema_sql(default_schema)) + driver.execute_script(create_schema_sql(tracker_schema)) + driver.commit() + + commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + commands.upgrade() + + with config.provide_session() as driver: + assert sync_table_exists(driver, default_schema, table_name, style="pyformat") + assert sync_table_exists(driver, tracker_schema, version_table, style="pyformat") + assert not sync_table_exists(driver, default_schema, version_table, style="pyformat") + finally: + with config.provide_session() as driver: + driver.execute_script(drop_schema_sql(default_schema)) + driver.execute_script(drop_schema_sql(tracker_schema)) + driver.commit() + if config.connection_instance: + config.close_pool() + + +def test_psycopg_migration_missing_schema_fails_fast(tmp_path: Path, postgres_service: "PostgresService") -> None: + """Psycopg validates the default schema before creating tracker tables or applying DDL.""" + schema = unique_identifier("psycopg_missing") + table_name = unique_identifier("psycopg_table") + version_table = unique_identifier("psycopg_versions") + migration_dir = tmp_path / "migrations" + + config = PsycopgSyncConfig( + connection_config={ + "conninfo": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": version_table, + "default_schema": schema, + }, + ) + commands = SyncMigrationCommands(config) + + try: + commands.init(str(migration_dir), package=True) + write_unqualified_table_migration(migration_dir, table_name) + + with pytest.raises(MigrationError, match=f"Configured schema '{schema}' does not exist"): + commands.upgrade() + + with config.provide_session() as driver: + assert not sync_table_exists(driver, "public", version_table, style="pyformat") + assert not sync_table_exists(driver, "public", table_name, style="pyformat") + finally: + if config.connection_instance: + config.close_pool() + + def test_psycopg_sync_multiple_migrations_workflow(tmp_path: Path, postgres_service: "PostgresService") -> None: """Test Psycopg sync workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" diff --git a/tests/unit/adapters/test_bigquery/test_migration_schema.py b/tests/unit/adapters/test_bigquery/test_migration_schema.py new file mode 100644 index 000000000..fb9700bfd --- /dev/null +++ b/tests/unit/adapters/test_bigquery/test_migration_schema.py @@ -0,0 +1,92 @@ +"""Unit coverage for BigQuery migration schema hooks.""" + +from typing import Any, cast + +from google.cloud.bigquery import QueryJobConfig +from google.cloud.exceptions import NotFound + +from sqlspec.adapters.bigquery.config import BigQueryConfig +from sqlspec.adapters.bigquery.driver import BigQueryDriver +from sqlspec.migrations.tracker import SyncMigrationTracker + + +class FakeBigQueryClient: + project = "test-project" + + def __init__(self, datasets: set[str] | None = None) -> None: + self.datasets = datasets or set() + self.requested_datasets: list[str] = [] + + def get_dataset(self, dataset_path: str) -> object: + self.requested_datasets.append(dataset_path) + if dataset_path not in self.datasets: + raise NotFound("missing dataset") # type: ignore[no-untyped-call] + return object() + + +def _bigquery_driver(client: FakeBigQueryClient, default_config: QueryJobConfig | None = None) -> BigQueryDriver: + config = BigQueryConfig(connection_config={"project": client.project}, connection_instance=client) + features: dict[str, Any] = {} + if default_config is not None: + features["default_query_job_config"] = default_config + return BigQueryDriver( + client, # type: ignore[arg-type] + statement_config=config.statement_config, + driver_features=features, + ) + + +def _compile_bigquery_tracker_create_table_sql() -> str: + client = FakeBigQueryClient() + driver = _bigquery_driver(client) + tracker = SyncMigrationTracker("ddl_migrations", version_table_schema="tenant") + statement = driver.prepare_statement(tracker._get_create_table_sql()) # pyright: ignore[reportPrivateUsage] + sql, _ = driver._get_compiled_sql(statement, driver.statement_config) # pyright: ignore[reportPrivateUsage] + return sql + + +def test_bigquery_uses_shared_sync_migration_tracker() -> None: + assert BigQueryConfig.migration_tracker_type is SyncMigrationTracker + + +def test_bigquery_migration_schema_sets_driver_default_dataset_without_mutating_base_config() -> None: + client = FakeBigQueryClient() + base_config = QueryJobConfig() + base_config.use_query_cache = False + driver = _bigquery_driver(client, base_config) + + driver.set_migration_session_schema("tenant") + + driver_config = cast("Any", driver)._default_query_job_config + assert str(driver_config.default_dataset) == "test-project.tenant" + assert driver_config.use_query_cache is False + assert base_config.default_dataset is None + assert BigQueryConfig.supports_migration_schemas is True + + +def test_bigquery_migration_schema_accepts_project_qualified_dataset() -> None: + client = FakeBigQueryClient() + driver = _bigquery_driver(client) + + driver.set_migration_session_schema("other-project.analytics") + + driver_config = cast("Any", driver)._default_query_job_config + assert str(driver_config.default_dataset) == "other-project.analytics" + + +def test_bigquery_has_schema_uses_client_dataset_lookup() -> None: + client = FakeBigQueryClient({"test-project.tenant"}) + driver = _bigquery_driver(client) + + assert driver.has_schema("tenant") is True + assert driver.has_schema("missing") is False + assert client.requested_datasets == ["test-project.tenant", "test-project.missing"] + + +def test_bigquery_shared_tracker_compiles_valid_create_table_ddl() -> None: + create_sql = _compile_bigquery_tracker_create_table_sql() + + assert "CREATE TABLE IF NOT EXISTS `tenant`.`ddl_migrations`" in create_sql + assert "version_num` STRING(32) PRIMARY KEY NOT ENFORCED" in create_sql + assert "applied_at` DATETIME DEFAULT CURRENT_TIMESTAMP() NOT NULL" in create_sql + assert "execution_sequence` INT64" in create_sql diff --git a/tests/unit/adapters/test_duckdb/test_migration_schema.py b/tests/unit/adapters/test_duckdb/test_migration_schema.py new file mode 100644 index 000000000..4328d7185 --- /dev/null +++ b/tests/unit/adapters/test_duckdb/test_migration_schema.py @@ -0,0 +1,49 @@ +"""Unit coverage for DuckDB migration schema hooks.""" + +from typing import Any + +from sqlspec.adapters.duckdb.config import DuckDBConfig +from sqlspec.adapters.duckdb.driver import DuckDBDriver + + +class FakeDuckDBConnection: + def __init__(self, schema_exists: bool = True) -> None: + self.schema_exists = schema_exists + self.executed: list[tuple[str, Any]] = [] + + def execute(self, sql: str, parameters: Any | None = None) -> "FakeDuckDBConnection": + self.executed.append((sql, parameters)) + return self + + def fetchone(self) -> tuple[int] | None: + return (1,) if self.schema_exists else None + + +def test_duckdb_migration_schema_hooks() -> None: + connection = FakeDuckDBConnection() + driver = DuckDBDriver(connection) # type: ignore[arg-type] + + driver.set_migration_session_schema("tenant") + assert driver.has_schema("tenant") is True + + assert connection.executed == [ + ("SET search_path = 'tenant'", None), + ("SELECT 1 FROM information_schema.schemata WHERE schema_name = ?", ["tenant"]), + ] + assert DuckDBConfig.supports_migration_schemas is True + + +def test_duckdb_has_schema_returns_false_for_missing_schema() -> None: + connection = FakeDuckDBConnection(schema_exists=False) + driver = DuckDBDriver(connection) # type: ignore[arg-type] + + assert driver.has_schema("missing") is False + + +def test_duckdb_migration_schema_escapes_search_path_literal() -> None: + connection = FakeDuckDBConnection() + driver = DuckDBDriver(connection) # type: ignore[arg-type] + + driver.set_migration_session_schema("tenant's") + + assert connection.executed == [("SET search_path = 'tenant''s'", None)] diff --git a/tests/unit/adapters/test_migration_schema_noops.py b/tests/unit/adapters/test_migration_schema_noops.py new file mode 100644 index 000000000..c83e30b80 --- /dev/null +++ b/tests/unit/adapters/test_migration_schema_noops.py @@ -0,0 +1,74 @@ +"""Unit coverage for adapters that accept migration schemas as no-ops.""" + +import pytest + +from sqlspec.adapters.adbc.driver import AdbcDriver +from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver +from sqlspec.adapters.asyncmy.driver import AsyncmyDriver +from sqlspec.adapters.sqlite.driver import SqliteDriver +from sqlspec.core import StatementConfig + + +class FakeConnection: + def adbc_get_info(self) -> dict[str, str]: + return {"driver_name": "sql server"} + + +def test_sqlite_migration_schema_noop(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level("DEBUG") + driver = SqliteDriver(object()) # type: ignore[arg-type] + + driver.set_migration_session_schema("tenant") + + assert driver.has_schema("tenant") is True + assert "SQLite driver does not support default schemas" in caplog.text + + +@pytest.mark.anyio +async def test_aiosqlite_migration_schema_noop(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level("DEBUG") + driver = AiosqliteDriver(object()) # type: ignore[arg-type] + + await driver.set_migration_session_schema("tenant") + + assert await driver.has_schema("tenant") is True + assert "aiosqlite driver does not support default schemas" in caplog.text + + +@pytest.mark.anyio +async def test_asyncmy_migration_schema_noop(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level("DEBUG") + driver = AsyncmyDriver(object()) # type: ignore[arg-type] + + await driver.set_migration_session_schema("tenant") + + assert await driver.has_schema("tenant") is True + assert "asyncmy driver does not support default schemas" in caplog.text + + +def test_adbc_sql_server_migration_schema_noop(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level("DEBUG") + driver = AdbcDriver( + FakeConnection(), # type: ignore[arg-type] + statement_config=StatementConfig(dialect="tsql"), + driver_features={}, + ) + + driver.set_migration_session_schema("tenant") + + assert driver.has_schema("tenant") is True + assert "SQL Server schema support not yet implemented for ADBC" in caplog.text + + +def test_adbc_non_postgres_migration_schema_noop(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level("DEBUG") + driver = AdbcDriver( + FakeConnection(), # type: ignore[arg-type] + statement_config=StatementConfig(dialect="sqlite"), + driver_features={}, + ) + + driver.set_migration_session_schema("tenant") + + assert driver.has_schema("tenant") is True + assert "ADBC driver does not support default schemas" in caplog.text diff --git a/tests/unit/adapters/test_oracledb/test_migration_schema.py b/tests/unit/adapters/test_oracledb/test_migration_schema.py new file mode 100644 index 000000000..131eddc22 --- /dev/null +++ b/tests/unit/adapters/test_oracledb/test_migration_schema.py @@ -0,0 +1,91 @@ +"""Unit coverage for Oracle migration schema hooks.""" + +from typing import Any + +import pytest + +from sqlspec.adapters.oracledb.config import OracleAsyncConfig, OracleSyncConfig +from sqlspec.adapters.oracledb.driver import OracleAsyncDriver, OracleSyncDriver + + +class FakeSyncCursor: + def __init__(self, schema_exists: bool = True) -> None: + self.schema_exists = schema_exists + self.executed: list[tuple[str, Any]] = [] + + def execute(self, sql: str, parameters: Any | None = None) -> None: + self.executed.append((sql, parameters)) + + def fetchone(self) -> tuple[int] | None: + return (1,) if self.schema_exists else None + + def close(self) -> None: + return None + + +class FakeSyncConnection: + def __init__(self, cursor: FakeSyncCursor) -> None: + self._cursor = cursor + + def cursor(self) -> FakeSyncCursor: + return self._cursor + + +class FakeAsyncCursor: + def __init__(self, schema_exists: bool = True) -> None: + self.schema_exists = schema_exists + self.executed: list[tuple[str, Any]] = [] + + async def execute(self, sql: str, parameters: Any | None = None) -> None: + self.executed.append((sql, parameters)) + + async def fetchone(self) -> tuple[int] | None: + return (1,) if self.schema_exists else None + + def close(self) -> None: + return None + + +class FakeAsyncConnection: + def __init__(self, cursor: FakeAsyncCursor) -> None: + self._cursor = cursor + + def cursor(self) -> FakeAsyncCursor: + return self._cursor + + +def test_oracle_sync_migration_schema_hooks() -> None: + cursor = FakeSyncCursor() + driver = OracleSyncDriver(FakeSyncConnection(cursor)) # type: ignore[arg-type] + + driver.set_migration_session_schema("tenant") + assert driver.has_schema("tenant") is True + + assert cursor.executed == [ + ('ALTER SESSION SET CURRENT_SCHEMA = "TENANT"', None), + ("SELECT 1 FROM ALL_USERS WHERE USERNAME = UPPER(:schema_name)", {"schema_name": "tenant"}), + ] + assert OracleSyncConfig.supports_migration_schemas is True + + +@pytest.mark.anyio +async def test_oracle_async_migration_schema_hooks() -> None: + cursor = FakeAsyncCursor() + driver = OracleAsyncDriver(FakeAsyncConnection(cursor)) # type: ignore[arg-type] + + await driver.set_migration_session_schema("tenant") + assert await driver.has_schema("tenant") is True + + assert cursor.executed == [ + ('ALTER SESSION SET CURRENT_SCHEMA = "TENANT"', None), + ("SELECT 1 FROM ALL_USERS WHERE USERNAME = UPPER(:schema_name)", {"schema_name": "tenant"}), + ] + assert OracleAsyncConfig.supports_migration_schemas is True + + +@pytest.mark.anyio +async def test_oracle_async_has_schema_returns_false_for_missing_user() -> None: + cursor = FakeAsyncCursor(schema_exists=False) + driver = OracleAsyncDriver(FakeAsyncConnection(cursor)) # type: ignore[arg-type] + + assert await driver.has_schema("missing") is False diff --git a/tests/unit/adapters/test_postgres_migration_schema.py b/tests/unit/adapters/test_postgres_migration_schema.py new file mode 100644 index 000000000..9d0cc6908 --- /dev/null +++ b/tests/unit/adapters/test_postgres_migration_schema.py @@ -0,0 +1,199 @@ +# pyright: reportPrivateUsage=false +"""Unit coverage for PostgreSQL migration schema hooks.""" + +from typing import Any + +import pytest + +from sqlspec.adapters.adbc.config import AdbcConfig +from sqlspec.adapters.adbc.driver import AdbcDriver +from sqlspec.adapters.asyncpg.config import AsyncpgConfig +from sqlspec.adapters.asyncpg.driver import AsyncpgDriver +from sqlspec.adapters.psqlpy.config import PsqlpyConfig +from sqlspec.adapters.psqlpy.driver import PsqlpyDriver +from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig, PsycopgSyncConfig +from sqlspec.adapters.psycopg.driver import PsycopgAsyncDriver, PsycopgSyncDriver + + +class FakeAsyncpgConnection: + def __init__(self, schema_exists: bool = True) -> None: + self.schema_exists = schema_exists + self.executed: list[tuple[str, tuple[Any, ...]]] = [] + + async def execute(self, sql: str, *parameters: Any) -> str: + self.executed.append((sql, parameters)) + return "OK" + + async def fetchval(self, sql: str, *parameters: Any) -> int | None: + self.executed.append((sql, parameters)) + return 1 if self.schema_exists else None + + +class FakePsqlpyConnection: + def __init__(self, schema_exists: bool = True) -> None: + self.schema_exists = schema_exists + self.executed: list[tuple[str, Any]] = [] + + async def execute(self, sql: str, parameters: Any | None = None) -> str: + self.executed.append((sql, parameters)) + return "OK" + + async def fetch(self, sql: str, parameters: Any | None = None) -> list[tuple[int]]: + self.executed.append((sql, parameters)) + return [(1,)] if self.schema_exists else [] + + +class FakeSyncCursor: + def __init__(self, schema_exists: bool = True) -> None: + self.schema_exists = schema_exists + self.executed: list[tuple[str, Any]] = [] + + def execute(self, sql: str, parameters: Any | None = None) -> None: + self.executed.append((sql, parameters)) + + def fetchone(self) -> tuple[int] | None: + return (1,) if self.schema_exists else None + + def close(self) -> None: + return None + + +class FakePsycopgSyncConnection: + def __init__(self, cursor: FakeSyncCursor) -> None: + self._cursor = cursor + + def cursor(self) -> FakeSyncCursor: + return self._cursor + + +class FakeAsyncCursor: + def __init__(self, schema_exists: bool = True) -> None: + self.schema_exists = schema_exists + self.executed: list[tuple[str, Any]] = [] + + async def execute(self, sql: str, parameters: Any | None = None) -> None: + self.executed.append((sql, parameters)) + + async def fetchone(self) -> tuple[int] | None: + return (1,) if self.schema_exists else None + + async def close(self) -> None: + return None + + +class FakePsycopgAsyncConnection: + def __init__(self, cursor: FakeAsyncCursor) -> None: + self._cursor = cursor + + def cursor(self) -> FakeAsyncCursor: + return self._cursor + + +class FakeAdbcConnection: + def __init__(self, cursor: FakeSyncCursor, *, dialect: str = "postgresql") -> None: + self._cursor = cursor + self._dialect = dialect + + def adbc_get_info(self) -> dict[str, str]: + return {"vendor_name": self._dialect, "driver_name": self._dialect} + + def cursor(self) -> FakeSyncCursor: + return self._cursor + + +@pytest.mark.anyio +async def test_asyncpg_migration_schema_hooks() -> None: + connection = FakeAsyncpgConnection() + driver = AsyncpgDriver(connection) # type: ignore[arg-type] + + await driver.set_migration_session_schema('tenant_"one"') + assert await driver.has_schema("tenant") is True + + assert connection.executed == [ + ('SET LOCAL search_path TO "tenant_""one""", "$user", public', ()), + ("SELECT 1 FROM information_schema.schemata WHERE schema_name = $1", ("tenant",)), + ] + assert AsyncpgConfig.supports_migration_schemas is True + + +@pytest.mark.anyio +async def test_psqlpy_migration_schema_hooks() -> None: + connection = FakePsqlpyConnection() + driver = PsqlpyDriver(connection) # type: ignore[arg-type] + + await driver.set_migration_session_schema("tenant") + assert await driver.has_schema("missing") is True + + assert connection.executed == [ + ('SET LOCAL search_path TO "tenant", "$user", public', None), + ("SELECT 1 FROM information_schema.schemata WHERE schema_name = $1", ["missing"]), + ] + assert PsqlpyConfig.supports_migration_schemas is True + + +@pytest.mark.anyio +async def test_psqlpy_has_schema_returns_false_for_empty_result() -> None: + connection = FakePsqlpyConnection(schema_exists=False) + driver = PsqlpyDriver(connection) # type: ignore[arg-type] + + assert await driver.has_schema("missing") is False + + +def test_psycopg_sync_migration_schema_hooks() -> None: + cursor = FakeSyncCursor() + driver = PsycopgSyncDriver(FakePsycopgSyncConnection(cursor)) # type: ignore[arg-type] + + driver.set_migration_session_schema("tenant") + assert driver.has_schema("tenant") is True + + assert cursor.executed == [ + ('SET LOCAL search_path TO "tenant", "$user", public', None), + ("SELECT 1 FROM information_schema.schemata WHERE schema_name = %s", ("tenant",)), + ] + assert PsycopgSyncConfig.supports_migration_schemas is True + + +@pytest.mark.anyio +async def test_psycopg_async_migration_schema_hooks() -> None: + cursor = FakeAsyncCursor() + driver = PsycopgAsyncDriver(FakePsycopgAsyncConnection(cursor)) # type: ignore[arg-type] + + await driver.set_migration_session_schema("tenant") + assert await driver.has_schema("missing") is True + + assert cursor.executed == [ + ('SET LOCAL search_path TO "tenant", "$user", public', None), + ("SELECT 1 FROM information_schema.schemata WHERE schema_name = %s", ("missing",)), + ] + assert PsycopgAsyncConfig.supports_migration_schemas is True + + +def test_adbc_postgres_migration_schema_hooks() -> None: + cursor = FakeSyncCursor() + driver = AdbcDriver(FakeAdbcConnection(cursor)) # type: ignore[arg-type] + + driver.set_migration_session_schema("tenant") + assert driver.has_schema("tenant") is True + + assert cursor.executed == [ + ('SET search_path TO "tenant", "$user", public', None), + ("SELECT 1 FROM information_schema.schemata WHERE schema_name = $1", ["tenant"]), + ] + + +def test_adbc_non_postgres_migration_schema_hooks_are_noops() -> None: + cursor = FakeSyncCursor() + driver = AdbcDriver(FakeAdbcConnection(cursor, dialect="sqlite")) # type: ignore[arg-type] + + driver.set_migration_session_schema("tenant") + assert driver.has_schema("tenant") is True + + assert cursor.executed == [] + + +def test_adbc_config_supports_migration_schemas_for_postgres_only() -> None: + pg_config = AdbcConfig(connection_config={"uri": "postgresql://example.invalid/db"}) + sqlite_config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": ":memory:"}) + + assert pg_config.supports_migration_schemas is True + assert sqlite_config.supports_migration_schemas is False diff --git a/tests/unit/base/test_set_migration_config.py b/tests/unit/base/test_set_migration_config.py index 64238e482..f191a55ec 100644 --- a/tests/unit/base/test_set_migration_config.py +++ b/tests/unit/base/test_set_migration_config.py @@ -44,9 +44,16 @@ def test_set_migration_config_with_typed_dict() -> None: """set_migration_config works with MigrationConfig TypedDict values.""" from sqlspec.config import MigrationConfig - mc: MigrationConfig = {"script_location": "alembic", "enabled": True} + mc: MigrationConfig = { + "script_location": "alembic", + "enabled": True, + "default_schema": "app_schema", + "version_table_schema": "migration_schema", + } config = SqliteConfig(connection_config={"database": ":memory:"}) config.set_migration_config(mc) assert config.migration_config["script_location"] == "alembic" assert config.migration_config["enabled"] is True + assert config.migration_config["default_schema"] == "app_schema" + assert config.migration_config["version_table_schema"] == "migration_schema" diff --git a/tests/unit/builder/test_insert_builder.py b/tests/unit/builder/test_insert_builder.py index c1ffbba06..02c79fdf9 100644 --- a/tests/unit/builder/test_insert_builder.py +++ b/tests/unit/builder/test_insert_builder.py @@ -30,6 +30,15 @@ def test_insert_with_table_in_constructor() -> None: assert "price" in stmt.parameters +def test_insert_preserves_schema_qualified_table() -> None: + """INSERT column handling should keep schema-qualified table targets intact.""" + query = sql.insert("history.ddl_migrations").columns("version_num").values("0001") + stmt = query.build() + + assert '"history"."ddl_migrations"' in stmt.sql + assert stmt.parameters["version_num"] == "0001" + + def test_insert_values_from_dict() -> None: """Test INSERT using values_from_dict method (deprecated; emits DeprecationWarning).""" data = {"id": 1, "name": "John", "status": "active"} diff --git a/tests/unit/migrations/test_migration_commands.py b/tests/unit/migrations/test_migration_commands.py index 59de0345c..51ee9b819 100644 --- a/tests/unit/migrations/test_migration_commands.py +++ b/tests/unit/migrations/test_migration_commands.py @@ -10,13 +10,15 @@ - Command routing and parameter passing """ +import logging from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from sqlspec.adapters.aiosqlite.config import AiosqliteConfig from sqlspec.adapters.sqlite.config import SqliteConfig +from sqlspec.exceptions import MigrationError from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands, create_migration_commands @@ -227,6 +229,94 @@ def test_async_migration_commands_initialization(async_config: AiosqliteConfig) assert hasattr(commands, "runner") +class SchemaAwareSqliteConfig(SqliteConfig): + supports_migration_schemas = True + + +class SchemaAwareAiosqliteConfig(AiosqliteConfig): + supports_migration_schemas = True + + +def test_sync_commands_pass_resolved_tracker_schema_when_supported(tmp_path: Path) -> None: + config = SchemaAwareSqliteConfig( + connection_config={"database": ":memory:"}, + migration_config={"script_location": str(tmp_path), "default_schema": "app_schema"}, + ) + + commands = SyncMigrationCommands(config) + + assert commands.tracker.version_table_schema == "app_schema" + assert commands.tracker.version_table == "app_schema.ddl_migrations" + + +def test_async_commands_pass_resolved_tracker_schema_when_supported(tmp_path: Path) -> None: + config = SchemaAwareAiosqliteConfig( + connection_config={"database": ":memory:"}, + migration_config={"script_location": str(tmp_path), "version_table_schema": "history_schema"}, + ) + + commands = AsyncMigrationCommands(config) + + assert commands.tracker.version_table_schema == "history_schema" + assert commands.tracker.version_table == "history_schema.ddl_migrations" + + +def test_sync_validate_migration_schema_raises_for_missing_schema(sync_config: SqliteConfig) -> None: + sync_config.set_migration_config({"default_schema": "missing_schema"}) + commands = SyncMigrationCommands(sync_config) + driver = MagicMock() + driver.has_schema.return_value = False + + with pytest.raises(MigrationError, match="Configured schema 'missing_schema' does not exist"): + commands._validate_migration_schema(driver) + + driver.has_schema.assert_called_once_with("missing_schema") + + +async def test_async_validate_migration_schema_raises_for_missing_schema(async_config: AiosqliteConfig) -> None: + async_config.set_migration_config({"default_schema": "missing_schema"}) + commands = AsyncMigrationCommands(async_config) + driver = AsyncMock() + driver.has_schema.return_value = False + + with pytest.raises(MigrationError, match="Configured schema 'missing_schema' does not exist"): + await commands._validate_migration_schema(driver) + + driver.has_schema.assert_awaited_once_with("missing_schema") + + +def test_sqlite_default_schema_noop_migration_succeeds(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: + migrations_dir = tmp_path / "migrations" + migrations_dir.mkdir() + migration_file = migrations_dir / "0001_schema_noop.sql" + migration_file.write_text( + """ +-- name: migrate-0001-up +CREATE TABLE schema_noop_example (id INTEGER PRIMARY KEY); + +-- name: migrate-0001-down +DROP TABLE schema_noop_example; +""".strip(), + encoding="utf-8", + ) + config = SqliteConfig( + connection_config={"database": str(tmp_path / "db.sqlite")}, + migration_config={"script_location": str(migrations_dir), "default_schema": "ignored_schema"}, + ) + commands = SyncMigrationCommands(config) + + with caplog.at_level(logging.DEBUG, logger="sqlspec.driver"): + commands.upgrade(echo=False) + + assert any(record.getMessage() == "migration.schema.noop" for record in caplog.records) + assert commands.tracker.version_table == "ddl_migrations" + with config.provide_session() as driver: + result = driver.select_value( + "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", ["schema_noop_example"] + ) + assert result == "schema_noop_example" + + def test_sync_migration_commands_init_creates_directory(tmp_path: Path, sync_config: SqliteConfig) -> None: """Test that SyncMigrationCommands init creates migration directory structure.""" commands = SyncMigrationCommands(sync_config) diff --git a/tests/unit/migrations/test_migration_runner.py b/tests/unit/migrations/test_migration_runner.py index 833ad7e83..79766c74e 100644 --- a/tests/unit/migrations/test_migration_runner.py +++ b/tests/unit/migrations/test_migration_runner.py @@ -12,15 +12,53 @@ import time from pathlib import Path from typing import Any, cast -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, call, patch import pytest from sqlspec.loader import SQLFileLoader as CoreSQLFileLoader from sqlspec.migrations import runner as runner_module from sqlspec.migrations.base import BaseMigrationRunner +from sqlspec.migrations.context import MigrationContext from sqlspec.migrations.loaders import SQLFileLoader as MigrationSQLFileLoader -from sqlspec.migrations.runner import SyncMigrationRunner +from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner + + +class _RunnerConfig: + supports_transactional_ddl = True + + def __init__(self, migration_config: dict[str, Any]) -> None: + self.migration_config = migration_config + + +class _SyncMigrationLoader: + def get_up_sql(self, _file_path: Path) -> list[str]: + return ["CREATE TABLE example (id INTEGER)"] + + def get_down_sql(self, _file_path: Path) -> list[str]: + return ["DROP TABLE example"] + + +class _AsyncMigrationLoader: + async def get_up_sql(self, _file_path: Path) -> list[str]: + return ["CREATE TABLE example (id INTEGER)"] + + async def get_down_sql(self, _file_path: Path) -> list[str]: + return ["DROP TABLE example"] + + +def _migration(file_path: Path, loader: Any) -> dict[str, Any]: + return {"version": "0001", "file_path": file_path, "loader": loader, "has_upgrade": True, "has_downgrade": True} + + +def _sync_runner(tmp_path: Path, migration_config: dict[str, Any]) -> SyncMigrationRunner: + context = MigrationContext(config=_RunnerConfig(migration_config)) + return SyncMigrationRunner(tmp_path, {}, context, {}) + + +def _async_runner(tmp_path: Path, migration_config: dict[str, Any]) -> AsyncMigrationRunner: + context = MigrationContext(config=_RunnerConfig(migration_config)) + return AsyncMigrationRunner(tmp_path, {}, context, {}) def create_test_migration_runner(migrations_path: Path = Path("/test")) -> BaseMigrationRunner: @@ -115,6 +153,83 @@ def _write_basic_sql(path: Path, version: str, body: str = "SELECT 1;") -> None: ) +def test_sync_execute_upgrade_sets_migration_schema_after_begin(tmp_path: Path) -> None: + """Configured migration schemas should be applied inside the migration transaction.""" + migration_file = tmp_path / "0001_schema.sql" + _write_basic_sql(migration_file, "0001") + runner = _sync_runner(tmp_path, {"default_schema": "app_schema"}) + driver = Mock() + + runner.execute_upgrade(driver, _migration(migration_file, _SyncMigrationLoader()), use_transaction=True) + + assert driver.mock_calls[:4] == [ + call.begin(), + call.set_migration_session_schema("app_schema"), + call.execute_script("CREATE TABLE example (id INTEGER)"), + call.commit(), + ] + + +def test_sync_execute_upgrade_skips_migration_schema_when_unset(tmp_path: Path) -> None: + """Back-compat path should not call the schema hook when default_schema is unset.""" + migration_file = tmp_path / "0001_schema.sql" + _write_basic_sql(migration_file, "0001") + runner = _sync_runner(tmp_path, {}) + driver = Mock() + + runner.execute_upgrade(driver, _migration(migration_file, _SyncMigrationLoader()), use_transaction=True) + + driver.set_migration_session_schema.assert_not_called() + + +def test_sync_execute_downgrade_sets_migration_schema_after_begin(tmp_path: Path) -> None: + """Downgrades should use the same schema setup order as upgrades.""" + migration_file = tmp_path / "0001_schema.sql" + _write_basic_sql(migration_file, "0001") + runner = _sync_runner(tmp_path, {"default_schema": "app_schema"}) + driver = Mock() + + runner.execute_downgrade(driver, _migration(migration_file, _SyncMigrationLoader()), use_transaction=True) + + assert driver.mock_calls[:4] == [ + call.begin(), + call.set_migration_session_schema("app_schema"), + call.execute_script("DROP TABLE example"), + call.commit(), + ] + + +@pytest.mark.anyio +async def test_async_execute_upgrade_sets_migration_schema_after_begin(tmp_path: Path) -> None: + """Async migrations should await the schema hook inside the transaction.""" + migration_file = tmp_path / "0001_schema.sql" + _write_basic_sql(migration_file, "0001") + runner = _async_runner(tmp_path, {"default_schema": "app_schema"}) + driver = AsyncMock() + + await runner.execute_upgrade(driver, _migration(migration_file, _AsyncMigrationLoader()), use_transaction=True) + + assert driver.mock_calls[:4] == [ + call.begin(), + call.set_migration_session_schema("app_schema"), + call.execute_script("CREATE TABLE example (id INTEGER)"), + call.commit(), + ] + + +@pytest.mark.anyio +async def test_async_execute_upgrade_skips_migration_schema_when_unset(tmp_path: Path) -> None: + """Async back-compat path should not call the schema hook when default_schema is unset.""" + migration_file = tmp_path / "0001_schema.sql" + _write_basic_sql(migration_file, "0001") + runner = _async_runner(tmp_path, {}) + driver = AsyncMock() + + await runner.execute_upgrade(driver, _migration(migration_file, _AsyncMigrationLoader()), use_transaction=True) + + driver.set_migration_session_schema.assert_not_called() + + def test_load_migration_metadata_uses_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Ensure metadata caching prevents redundant checksum calculations.""" diff --git a/tests/unit/migrations/test_tracker_idempotency.py b/tests/unit/migrations/test_tracker_idempotency.py index 500b97107..fdd7507bc 100644 --- a/tests/unit/migrations/test_tracker_idempotency.py +++ b/tests/unit/migrations/test_tracker_idempotency.py @@ -1,10 +1,14 @@ -"""Unit tests for idempotent update_version_record behavior.""" +"""Unit tests for migration tracker behavior.""" +import logging from typing import Any from unittest.mock import MagicMock, Mock import pytest +from sqlspec.adapters.oracledb.migrations import OracleSyncMigrationTracker +from sqlspec.driver._async import AsyncDriverAdapterBase +from sqlspec.driver._sync import SyncDriverAdapterBase from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker @@ -25,6 +29,64 @@ def test_sync_update_version_record_success() -> None: assert "ddl_migrations" in update_sql +def test_sync_tracker_qualifies_table_sql_when_schema_is_configured() -> None: + """The base tracker should render the tracking table in the configured schema.""" + tracker = SyncMigrationTracker(version_table_name="ddl_migrations", version_table_schema="history") + + assert tracker.version_table == "history.ddl_migrations" + rendered_statements = [ + str(tracker._get_create_table_sql()), # pyright: ignore[reportPrivateUsage] + str(tracker._get_current_version_sql()), # pyright: ignore[reportPrivateUsage] + str(tracker._get_applied_migrations_sql()), # pyright: ignore[reportPrivateUsage] + str(tracker._get_next_execution_sequence_sql()), # pyright: ignore[reportPrivateUsage] + str(tracker._get_record_migration_sql("0001", "sequential", 1, "init", 10, "abc", "tester")), # pyright: ignore[reportPrivateUsage] + str(tracker._get_remove_migration_sql("0001")), # pyright: ignore[reportPrivateUsage] + str(tracker._get_update_version_sql("20250101000000", "0001", "sequential")), # pyright: ignore[reportPrivateUsage] + str(tracker._get_delete_versions_sql(["0001"])), # pyright: ignore[reportPrivateUsage] + str(tracker._get_record_squashed_migration_sql("0002", "sequential", 2, "squash", 0, "def", "tester", "0001")), # pyright: ignore[reportPrivateUsage] + str(tracker._get_check_column_exists_sql()), # pyright: ignore[reportPrivateUsage] + ] + + assert all('"history"."ddl_migrations"' in statement for statement in rendered_statements) + + +def test_sync_tracker_keeps_bare_table_sql_without_schema() -> None: + """Existing migration tracker SQL should stay unqualified by default.""" + tracker = SyncMigrationTracker(version_table_name="ddl_migrations") + + assert tracker.version_table == "ddl_migrations" + assert "history.ddl_migrations" not in str(tracker._get_create_table_sql()) # pyright: ignore[reportPrivateUsage] + + +def test_oracle_tracker_uppercases_qualified_tracking_table() -> None: + """Oracle tracker should qualify migration table names using Oracle's uppercase convention.""" + tracker = OracleSyncMigrationTracker(version_table_name="ddl_migrations", version_table_schema="app_owner") + + assert tracker.version_table == "APP_OWNER.DDL_MIGRATIONS" + existing_columns_sql = tracker._get_existing_columns_sql() # pyright: ignore[reportPrivateUsage] + assert "WHERE table_name = 'DDL_MIGRATIONS'" in existing_columns_sql + assert "AND owner = 'APP_OWNER'" in existing_columns_sql + + +def test_sync_driver_migration_schema_hook_logs_noop(caplog: pytest.LogCaptureFixture) -> None: + """Base sync drivers should accept migration schema configuration as a no-op hook.""" + caplog.set_level(logging.DEBUG, logger="sqlspec.driver") + + SyncDriverAdapterBase.set_migration_session_schema(object(), "app_schema") # type: ignore[arg-type] + + assert any(record.getMessage() == "migration.schema.noop" for record in caplog.records) + + +@pytest.mark.anyio +async def test_async_driver_migration_schema_hook_logs_noop(caplog: pytest.LogCaptureFixture) -> None: + """Base async drivers should accept migration schema configuration as a no-op hook.""" + caplog.set_level(logging.DEBUG, logger="sqlspec.driver") + + await AsyncDriverAdapterBase.set_migration_session_schema(object(), "app_schema") # type: ignore[arg-type] + + assert any(record.getMessage() == "migration.schema.noop" for record in caplog.records) + + def test_sync_update_version_record_idempotent_when_already_updated() -> None: """Test sync update is idempotent when version already exists.""" tracker = SyncMigrationTracker() diff --git a/tests/unit/migrations/test_utils.py b/tests/unit/migrations/test_utils.py index 9d79ff3de..4019522ef 100644 --- a/tests/unit/migrations/test_utils.py +++ b/tests/unit/migrations/test_utils.py @@ -17,7 +17,14 @@ from sqlspec.config import DatabaseConfigProtocol from sqlspec.migrations.templates import TemplateValidationError -from sqlspec.migrations.utils import _get_git_config, _get_system_username, create_migration_file, get_author +from sqlspec.migrations.utils import ( + _get_git_config, + _get_system_username, + create_migration_file, + get_author, + quote_migration_identifier, + resolve_tracker_schema, +) def _callable_author(_: Any | None = None) -> str: @@ -261,6 +268,26 @@ class DummyConfig: assert file_path.suffix == ".py" +def test_resolve_tracker_schema_prefers_version_table_schema() -> None: + migration_config = {"default_schema": "app_schema", "version_table_schema": "history_schema"} + + assert resolve_tracker_schema(migration_config) == "history_schema" + + +def test_resolve_tracker_schema_defaults_to_default_schema() -> None: + migration_config = {"default_schema": "app_schema"} + + assert resolve_tracker_schema(migration_config) == "app_schema" + + +def test_resolve_tracker_schema_returns_none_without_schema_config() -> None: + assert resolve_tracker_schema({}) is None + + +def test_quote_migration_identifier_escapes_embedded_quotes() -> None: + assert quote_migration_identifier('tenant_"one"') == '"tenant_""one"""' + + def test_create_migration_file_uses_custom_sql_template(tmp_path: Path) -> None: migrations_dir = tmp_path / "migrations" migrations_dir.mkdir()