From e54eabfd207c26b2047d20739aa53d2cfa8791a3 Mon Sep 17 00:00:00 2001 From: Hadrien David Date: Sat, 10 Jan 2026 08:14:35 -0500 Subject: [PATCH] fix: allow lifespan to receive app=None --- src/fastsqla.py | 6 +++--- tests/unit/test_lifespan.py | 22 ++++++++++++++-------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/fastsqla.py b/src/fastsqla.py index dbb1ce3..56461f5 100644 --- a/src/fastsqla.py +++ b/src/fastsqla.py @@ -2,7 +2,7 @@ import os from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable from contextlib import _AsyncGeneratorContextManager, asynccontextmanager -from typing import Annotated, Generic, TypeVar, TypedDict +from typing import Annotated, Generic, TypedDict, TypeVar from fastapi import Depends as BaseDepends from fastapi import FastAPI, Query @@ -90,7 +90,7 @@ class State(TypedDict): def new_lifespan( url: str | None = None, **kw -) -> Callable[[FastAPI], _AsyncGeneratorContextManager[State, None]]: +) -> Callable[[FastAPI | None], _AsyncGeneratorContextManager[State, None]]: """Create a new lifespan async context manager. It expects the exact same parameters as @@ -117,7 +117,7 @@ def new_lifespan( has_config = url is not None @asynccontextmanager - async def lifespan(app: FastAPI) -> AsyncGenerator[State, None]: + async def lifespan(app: FastAPI | None) -> AsyncGenerator[State, None]: if has_config: prefix = "" sqla_config = {**kw, **{"url": url}} diff --git a/tests/unit/test_lifespan.py b/tests/unit/test_lifespan.py index 626d650..d57a6e0 100644 --- a/tests/unit/test_lifespan.py +++ b/tests/unit/test_lifespan.py @@ -1,17 +1,23 @@ from fastapi import FastAPI -from pytest import raises +from pytest import raises, fixture -app = FastAPI() +_app = FastAPI() -async def test_it_returns_state(environ): +@fixture(params=[_app, None]) +def app(request): + # lifespan tests pass whether lifespan receives app or None + return request.param + + +async def test_it_returns_state(environ, app): from fastsqla import lifespan async with lifespan(app) as state: assert "fastsqla_engine" in state -async def test_it_binds_an_sqla_engine_to_sessionmaker(environ): +async def test_it_binds_an_sqla_engine_to_sessionmaker(environ, app): from fastsqla import SessionFactory, lifespan assert SessionFactory.kw["bind"] is None @@ -24,7 +30,7 @@ async def test_it_binds_an_sqla_engine_to_sessionmaker(environ): assert SessionFactory.kw["bind"] is None -async def test_it_fails_on_a_missing_sqlalchemy_url(monkeypatch): +async def test_it_fails_on_a_missing_sqlalchemy_url(monkeypatch, app): from fastsqla import lifespan monkeypatch.delenv("SQLALCHEMY_URL", raising=False) @@ -35,7 +41,7 @@ async def test_it_fails_on_a_missing_sqlalchemy_url(monkeypatch): assert raise_info.value.args[0] == "Missing sqlalchemy_url in environ." -async def test_it_fails_on_not_async_engine(monkeypatch): +async def test_it_fails_on_not_async_engine(monkeypatch, app): from fastsqla import lifespan monkeypatch.setenv("SQLALCHEMY_URL", "sqlite:///:memory:") @@ -46,7 +52,7 @@ async def test_it_fails_on_not_async_engine(monkeypatch): assert "'pysqlite' is not async." in raise_info.value.args[0] -async def test_new_lifespan_with_connect_args(sqlalchemy_url): +async def test_new_lifespan_with_connect_args(sqlalchemy_url, app): from fastsqla import new_lifespan lifespan = new_lifespan(sqlalchemy_url, connect_args={"autocommit": False}) @@ -55,7 +61,7 @@ async def test_new_lifespan_with_connect_args(sqlalchemy_url): pass -async def test_new_lifespan_fails_with_invalid_connect_args(sqlalchemy_url): +async def test_new_lifespan_fails_with_invalid_connect_args(sqlalchemy_url, app): from fastsqla import new_lifespan lifespan = new_lifespan(sqlalchemy_url, connect_args={"this is wrong": False})