diff --git a/src/database/migrations/env.py b/src/database/migrations/env.py index c6f5e81..6912b0b 100644 --- a/src/database/migrations/env.py +++ b/src/database/migrations/env.py @@ -2,7 +2,7 @@ from alembic.script import ScriptDirectory from ddcDatabases import get_postgresql_settings from logging.config import fileConfig -from sqlalchemy import engine_from_config, pool, text +from sqlalchemy import create_engine, engine_from_config, pool, text from sqlalchemy.schema import SchemaItem from src.bot.constants.settings import get_bot_settings from src.database.models import BotBase @@ -17,17 +17,36 @@ _project_settings = get_bot_settings() _postgres_settings = get_postgresql_settings() -_password = quote_plus(_postgres_settings.password).replace("%", "%%") -_conn_url = ( - f"{_postgres_settings.sync_driver}://" - f"{_postgres_settings.user}:" - f"{_password}@" - f"{_postgres_settings.host}:" - f"{_postgres_settings.port}/" - f"{_postgres_settings.database}" - f"?sslmode={_postgres_settings.ssl_mode}" -) -config.set_main_option("sqlalchemy.url", _conn_url) +_password = quote_plus(_postgres_settings.password) + + +def _build_url(database: str) -> str: + return ( + f"{_postgres_settings.sync_driver}://" + f"{_postgres_settings.user}:" + f"{_password}@" + f"{_postgres_settings.host}:" + f"{_postgres_settings.port}/" + f"{database}" + f"?sslmode={_postgres_settings.ssl_mode}" + ) + + +def _ensure_database_exists() -> None: + """Connect to the 'postgres' maintenance DB and create the target database if it doesn't exist.""" + db_name = _postgres_settings.database + engine = create_engine(_build_url("postgres"), isolation_level="AUTOCOMMIT") + with engine.connect() as conn: + result = conn.execute(text("SELECT 1 FROM pg_database WHERE datname = :db"), {"db": db_name}) + if not result.scalar(): + conn.execute(text(f'CREATE DATABASE "{db_name}"')) + engine.dispose() + + +_ensure_database_exists() + +# set_main_option uses %-interpolation, so escape any % in the password +config.set_main_option("sqlalchemy.url", _build_url(_postgres_settings.database).replace("%", "%%")) _schemas = {s.strip() for s in (_postgres_settings.schema or "public").split(",")}