diff --git a/Makefile b/Makefile index bafbcf161..d9bbc1dda 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,9 @@ setup-env: debug: FLASK_APP=policyengine_api.api FLASK_DEBUG=1 flask run --without-threads +debug-asgi: + FLASK_DEBUG=1 uvicorn policyengine_api.asgi:app --reload --port 8000 + test-env-vars: pytest tests/env_variables diff --git a/changelog.d/fastapi-shell.added.md b/changelog.d/fastapi-shell.added.md new file mode 100644 index 000000000..49ceb98b7 --- /dev/null +++ b/changelog.d/fastapi-shell.added.md @@ -0,0 +1 @@ +Added a FastAPI ASGI compatibility shell that can serve the existing Flask API through WSGI fallback. diff --git a/docs/engineering/skills/testing.md b/docs/engineering/skills/testing.md index 9a8a2611d..83b4b21a9 100644 --- a/docs/engineering/skills/testing.md +++ b/docs/engineering/skills/testing.md @@ -22,5 +22,24 @@ python scripts/export_migration_contracts.py python -m pytest tests/contract tests/unit/test_migration_flags.py tests/unit/test_migration_contract_artifacts.py tests/unit/test_capture_migration_baseline.py tests/unit/routes/test_migration_context_logging.py -q ``` +For PR 2 FastAPI shell or Flask fallback changes, verify the ASGI entrypoint and +the v1 route contracts together: + +```bash +FLASK_DEBUG=1 python -m pytest tests/unit/test_asgi_factory.py tests/contract/test_v1_route_contracts.py tests/unit/routes/test_migration_context_logging.py -q +``` + +If the change touches service compatibility behavior used by migrated or +candidate endpoints, add the relevant focused service tests. For budget-window +simulation compatibility, run: + +```bash +FLASK_DEBUG=1 python -m pytest tests/unit/services/test_economy_service.py::TestEconomyService::TestGetBudgetWindowEconomicImpact -q +``` + +Regenerate and review `docs/engineering/generated/migration_contracts.md` when +route inventory, migration registry flags, or v1 contract expectations change. +FastAPI shell-only fallback changes should not change the route catalog. + Run `ruff format --check` and `ruff check` on changed Python files before handoff. diff --git a/docs/migration-pr2-fastapi-shell-runbook.md b/docs/migration-pr2-fastapi-shell-runbook.md new file mode 100644 index 000000000..8a5a8d6ed --- /dev/null +++ b/docs/migration-pr2-fastapi-shell-runbook.md @@ -0,0 +1,41 @@ +# PR 2 FastAPI Shell Runbook + +PR 2 adds an ASGI FastAPI shell around the existing Flask API. It is a +compatibility step only. + +## Included + +- Native FastAPI `GET /health`. +- Flask fallback for all existing API v1 routes through WSGI middleware. +- ASGI parity tests for current app-v2 contract routes. +- Local Uvicorn run command. + +## Not Included + +- No production traffic shift. +- No Cloud Run deployment. +- No native FastAPI route migration beyond `GET /health`. +- No Supabase, Alembic, SQLAlchemy, or Modal compute changes. + +## Local Smoke + +Run: + +```bash +FLASK_DEBUG=1 uvicorn policyengine_api.asgi:app --port 8000 +``` + +Smoke-check: + +```bash +curl -i http://localhost:8000/health +curl -i http://localhost:8000/readiness-check +curl -i http://localhost:8000/liveness-check +curl -i http://localhost:8000/zz/metadata +``` + +Expected behavior: + +- `/health` returns FastAPI JSON: `{"status":"healthy"}`. +- `/readiness-check` and `/liveness-check` return existing Flask text `OK`. +- Existing v1 routes continue to use Flask fallback behavior. diff --git a/policyengine_api/asgi.py b/policyengine_api/asgi.py new file mode 100644 index 000000000..f54392915 --- /dev/null +++ b/policyengine_api/asgi.py @@ -0,0 +1,9 @@ +"""ASGI entrypoint for the Stage 2 FastAPI compatibility shell.""" + +from __future__ import annotations + +from policyengine_api.api import app as flask_app +from policyengine_api.asgi_factory import create_asgi_app + + +app = application = create_asgi_app(flask_app) diff --git a/policyengine_api/asgi_factory.py b/policyengine_api/asgi_factory.py new file mode 100644 index 000000000..b94838eaa --- /dev/null +++ b/policyengine_api/asgi_factory.py @@ -0,0 +1,52 @@ +"""FastAPI shell for serving the existing Flask API through ASGI.""" + +from __future__ import annotations + +from typing import Literal + +from a2wsgi import WSGIMiddleware +from fastapi import FastAPI +from pydantic import BaseModel + +from policyengine_api.constants import VERSION + + +class HealthResponse(BaseModel): + status: Literal["healthy"] + + +def _add_vary_origin(response) -> None: + vary = response.headers.get("Vary") + if vary is None: + response.headers["Vary"] = "Origin" + return + if "origin" not in {value.strip().lower() for value in vary.split(",")}: + response.headers["Vary"] = f"{vary}, Origin" + + +def create_asgi_app(wsgi_app) -> FastAPI: + """Create the Stage 2 FastAPI shell around the existing Flask app.""" + + app = FastAPI( + title="PolicyEngine API", + version=VERSION, + docs_url=None, + redoc_url=None, + openapi_url=None, + ) + + @app.middleware("http") + async def add_cors_for_native_routes(request, call_next): + response = await call_next(request) + origin = request.headers.get("origin") + if origin and "access-control-allow-origin" not in response.headers: + response.headers["Access-Control-Allow-Origin"] = origin + _add_vary_origin(response) + return response + + @app.get("/health", response_model=HealthResponse) + def health() -> HealthResponse: + return HealthResponse(status="healthy") + + app.mount("/", WSGIMiddleware(wsgi_app)) + return app diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 3532244d9..59a51403a 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -24,6 +24,7 @@ from policyengine_api.utils import budget_window as budget_window_utils from policyengine.simulation import SimulationOptions from policyengine.utils.data.datasets import get_default_dataset +import httpx import json import datetime import hashlib @@ -77,6 +78,7 @@ class ImpactStatus(Enum): BUDGET_WINDOW_MAX_ACTIVE_YEARS = budget_window_utils.BUDGET_WINDOW_MAX_ACTIVE_YEARS BUDGET_WINDOW_MAX_YEARS = budget_window_utils.BUDGET_WINDOW_MAX_YEARS BUDGET_WINDOW_MAX_END_YEAR = budget_window_utils.BUDGET_WINDOW_MAX_END_YEAR +BUDGET_WINDOW_SUBMISSION_VALIDATION_ERROR_STATUS_CODES = {400, 422} class EconomicImpactSetupOptions(BaseModel): @@ -348,6 +350,18 @@ def get_budget_window_economic_impact( budget_window_cache.store_batch_job_id( cache_key, batch_execution.batch_job_id ) + except httpx.HTTPStatusError as error: + budget_window_cache.clear_starting_claim(cache_key, claim_token) + if ( + error.response.status_code + in BUDGET_WINDOW_SUBMISSION_VALIDATION_ERROR_STATUS_CODES + ): + return BudgetWindowEconomicImpactResult.failed( + self._build_budget_window_submission_error_message(error), + queued_years=years, + cache_status=cache_status, + ) + raise except Exception: budget_window_cache.clear_starting_claim(cache_key, claim_token) raise @@ -443,6 +457,26 @@ def _start_budget_window_batch( return simulation_api.run_budget_window_batch(sim_params) + def _build_budget_window_submission_error_message( + self, error: httpx.HTTPStatusError + ) -> str: + try: + response_json = error.response.json() + except ValueError: + response_json = None + + if isinstance(response_json, dict): + for key in ("detail", "message", "error"): + value = response_json.get(key) + if value: + return str(value) + + response_text = error.response.text.strip() + if response_text: + return response_text + + return str(error) + def _get_budget_window_result_from_batch_job_id( self, *, diff --git a/pyproject.toml b/pyproject.toml index dd2905171..5cc138a46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,11 +21,13 @@ classifiers = [ "License :: OSI Approved :: GNU Affero General Public License v3", ] dependencies = [ + "a2wsgi>=1.10,<2", "anthropic", "assertpy", "click>=8,<9", "cloud-sql-python-connector", "faiss-cpu", + "fastapi>=0.115,<1", "flask>=3,<4", "flask-cors>=5,<6", "Flask-Caching>=2,<3", @@ -50,6 +52,7 @@ dependencies = [ "rq", "sqlalchemy>=2,<3", "streamlit", + "uvicorn[standard]>=0.32,<1", "werkzeug", ] diff --git a/tests/contract/clients.py b/tests/contract/clients.py new file mode 100644 index 000000000..4fdd98500 --- /dev/null +++ b/tests/contract/clients.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping, Protocol + +from fastapi.testclient import TestClient +from flask import Flask + +from policyengine_api.asgi_factory import create_asgi_app + + +@dataclass(frozen=True) +class ContractResponse: + status_code: int + body: bytes + headers: Mapping[str, str] + content_type: str | None + + @property + def data(self) -> bytes: + return self.body + + +class ContractClient(Protocol): + def open( + self, + path: str, + *, + method: str, + json: dict | None = None, + headers: dict | None = None, + ) -> ContractResponse: ... + + +class FlaskContractClient: + def __init__(self, app: Flask): + self._client = app.test_client() + + def open( + self, + path: str, + *, + method: str, + json: dict | None = None, + headers: dict | None = None, + ) -> ContractResponse: + response = self._client.open( + path, + method=method, + json=json, + headers=headers, + ) + return ContractResponse( + status_code=response.status_code, + body=response.data, + headers=dict(response.headers), + content_type=response.content_type, + ) + + +class ASGIContractClient: + def __init__(self, app: Flask): + self._client = TestClient(create_asgi_app(app)) + + def open( + self, + path: str, + *, + method: str, + json: dict | None = None, + headers: dict | None = None, + ) -> ContractResponse: + response = self._client.request( + method, + path, + json=json, + headers=headers, + ) + return ContractResponse( + status_code=response.status_code, + body=response.content, + headers=dict(response.headers), + content_type=response.headers.get("content-type"), + ) diff --git a/tests/contract/test_v1_route_contracts.py b/tests/contract/test_v1_route_contracts.py index a3fec206a..6b3407bbe 100644 --- a/tests/contract/test_v1_route_contracts.py +++ b/tests/contract/test_v1_route_contracts.py @@ -13,6 +13,11 @@ from policyengine_api.routes.policy_routes import policy_bp from policyengine_api.routes.report_output_routes import report_output_bp from policyengine_api.routes.simulation_routes import simulation_bp +from tests.contract.clients import ( + ASGIContractClient, + ContractClient, + FlaskContractClient, +) from tests.contract.helpers import ( assert_field_path_exists, assert_subset, @@ -121,7 +126,7 @@ def _load_contract_economy_blueprint(): ) -def _client(): +def create_contract_flask_app() -> Flask: app = Flask(__name__) app.config["TESTING"] = True app.register_blueprint(_load_contract_metadata_blueprint()) @@ -141,7 +146,17 @@ def liveness_check(): def readiness_check(): return Response("OK", status=200, mimetype="text/plain") - return app.test_client() + return app + + +@pytest.fixture(params=("flask_direct", "fastapi_fallback")) +def contract_client(request) -> ContractClient: + app = create_contract_flask_app() + if request.param == "flask_direct": + return FlaskContractClient(app) + if request.param == "fastapi_fallback": + return ASGIContractClient(app) + raise AssertionError(f"Unknown contract client: {request.param}") def _resolved_path(path: str) -> str: @@ -375,9 +390,12 @@ def _expected_subset(contract: ContractRequest) -> dict: APP_V2_ROUTE_CONTRACTS, ids=lambda contract: f"{contract.method} {contract.path}", ) -def test_app_v2_api_v1_route_contract(contract): +def test_app_v2_api_v1_route_contract( + contract: ContractRequest, + contract_client: ContractClient, +): with _patched_route_dependencies(): - response = _client().open( + response = contract_client.open( _resolved_path(contract.path), method=contract.method, json=_json_payload(contract), @@ -390,10 +408,9 @@ def test_app_v2_api_v1_route_contract(contract): assert_field_path_exists(payload, field_path) -def test_health_routes_contract(): - client = _client() - liveness = client.get("/liveness-check") - readiness = client.get("/readiness-check") +def test_health_routes_contract(contract_client: ContractClient): + liveness = contract_client.open("/liveness-check", method="GET") + readiness = contract_client.open("/readiness-check", method="GET") assert liveness.status_code == 200 assert liveness.data == b"OK" @@ -403,8 +420,8 @@ def test_health_routes_contract(): assert "text/plain" in readiness.content_type -def test_invalid_country_contract(): - response = _client().get("/zz/metadata") +def test_invalid_country_contract(contract_client: ContractClient): + response = contract_client.open("/zz/metadata", method="GET") assert response.status_code == 400 assert_subset( diff --git a/tests/unit/routes/test_migration_context_logging.py b/tests/unit/routes/test_migration_context_logging.py index e0798d70e..204b284eb 100644 --- a/tests/unit/routes/test_migration_context_logging.py +++ b/tests/unit/routes/test_migration_context_logging.py @@ -1,7 +1,9 @@ from unittest.mock import patch +from fastapi.testclient import TestClient from flask import Flask, Response +from policyengine_api.asgi_factory import create_asgi_app from policyengine_api.migration_logging import register_migration_request_logging @@ -40,3 +42,14 @@ def test_request_logging_failure_does_not_change_response(): assert response.status_code == 200 assert response.data == b"OK" + + +def test_request_logging_runs_for_asgi_fallback_routes(): + with patch("policyengine_api.migration_logging.logger") as mock_logger: + response = TestClient(create_asgi_app(_app())).get("/readiness-check") + + assert response.status_code == 200 + assert response.content == b"OK" + log_payload = mock_logger.log_struct.call_args.args[0] + assert log_payload["path"] == "/readiness-check" + assert log_payload["migration"]["route_group"] == "health" diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 5b1da4405..c82d2bd31 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1,5 +1,6 @@ import json import sys +import httpx import pytest from unittest.mock import patch, MagicMock from typing import Literal @@ -110,6 +111,27 @@ def make_mock_budget_impact_data( } +def make_http_status_error( + status_code: int, + payload: dict | None = None, + text: str | None = None, +) -> httpx.HTTPStatusError: + request = httpx.Request( + "POST", + "https://policyengine-staging--policyengine-simulation-gateway-web-app.modal.run/simulate/economy/budget-window", + ) + response = ( + httpx.Response(status_code, json=payload, request=request) + if payload is not None + else httpx.Response(status_code, text=text or "", request=request) + ) + return httpx.HTTPStatusError( + f"Client error '{status_code}'", + request=request, + response=response, + ) + + class TestEconomyService: class TestGetEconomicImpact: @pytest.fixture @@ -1010,6 +1032,103 @@ def test__given_batch_submission_fails__clears_start_claim( "budget-window-cache-key", MOCK_PROCESS_ID ) + @pytest.mark.parametrize("status_code", [400, 422]) + def test__given_modal_rejects_batch_submission_for_validation__returns_failed_result( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + status_code, + ): + mock_simulation_api.run_budget_window_batch.side_effect = make_http_status_error( + status_code, + { + "detail": ( + "Invalid Hugging Face dataset URI: " + "'hf://policyengine/nonexistent-budget-window-test.h5@0.0.0'" + ) + }, + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert result.data is None + assert result.error == ( + "Invalid Hugging Face dataset URI: " + "'hf://policyengine/nonexistent-budget-window-test.h5@0.0.0'" + ) + assert result.completed_years == [] + assert result.computing_years == [] + assert result.queued_years == ["2026", "2027", "2028"] + assert result.cache_status == "miss" + mock_budget_window_cache.clear_starting_claim.assert_called_once_with( + "budget-window-cache-key", MOCK_PROCESS_ID + ) + mock_budget_window_cache.store_batch_job_id.assert_not_called() + + @pytest.mark.parametrize("status_code", [401, 403, 429, 500]) + def test__given_modal_non_validation_error_on_batch_submission__raises( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + status_code, + ): + mock_simulation_api.run_budget_window_batch.side_effect = ( + make_http_status_error(status_code, {"detail": "gateway unavailable"}) + ) + + with pytest.raises(httpx.HTTPStatusError): + economy_service.get_budget_window_economic_impact(**base_params) + + mock_budget_window_cache.clear_starting_claim.assert_called_once_with( + "budget-window-cache-key", MOCK_PROCESS_ID + ) + mock_budget_window_cache.store_batch_job_id.assert_not_called() + + @pytest.mark.parametrize( + ("payload", "expected_message"), + [ + ({"message": "gateway validation failed"}, "gateway validation failed"), + ({"error": "invalid request"}, "invalid request"), + ], + ) + def test__given_modal_validation_json_error__extracts_message( + self, economy_service, payload, expected_message + ): + error = make_http_status_error(400, payload) + + message = economy_service._build_budget_window_submission_error_message( + error + ) + + assert message == expected_message + + def test__given_modal_validation_plain_text_error__extracts_response_text( + self, economy_service + ): + error = make_http_status_error(400, text="plain validation failed") + + message = economy_service._build_budget_window_submission_error_message( + error + ) + + assert message == "plain validation failed" + + def test__given_modal_validation_empty_error__falls_back_to_exception_text( + self, economy_service + ): + error = make_http_status_error(400, text="") + + message = economy_service._build_budget_window_submission_error_message( + error + ) + + assert message == str(error) + def test__given_cliff_target__raises_value_error( self, economy_service, base_params ): diff --git a/tests/unit/test_asgi_factory.py b/tests/unit/test_asgi_factory.py new file mode 100644 index 000000000..015b62ae7 --- /dev/null +++ b/tests/unit/test_asgi_factory.py @@ -0,0 +1,181 @@ +import importlib +import json +import sys + +import pytest +from fastapi.testclient import TestClient +from flask import Flask, Response, jsonify, make_response, request +from flask_cors import CORS +from starlette.responses import Response as ASGIResponse + +from policyengine_api.asgi_factory import _add_vary_origin, create_asgi_app + + +def create_test_wsgi_app() -> Flask: + app = Flask(__name__) + CORS(app) + + @app.get("/fallback") + def fallback(): + response = make_response("flask fallback", 202) + response.headers["X-Fallback"] = "preserved" + response.set_cookie("fallback-cookie", "present") + return response + + @app.get("/request-echo") + def request_echo(): + response = jsonify( + { + "cookie": request.cookies.get("session_id"), + "header": request.headers.get("X-Request-Trace"), + } + ) + response.headers["X-Echo"] = "present" + return response + + @app.get("/readiness-check") + def readiness_check(): + return Response("OK", status=200, mimetype="text/plain") + + @app.get("/liveness-check") + def liveness_check(): + return Response("OK", status=200, mimetype="text/plain") + + @app.get("/specification") + def specification(): + return jsonify({"openapi": "3.0.0", "info": {"title": "fallback"}}) + + return app + + +def test_native_health_route_is_fastapi_json(): + client = TestClient(create_asgi_app(create_test_wsgi_app())) + + response = client.get("/health") + + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + assert response.headers["content-type"].startswith("application/json") + + +@pytest.mark.parametrize( + ("existing_vary", "expected_vary"), + [ + (None, "Origin"), + ("Accept-Encoding", "Accept-Encoding, Origin"), + ("Origin", "Origin"), + ("Accept-Encoding, origin", "Accept-Encoding, origin"), + ], +) +def test_add_vary_origin_preserves_existing_values(existing_vary, expected_vary): + response = ASGIResponse() + if existing_vary is not None: + response.headers["Vary"] = existing_vary + + _add_vary_origin(response) + + assert response.headers["Vary"] == expected_vary + + +def test_asgi_entrypoint_imports_and_serves_health(monkeypatch): + monkeypatch.setenv("FLASK_DEBUG", "1") + sys.modules.pop("policyengine_api.asgi", None) + + asgi_module = importlib.import_module("policyengine_api.asgi") + response = TestClient(asgi_module.app).get("/health") + + assert asgi_module.application is asgi_module.app + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +def test_fastapi_documentation_routes_fall_through_to_flask_404(): + client = TestClient(create_asgi_app(create_test_wsgi_app())) + + for path in ("/docs", "/redoc", "/openapi.json"): + response = client.get(path) + + assert response.status_code == 404 + assert "text/html" in response.headers["content-type"] + assert "swagger" not in response.text.lower() + + +def test_flask_fallback_preserves_status_body_headers_and_cookies(): + client = TestClient(create_asgi_app(create_test_wsgi_app())) + + response = client.get("/fallback") + + assert response.status_code == 202 + assert response.content == b"flask fallback" + assert response.headers["x-fallback"] == "preserved" + assert response.headers["set-cookie"].startswith("fallback-cookie=present") + assert response.headers["content-type"].startswith("text/html") + + +def test_request_headers_and_cookies_pass_through_to_flask_fallback(): + client = TestClient(create_asgi_app(create_test_wsgi_app())) + client.cookies.set("session_id", "session-123") + + response = client.get( + "/request-echo", + headers={"X-Request-Trace": "trace-123"}, + ) + + assert response.status_code == 200 + assert response.json() == { + "cookie": "session-123", + "header": "trace-123", + } + assert response.headers["x-echo"] == "present" + + +def test_flask_cors_behavior_is_preserved_for_fallback_routes(): + client = TestClient(create_asgi_app(create_test_wsgi_app())) + + response = client.get( + "/fallback", + headers={"Origin": "https://app.policyengine.org"}, + ) + + assert response.status_code == 202 + assert ( + response.headers["access-control-allow-origin"] + == "https://app.policyengine.org" + ) + assert response.headers["vary"] == "Origin" + + +def test_health_route_uses_same_reflected_cors_policy(): + client = TestClient(create_asgi_app(create_test_wsgi_app())) + + response = client.get( + "/health", + headers={"Origin": "https://app.policyengine.org"}, + ) + + assert response.status_code == 200 + assert ( + response.headers["access-control-allow-origin"] + == "https://app.policyengine.org" + ) + assert response.headers["vary"] == "Origin" + + +def test_existing_health_and_specification_paths_fall_back_to_flask(): + client = TestClient(create_asgi_app(create_test_wsgi_app())) + + readiness = client.get("/readiness-check") + liveness = client.get("/liveness-check") + specification = client.get("/specification") + + assert readiness.status_code == 200 + assert readiness.content == b"OK" + assert readiness.headers["content-type"].startswith("text/plain") + assert liveness.status_code == 200 + assert liveness.content == b"OK" + assert liveness.headers["content-type"].startswith("text/plain") + assert specification.status_code == 200 + assert json.loads(specification.content) == { + "openapi": "3.0.0", + "info": {"title": "fallback"}, + }