Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
changed:
- Upgraded SQLAlchemy from v1 (>=1.4,<2) to v2 (>=2,<3) for Python 3.14 compatibility. Replaced removed engine.execute() with connection-based execution, updated LegacyRow to Row, and added _ResultProxy wrapper for eager result fetching.
48 changes: 46 additions & 2 deletions policyengine_api/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,34 @@
load_dotenv()


class _ResultProxy:
"""Lightweight wrapper that eagerly fetches results from a
SQLAlchemy CursorResult so they survive connection closure.
Provides fetchone()/fetchall() with dict-like row access."""

def __init__(self, cursor_result):
try:
# Use .mappings() so rows behave like dicts
self._rows = list(cursor_result.mappings())
except Exception:
# For non-SELECT statements (INSERT/UPDATE/DELETE)
# there are no rows to fetch
self._rows = []
self._index = 0

def fetchone(self):
if self._index < len(self._rows):
row = self._rows[self._index]
self._index += 1
return row
return None

def fetchall(self):
remaining = self._rows[self._index :]
self._index = len(self._rows)
return remaining


class PolicyEngineDatabase:
"""
A wrapper around the database connection.
Expand Down Expand Up @@ -70,6 +98,22 @@ def _close_pool(self):
except:
pass

def _execute_remote(self, query_args):
"""Execute a query against the remote database using
SQLAlchemy v2 connection-based execution."""
main_query = query_args[0]
params = query_args[1] if len(query_args) > 1 else None
with self.pool.connect() as conn:
if params is not None:
result = conn.exec_driver_sql(main_query, params)
else:
result = conn.exec_driver_sql(main_query)
conn.commit()
# Return a lightweight wrapper that holds
# the fetched results so they survive the
# connection context closing
return _ResultProxy(result)

def query(self, *query):
if self.local:
with sqlite3.connect(self.db_url) as conn:
Expand All @@ -89,7 +133,7 @@ def dict_factory(cursor, row):
main_query = main_query.replace("?", "%s")
query[0] = main_query
try:
return self.pool.execute(*query)
return self._execute_remote(query)
# Except InterfaceError and OperationalError, which are thrown when the connection is lost.
except (
sqlalchemy.exc.InterfaceError,
Expand All @@ -98,7 +142,7 @@ def dict_factory(cursor, row):
try:
self._close_pool()
self._create_pool()
return self.pool.execute(*query)
return self._execute_remote(query)
except Exception as e:
raise e

Expand Down
6 changes: 5 additions & 1 deletion policyengine_api/endpoints/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,13 @@ def set_user_policy(country_id: str) -> dict:
f"AND dataset {dataset_select_str}"
)

params = [country_id, reform_id, baseline_id, user_id, year, geography]
if dataset:
params.append(dataset)

row = database.query(
query,
(country_id, reform_id, baseline_id, user_id, year, geography),
tuple(params),
).fetchone()

except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions policyengine_api/services/household_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from sqlalchemy.engine.row import LegacyRow
from sqlalchemy.engine.row import Row

from policyengine_api.data import database
from policyengine_api.utils import hash_object
Expand All @@ -24,7 +24,7 @@ def get_household(self, country_id: str, household_id: int) -> dict | None:
f"Invalid household ID: {household_id}. Must be a positive integer."
)

row: LegacyRow | None = database.query(
row: Row | None = database.query(
f"SELECT * FROM household WHERE id = ? AND country_id = ?",
(household_id, country_id),
).fetchone()
Expand Down
4 changes: 2 additions & 2 deletions policyengine_api/services/policy_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from sqlalchemy.engine.row import LegacyRow
from sqlalchemy.engine.row import Row

from policyengine_api.data import database
from policyengine_api.utils import hash_object
Expand Down Expand Up @@ -37,7 +37,7 @@ def get_policy(self, country_id: str, policy_id: int) -> dict | None:
raise ValueError("country_id cannot be empty or None")

# If no policy found, this will return None
row: LegacyRow | None = database.query(
row: Row | None = database.query(
"SELECT * FROM policy WHERE country_id = ? AND id = ?",
(country_id, policy_id),
).fetchone()
Expand Down
4 changes: 2 additions & 2 deletions policyengine_api/services/report_output_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy.engine.row import LegacyRow
from sqlalchemy.engine.row import Row

from policyengine_api.data import database
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
Expand Down Expand Up @@ -137,7 +137,7 @@ def get_report_output(self, report_output_id: int) -> dict | None:
f"Invalid report output ID: {report_output_id}. Must be a positive integer."
)

row: LegacyRow | None = database.query(
row: Row | None = database.query(
"SELECT * FROM report_outputs WHERE id = ?",
(report_output_id,),
).fetchone()
Expand Down
4 changes: 2 additions & 2 deletions policyengine_api/services/simulation_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from sqlalchemy.engine.row import LegacyRow
from sqlalchemy.engine.row import Row

from policyengine_api.data import database
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_simulation(
f"Invalid simulation ID: {simulation_id}. Must be a positive integer."
)

row: LegacyRow | None = database.query(
row: Row | None = database.query(
"SELECT * FROM simulations WHERE id = ? AND country_id = ?",
(simulation_id, country_id),
).fetchone()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"python-dotenv",
"redis",
"rq",
"sqlalchemy>=1.4,<2",
"sqlalchemy>=2,<3",
"streamlit",
"werkzeug",
"Flask-Caching>=2,<3",
Expand Down
10 changes: 5 additions & 5 deletions tests/to_refactor/python/test_household_routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import json
from unittest.mock import MagicMock, patch
from sqlalchemy.engine.row import LegacyRow

from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS

Expand All @@ -16,8 +15,9 @@
class TestGetHousehold:
def test_get_existing_household(self, rest_client, mock_database):
"""Test getting an existing household."""
# Mock database response
mock_row = MagicMock(spec=LegacyRow)
# Mock database response as a dict-like object
# (SQLAlchemy v2 Row objects support dict() via ._mapping)
mock_row = MagicMock()
mock_row.__getitem__.side_effect = lambda x: valid_db_row[x]
mock_row.keys.return_value = valid_db_row.keys()
mock_database.query().fetchone.return_value = mock_row
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_create_household_success(
):
"""Test successfully creating a new household."""
# Mock database responses
mock_row = MagicMock(spec=LegacyRow)
mock_row = MagicMock()
mock_row.__getitem__.side_effect = lambda x: {"id": 1}[x]
mock_database.query().fetchone.return_value = mock_row

Expand Down Expand Up @@ -111,7 +111,7 @@ def test_update_household_success(
):
"""Test successfully updating an existing household."""
# Mock getting existing household
mock_row = MagicMock(spec=LegacyRow)
mock_row = MagicMock()
mock_row.__getitem__.side_effect = lambda x: valid_db_row[x]
mock_row.keys.return_value = valid_db_row.keys()
mock_database.query().fetchone.return_value = mock_row
Expand Down
Loading
Loading