From c0e719eca91c974cdbd7c77aa69de14bdafb19f8 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 6 Mar 2026 07:53:13 -0500 Subject: [PATCH] Switch code formatter from Black to Ruff Replace Black with Ruff for code formatting across the project: - Update Makefile format target to use `ruff format .` - Replace Black GitHub Actions with `ruff format --check .` in CI - Add ruff>=0.9.0 as a dev dependency in setup.py - Reformat all Python files with ruff defaults Co-Authored-By: Claude Opus 4.6 --- .github/workflows/pr.yml | 8 +- .github/workflows/push.yml | 10 +- Makefile | 2 +- changelog_entry.yaml | 4 + dashboard/app.py | 20 +-- dashboard/experiments/gpt4_api_user.py | 14 +- gcp/bump_country_package.py | 10 +- gcp/export.py | 8 +- .../ai_prompts/simulation_analysis_prompt.py | 12 +- policyengine_api/api.py | 16 +- policyengine_api/country.py | 50 +++--- .../data/congressional_districts.py | 4 +- policyengine_api/data/data.py | 8 +- policyengine_api/data/model_setup.py | 8 +- policyengine_api/data/places.py | 3 +- policyengine_api/endpoints/economy/compare.py | 74 +++------ policyengine_api/endpoints/household.py | 10 +- policyengine_api/endpoints/policy.py | 32 +--- policyengine_api/routes/economy_routes.py | 38 ++--- policyengine_api/routes/household_routes.py | 20 +-- policyengine_api/routes/metadata_routes.py | 4 +- policyengine_api/routes/policy_routes.py | 4 +- .../routes/report_output_routes.py | 12 +- .../routes/simulation_analysis_routes.py | 4 +- policyengine_api/routes/simulation_routes.py | 4 +- .../services/ai_analysis_service.py | 4 +- .../services/ai_prompt_service.py | 1 - policyengine_api/services/economy_service.py | 71 ++++----- .../services/household_service.py | 14 +- .../services/report_output_service.py | 13 +- .../services/simulation_analysis_service.py | 12 +- .../services/simulation_service.py | 17 +-- .../services/tracer_analysis_service.py | 16 +- .../validate_household_payload.py | 4 +- .../validate_set_policy_payload.py | 4 +- policyengine_api/utils/singleton.py | 4 +- setup.py | 2 +- .../test_environment_variables.py | 22 ++- tests/fixtures/integration/simulations.py | 12 +- .../fixtures/services/ai_analysis_service.py | 12 +- tests/fixtures/services/economy_service.py | 15 +- tests/fixtures/services/household_fixtures.py | 4 +- tests/fixtures/services/policy_service.py | 4 +- tests/integration/test_simulations.py | 21 +-- tests/to_refactor/api/test_api.py | 8 +- .../to_refactor_household_fixtures.py | 8 +- .../python/test_ai_analysis_service_old.py | 4 +- .../python/test_household_routes.py | 12 +- .../python/test_policy_service_old.py | 33 +--- .../python/test_simulation_analysis_routes.py | 12 +- .../python/test_tracer_analysis_routes.py | 12 +- .../python/test_us_policy_macro.py | 16 +- .../python/test_user_profile_routes.py | 16 +- .../python/test_validate_household_payload.py | 8 +- .../python/test_yearly_var_removal.py | 42 ++--- .../test_simulation_analysis_prompt.py | 13 +- .../unit/data/test_congressional_districts.py | 95 +++--------- tests/unit/data/test_sqlalchemy_v2.py | 24 +-- tests/unit/endpoints/economy/test_compare.py | 144 +++++------------- tests/unit/libs/test_simulation_api_modal.py | 33 +--- .../unit/services/test_ai_analysis_service.py | 4 +- tests/unit/services/test_create_profile.py | 1 - tests/unit/services/test_economy_service.py | 142 +++++------------ tests/unit/services/test_household_service.py | 5 +- tests/unit/services/test_metadata_service.py | 21 +-- tests/unit/services/test_policy_service.py | 33 +--- .../services/test_report_output_service.py | 38 ++--- .../unit/services/test_simulation_service.py | 21 +-- .../services/test_tracer_analysis_service.py | 8 +- tests/unit/services/test_tracer_service.py | 4 +- .../services/test_update_profile_service.py | 13 +- tests/unit/services/test_user_service.py | 5 +- tests/unit/test_country.py | 16 +- 73 files changed, 403 insertions(+), 1019 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index b55f92cfc..cf3611eee 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -14,10 +14,10 @@ jobs: uses: actions/checkout@v4 - name: Setup Python uses: actions/setup-python@v5 - - name: Format with Black - uses: psf/black@stable - with: - options: ". -l 79 --check" + - name: Install ruff + run: pip install ruff>=0.9.0 + - name: Format check with ruff + run: ruff format --check . check-version: name: Check version runs-on: ubuntu-latest diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 65d0ec85e..2834e43f6 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -20,10 +20,12 @@ jobs: steps: - name: Checkout repo uses: actions/checkout@v4 - - name: Check formatting - uses: "lgeiger/black-action@master" - with: - args: ". -l 79 --check" + - name: Setup Python + uses: actions/setup-python@v5 + - name: Install ruff + run: pip install ruff>=0.9.0 + - name: Format check with ruff + run: ruff format --check . ensure-model-version-aligns-with-sim-api: name: Ensure model version aligns with simulation API runs-on: ubuntu-latest diff --git a/Makefile b/Makefile index fef3abcf6..bece7d1b1 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ debug-test: MAX_HOUSEHOLDS=1000 FLASK_DEBUG=1 pytest -vv --durations=0 tests format: - black . -l 79 + ruff format . deploy: python gcp/export.py diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..2533340d6 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + changed: + - Switched code formatter from Black to Ruff. diff --git a/dashboard/app.py b/dashboard/app.py index 79da1fa4a..cdfcf33e2 100644 --- a/dashboard/app.py +++ b/dashboard/app.py @@ -19,12 +19,8 @@ st.subheader("Look up a policy") -policy_id = int( - st.text_input("Enter a policy ID", "1", key="policy_lookup_text") -) -country_id = st.text_input( - "Enter a country ID", "uk", key="policy_lookup_country" -) +policy_id = int(st.text_input("Enter a policy ID", "1", key="policy_lookup_text")) +country_id = st.text_input("Enter a country ID", "uk", key="policy_lookup_country") if st.button("Look up policy", key="policy_lookup"): try: results = database.query( @@ -40,9 +36,7 @@ policy_id = int(st.text_input("Enter a policy ID", "1")) country_id = st.text_input("Enter a country ID", "uk") -new_label = st.text_input( - "Enter a new label", "New label", key="policy_label_text" -) +new_label = st.text_input("Enter a new label", "New label", key="policy_label_text") if st.button("Set policy label", key="policy_label"): try: database.set_policy_label(policy_id, country_id, new_label) @@ -54,12 +48,8 @@ st.subheader("Delete a policy") -policy_id = int( - st.text_input("Enter a policy ID", "1", key="policy_delete_text") -) -country_id = st.text_input( - "Enter a country ID", "uk", key="policy_delete_country" -) +policy_id = int(st.text_input("Enter a policy ID", "1", key="policy_delete_text")) +country_id = st.text_input("Enter a country ID", "uk", key="policy_delete_country") if st.button("Delete policy", key="policy_delete"): try: database.delete_policy(policy_id, country_id) diff --git a/dashboard/experiments/gpt4_api_user.py b/dashboard/experiments/gpt4_api_user.py index 8746be739..37a2656d4 100644 --- a/dashboard/experiments/gpt4_api_user.py +++ b/dashboard/experiments/gpt4_api_user.py @@ -13,11 +13,9 @@ @st.cache_data def get_embeddings_df(): - embeddings_df = pd.read_csv( - "parameter_embeddings.csv.gz", compression="gzip" - ) - embeddings_df.parameter_embedding = ( - embeddings_df.parameter_embedding.apply(lambda x: eval(x)) + embeddings_df = pd.read_csv("parameter_embeddings.csv.gz", compression="gzip") + embeddings_df.parameter_embedding = embeddings_df.parameter_embedding.apply( + lambda x: eval(x) ) return embeddings_df @@ -55,11 +53,7 @@ def embed(prompt, engine="text-embedding-ada-002"): lambda x: cosine_similarity(x, embedding) ) -top5 = ( - embeddings_df.sort_values("similarities", ascending=False) - .head(5)["json"] - .values -) +top5 = embeddings_df.sort_values("similarities", ascending=False).head(5)["json"].values # display in streamlit diff --git a/gcp/bump_country_package.py b/gcp/bump_country_package.py index 3c5341a7c..c73f3e7c8 100644 --- a/gcp/bump_country_package.py +++ b/gcp/bump_country_package.py @@ -16,9 +16,7 @@ def main(): help="Country package to bump", choices=["policyengine-uk", "policyengine-us", "policyengine-canada"], ) - parser.add_argument( - "--version", type=str, required=True, help="Version to bump to" - ) + parser.add_argument("--version", type=str, required=True, help="Version to bump to") args = parser.parse_args() country = args.country version = args.version @@ -44,9 +42,9 @@ def bump_country_package(country, version): with open(setup_py_path, "w") as f: f.write(setup_py) - country_package_full_name = country.replace( - "policyengine", "PolicyEngine" - ).replace("-", " ") + country_package_full_name = country.replace("policyengine", "PolicyEngine").replace( + "-", " " + ) country_id = country.replace("policyengine-", "") country_package_full_name = country_package_full_name.replace( country_id, country_id.upper() diff --git a/gcp/export.py b/gcp/export.py index a1694db48..b4eaafc00 100644 --- a/gcp/export.py +++ b/gcp/export.py @@ -32,13 +32,9 @@ dockerfile = dockerfile.replace( ".github_microdata_token", GITHUB_MICRODATA_TOKEN ) - dockerfile = dockerfile.replace( - ".anthropic_api_key", ANTHROPIC_API_KEY - ) + dockerfile = dockerfile.replace(".anthropic_api_key", ANTHROPIC_API_KEY) dockerfile = dockerfile.replace(".openai_api_key", OPENAI_API_KEY) - dockerfile = dockerfile.replace( - ".hugging_face_token", HUGGING_FACE_TOKEN - ) + dockerfile = dockerfile.replace(".hugging_face_token", HUGGING_FACE_TOKEN) with open(dockerfile_location, "w") as f: f.write(dockerfile) diff --git a/policyengine_api/ai_prompts/simulation_analysis_prompt.py b/policyengine_api/ai_prompts/simulation_analysis_prompt.py index e7605771f..dc809312e 100644 --- a/policyengine_api/ai_prompts/simulation_analysis_prompt.py +++ b/policyengine_api/ai_prompts/simulation_analysis_prompt.py @@ -95,18 +95,12 @@ def generate_simulation_analysis_prompt(params: InboundParameters) -> str: ) impact_budget: str = json.dumps(parameters.impact["budget"]) - impact_intra_decile: dict[str, Any] = json.dumps( - parameters.impact["intra_decile"] - ) + impact_intra_decile: dict[str, Any] = json.dumps(parameters.impact["intra_decile"]) impact_decile: str = json.dumps(parameters.impact["decile"]) impact_inequality: str = json.dumps(parameters.impact["inequality"]) impact_poverty: str = json.dumps(parameters.impact["poverty"]["poverty"]) - impact_deep_poverty: str = json.dumps( - parameters.impact["poverty"]["deep_poverty"] - ) - impact_poverty_by_gender: str = json.dumps( - parameters.impact["poverty_by_gender"] - ) + impact_deep_poverty: str = json.dumps(parameters.impact["poverty"]["deep_poverty"]) + impact_poverty_by_gender: str = json.dumps(parameters.impact["poverty_by_gender"]) all_parameters: AllParameters = AllParameters.model_validate( { diff --git a/policyengine_api/api.py b/policyengine_api/api.py index b22529b31..112cce9ac 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -132,9 +132,7 @@ def log_timing(message): app.route("//calculate-full", methods=["POST"])( cache.cached(make_cache_key=make_cache_key)( - lambda *args, **kwargs: get_calculate( - *args, **kwargs, add_missing=True - ) + lambda *args, **kwargs: get_calculate(*args, **kwargs, add_missing=True) ) ) log_timing("Calculate-full endpoint registered") @@ -153,9 +151,7 @@ def log_timing(message): app.route("//user-policy", methods=["PUT"])(update_user_policy) log_timing("User policy update endpoint registered") -app.route("//user-policy/", methods=["GET"])( - get_user_policy -) +app.route("//user-policy/", methods=["GET"])(get_user_policy) log_timing("User policy get endpoint registered") app.register_blueprint(user_profile_bp) @@ -177,9 +173,7 @@ def log_timing(message): @app.route("/liveness-check", methods=["GET"]) def liveness_check(): - return flask.Response( - "OK", status=200, headers={"Content-Type": "text/plain"} - ) + return flask.Response("OK", status=200, headers={"Content-Type": "text/plain"}) log_timing("Liveness check endpoint registered") @@ -187,9 +181,7 @@ def liveness_check(): @app.route("/readiness-check", methods=["GET"]) def readiness_check(): - return flask.Response( - "OK", status=200, headers={"Content-Type": "text/plain"} - ) + return flask.Response("OK", status=200, headers={"Content-Type": "text/plain"}) log_timing("Readiness check endpoint registered") diff --git a/policyengine_api/country.py b/policyengine_api/country.py index 33e0b00e0..befa49851 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -60,9 +60,7 @@ def build_metadata(self): }[self.country_id], basicInputs=self.tax_benefit_system.basic_inputs, modelled_policies=self.tax_benefit_system.modelled_policies, - version=get_package_version( - self.country_package_name.replace("_", "-") - ), + version=get_package_version(self.country_package_name.replace("_", "-")), ) def build_microsimulation_options(self) -> dict: @@ -77,13 +75,9 @@ def build_microsimulation_options(self) -> dict: region = [ dict(name="uk", label="the UK", type="national"), dict(name="country/england", label="England", type="country"), - dict( - name="country/scotland", label="Scotland", type="country" - ), + dict(name="country/scotland", label="Scotland", type="country"), dict(name="country/wales", label="Wales", type="country"), - dict( - name="country/ni", label="Northern Ireland", type="country" - ), + dict(name="country/ni", label="Northern Ireland", type="country"), ] for i in range(len(constituency_names)): region.append( @@ -130,9 +124,7 @@ def build_microsimulation_options(self) -> dict: dict(name="state/co", label="Colorado", type="state"), dict(name="state/ct", label="Connecticut", type="state"), dict(name="state/de", label="Delaware", type="state"), - dict( - name="state/dc", label="District of Columbia", type="state" - ), + dict(name="state/dc", label="District of Columbia", type="state"), dict(name="state/fl", label="Florida", type="state"), dict(name="state/ga", label="Georgia", type="state"), dict(name="state/hi", label="Hawaii", type="state"), @@ -276,9 +268,9 @@ def build_variables(self) -> dict: dict(value=value.name, label=value.value) for value in variable.possible_values ] - variable_data[variable_name][ - "defaultValue" - ] = variable.default_value.name + variable_data[variable_name]["defaultValue"] = ( + variable.default_value.name + ) return variable_data def build_parameters(self) -> dict: @@ -299,9 +291,7 @@ def build_parameters(self) -> dict: ), } elif isinstance(parameter, ParameterScaleBracket): - bracket_index = int( - parameter.name[parameter.name.index("[") + 1 : -1] - ) + bracket_index = int(parameter.name[parameter.name.index("[") + 1 : -1]) # Set the label to 'first bracket' for the first bracket, 'second bracket' for the second, etc. bracket_label = f"bracket {bracket_index + 1}" parameter_data[parameter.name] = { @@ -378,9 +368,7 @@ def calculate( for parameter_name in reform: for time_period, value in reform[parameter_name].items(): start_instant, end_instant = time_period.split(".") - parameter = get_parameter( - system.parameters, parameter_name - ) + parameter = get_parameter(system.parameters, parameter_name) node_type = type(parameter.values_list[-1].value) if node_type == int: node_type = float @@ -434,9 +422,9 @@ def calculate( if any([math.isinf(value) for value in result]): raise ValueError("Infinite value") else: - household[entity_plural][entity_id][variable_name][ - period - ] = result + household[entity_plural][entity_id][variable_name][period] = ( + result + ) else: entity_index = population.get_index(entity_id) if variable.value_type == Enum: @@ -453,19 +441,15 @@ def calculate( else: entity_result = result.tolist()[entity_index] - household[entity_plural][entity_id][variable_name][ - period - ] = entity_result + household[entity_plural][entity_id][variable_name][period] = ( + entity_result + ) except Exception as e: if "axes" in household: pass else: - household[entity_plural][entity_id][variable_name][ - period - ] = None - print( - f"Error computing {variable_name} for {entity_id}: {e}" - ) + household[entity_plural][entity_id][variable_name][period] = None + print(f"Error computing {variable_name} for {entity_id}: {e}") tracer_output = simulation.tracer.computation_log log_lines = tracer_output.lines(aggregate=False, max_depth=10) diff --git a/policyengine_api/data/congressional_districts.py b/policyengine_api/data/congressional_districts.py index 2e728991b..f5218fd36 100644 --- a/policyengine_api/data/congressional_districts.py +++ b/policyengine_api/data/congressional_districts.py @@ -683,9 +683,7 @@ def build_congressional_district_metadata() -> list[dict]: return [ { "name": _build_district_name(district.state_code, district.number), - "label": _build_district_label( - district.state_code, district.number - ), + "label": _build_district_label(district.state_code, district.number), "type": "congressional_district", "state_abbreviation": district.state_code, "state_name": STATE_CODE_TO_NAME[district.state_code], diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index 2318fc43a..7dcb96c43 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -58,9 +58,7 @@ def __init__( self.local = local if local: # Local development uses a sqlite database. - self.db_url = ( - REPO / "policyengine_api" / "data" / "policyengine.db" - ) + self.db_url = REPO / "policyengine_api" / "data" / "policyengine.db" if initialize or not Path(self.db_url).exists(): self.initialize() else: @@ -69,9 +67,7 @@ def __init__( self.initialize() def _create_pool(self): - instance_connection_name = ( - "policyengine-api:us-central1:policyengine-api-data" - ) + instance_connection_name = "policyengine-api:us-central1:policyengine-api-data" self.connector = Connector() db_user = "policyengine" db_pass = os.environ["POLICYENGINE_DB_PASSWORD"] diff --git a/policyengine_api/data/model_setup.py b/policyengine_api/data/model_setup.py index a2a6a3ee7..739f7bbcc 100644 --- a/policyengine_api/data/model_setup.py +++ b/policyengine_api/data/model_setup.py @@ -37,11 +37,7 @@ def get_dataset_version(country_id: str) -> str | None: for dataset in datasets["uk"]: - datasets["uk"][ - dataset - ] = f"{datasets['uk'][dataset]}@{get_dataset_version('uk')}" + datasets["uk"][dataset] = f"{datasets['uk'][dataset]}@{get_dataset_version('uk')}" for dataset in datasets["us"]: - datasets["us"][ - dataset - ] = f"{datasets['us'][dataset]}@{get_dataset_version('us')}" + datasets["us"][dataset] = f"{datasets['us'][dataset]}@{get_dataset_version('us')}" diff --git a/policyengine_api/data/places.py b/policyengine_api/data/places.py index e588489fe..f24467f83 100644 --- a/policyengine_api/data/places.py +++ b/policyengine_api/data/places.py @@ -46,6 +46,5 @@ def validate_place_code(place_code: str) -> None: if not place_fips.isdigit() or len(place_fips) != 5: raise ValueError( - f"Invalid FIPS code in place: '{place_fips}'. " - "Expected 5-digit FIPS code" + f"Invalid FIPS code in place: '{place_fips}'. Expected 5-digit FIPS code" ) diff --git a/policyengine_api/endpoints/economy/compare.py b/policyengine_api/endpoints/economy/compare.py index c453e8063..3d3cf9c27 100644 --- a/policyengine_api/endpoints/economy/compare.py +++ b/policyengine_api/endpoints/economy/compare.py @@ -10,12 +10,8 @@ def budgetary_impact(baseline: dict, reform: dict) -> dict: tax_revenue_impact = reform["total_tax"] - baseline["total_tax"] - state_tax_revenue_impact = ( - reform["total_state_tax"] - baseline["total_state_tax"] - ) - benefit_spending_impact = ( - reform["total_benefits"] - baseline["total_benefits"] - ) + state_tax_revenue_impact = reform["total_state_tax"] - baseline["total_state_tax"] + benefit_spending_impact = reform["total_benefits"] - baseline["total_benefits"] budgetary_impact = tax_revenue_impact - benefit_spending_impact return dict( budgetary_impact=budgetary_impact, @@ -28,14 +24,10 @@ def budgetary_impact(baseline: dict, reform: dict) -> dict: def labor_supply_response(baseline: dict, reform: dict) -> dict: - substitution_lsr = ( - reform["substitution_lsr"] - baseline["substitution_lsr"] - ) + substitution_lsr = reform["substitution_lsr"] - baseline["substitution_lsr"] income_lsr = reform["income_lsr"] - baseline["income_lsr"] total_change = substitution_lsr + income_lsr - revenue_change = ( - reform["budgetary_impact_lsr"] - baseline["budgetary_impact_lsr"] - ) + revenue_change = reform["budgetary_impact_lsr"] - baseline["budgetary_impact_lsr"] substitution_lsr_hh = np.array(reform["substitution_lsr_hh"]) - np.array( baseline["substitution_lsr_hh"] @@ -48,17 +40,13 @@ def labor_supply_response(baseline: dict, reform: dict) -> dict: total_lsr_hh = substitution_lsr_hh + income_lsr_hh - emp_income = MicroSeries( - baseline["employment_income_hh"], weights=household_weight - ) + emp_income = MicroSeries(baseline["employment_income_hh"], weights=household_weight) self_emp_income = MicroSeries( baseline["self_employment_income_hh"], weights=household_weight ) earnings = emp_income + self_emp_income original_earnings = earnings - total_lsr_hh - substitution_lsr_hh = MicroSeries( - substitution_lsr_hh, weights=household_weight - ) + substitution_lsr_hh = MicroSeries(substitution_lsr_hh, weights=household_weight) income_lsr_hh = MicroSeries(income_lsr_hh, weights=household_weight) decile_avg = dict( @@ -81,9 +69,7 @@ def labor_supply_response(baseline: dict, reform: dict) -> dict: substitution=(substitution_lsr_hh.sum() / original_earnings.sum()), ) - decile_rel["income"] = { - int(k): v for k, v in decile_rel["income"].items() if k > 0 - } + decile_rel["income"] = {int(k): v for k, v in decile_rel["income"].items() if k > 0} decile_rel["substitution"] = { int(k): v for k, v in decile_rel["substitution"].items() if k > 0 } @@ -112,9 +98,7 @@ def labor_supply_response(baseline: dict, reform: dict) -> dict: ) -def detailed_budgetary_impact( - baseline: dict, reform: dict, country_id: str -) -> dict: +def detailed_budgetary_impact(baseline: dict, reform: dict, country_id: str) -> dict: result = {} if country_id == "uk": for program in baseline["programs"]: @@ -122,8 +106,7 @@ def detailed_budgetary_impact( result[program] = dict( baseline=baseline["programs"][program], reform=reform["programs"][program], - difference=reform["programs"][program] - - baseline["programs"][program], + difference=reform["programs"][program] - baseline["programs"][program], ) return result @@ -289,9 +272,7 @@ def poverty_impact(baseline: dict, reform: dict) -> dict: reform=float(reform_deep_poverty[age < 18].mean()), ), adult=dict( - baseline=float( - baseline_deep_poverty[(age >= 18) & (age < 65)].mean() - ), + baseline=float(baseline_deep_poverty[(age >= 18) & (age < 65)].mean()), reform=float(reform_deep_poverty[(age >= 18) & (age < 65)].mean()), ), senior=dict( @@ -329,9 +310,7 @@ def intra_decile_impact(baseline: dict, reform: dict) -> dict: baseline["household_count_people"], weights=baseline_income.weights ) decile = MicroSeries(baseline["household_income_decile"]).values - income_change = compute_income_change( - baseline_income.values, reform_income.values - ) + income_change = compute_income_change(baseline_income.values, reform_income.values) # Within each decile, calculate the percentage of people who: # 1. Gained more than 5% of their income @@ -353,7 +332,6 @@ def intra_decile_impact(baseline: dict, reform: dict) -> dict: for lower, upper, label in zip(BOUNDS[:-1], BOUNDS[1:], LABELS): outcome_groups[label] = [] for i in range(1, 11): - in_decile: bool = decile == i in_group: bool = (income_change > lower) & (income_change <= upper) in_both: bool = in_decile & in_group @@ -365,9 +343,7 @@ def intra_decile_impact(baseline: dict, reform: dict) -> dict: if people_in_decile == 0 and people_in_both == 0: people_in_proportion: float = 0.0 else: - people_in_proportion: float = float( - people_in_both / people_in_decile - ) + people_in_proportion: float = float(people_in_both / people_in_decile) outcome_groups[label].append(people_in_proportion) @@ -386,9 +362,7 @@ def intra_wealth_decile_impact(baseline: dict, reform: dict) -> dict: baseline["household_count_people"], weights=baseline_income.weights ) decile = MicroSeries(baseline["household_wealth_decile"]).values - income_change = compute_income_change( - baseline_income.values, reform_income.values - ) + income_change = compute_income_change(baseline_income.values, reform_income.values) # Within each decile, calculate the percentage of people who: # 1. Gained more than 5% of their income @@ -410,7 +384,6 @@ def intra_wealth_decile_impact(baseline: dict, reform: dict) -> dict: for lower, upper, label in zip(BOUNDS[:-1], BOUNDS[1:], LABELS): outcome_groups[label] = [] for i in range(1, 11): - in_decile: bool = decile == i in_group: bool = (income_change > lower) & (income_change <= upper) in_both: bool = in_decile & in_group @@ -422,9 +395,7 @@ def intra_wealth_decile_impact(baseline: dict, reform: dict) -> dict: if people_in_decile == 0 and people_in_both == 0: people_in_proportion = 0 else: - people_in_proportion: float = float( - people_in_both / people_in_decile - ) + people_in_proportion: float = float(people_in_both / people_in_decile) outcome_groups[label].append(people_in_proportion) @@ -506,9 +477,7 @@ def poverty_racial_breakdown(baseline: dict, reform: dict) -> dict: reform_poverty = MicroSeries( reform["person_in_poverty"], weights=baseline_poverty.weights ) - race = MicroSeries( - baseline["race"] - ) # Can be WHITE, BLACK, HISPANIC, or OTHER. + race = MicroSeries(baseline["race"]) # Can be WHITE, BLACK, HISPANIC, or OTHER. poverty = dict( white=dict( @@ -604,7 +573,9 @@ def uk_constituency_breakdown( repo_filename="parliamentary_constituency_weights.h5", ) with h5py.File(constituency_weights_path, "r") as f: - weights = f["2025"][ + weights = f[ + "2025" + ][ ... ] # {2025: array(650, 100180) where cell i, j is the weight of household record i in constituency j} @@ -750,10 +721,7 @@ def uk_local_authority_breakdown( continue elif selected_country == "wales" and not code.startswith("W"): continue - elif ( - selected_country == "northern_ireland" - and not code.startswith("N") - ): + elif selected_country == "northern_ireland" and not code.startswith("N"): continue weight: np.ndarray = weights[i] @@ -839,9 +807,7 @@ def compare_economic_outputs( uk_local_authority_breakdown(baseline, reform, country_id, region) ) if local_authority_impact_data is not None: - local_authority_impact_data = ( - local_authority_impact_data.model_dump() - ) + local_authority_impact_data = local_authority_impact_data.model_dump() try: wealth_decile_impact_data = wealth_decile_impact(baseline, reform) intra_wealth_decile_impact_data = intra_wealth_decile_impact( diff --git a/policyengine_api/endpoints/household.py b/policyengine_api/endpoints/household.py index b841c5e10..edd647906 100644 --- a/policyengine_api/endpoints/household.py +++ b/policyengine_api/endpoints/household.py @@ -41,11 +41,7 @@ def add_yearly_variables(household, country_id): if variables[variable]["isInputVariable"]: household[entity_plural][entity][ variables[variable]["name"] - ] = { - household_year: variables[variable][ - "defaultValue" - ] - } + ] = {household_year: variables[variable]["defaultValue"]} else: household[entity_plural][entity][ variables[variable]["name"] @@ -75,9 +71,7 @@ def get_household_year(household): @validate_country -def get_household_under_policy( - country_id: str, household_id: str, policy_id: str -): +def get_household_under_policy(country_id: str, household_id: str, policy_id: str): """Get a household's output data under a given policy. Args: diff --git a/policyengine_api/endpoints/policy.py b/policyengine_api/endpoints/policy.py index d88aed3e1..290fe895e 100644 --- a/policyengine_api/endpoints/policy.py +++ b/policyengine_api/endpoints/policy.py @@ -30,9 +30,7 @@ def get_policy_search(country_id: str) -> dict: query = request.args.get("query", "") # The "json.loads" default type is added to convert lowercase # "true" and "false" to Python-friendly bool values - unique_only = request.args.get( - "unique_only", default=False, type=json.loads - ) + unique_only = request.args.get("unique_only", default=False, type=json.loads) try: results = database.query( @@ -47,9 +45,7 @@ def get_policy_search(country_id: str) -> dict: status="error", message=f"No policies found for country {country_id} for query '{query}", ) - return Response( - json.dumps(body), status=404, mimetype="application/json" - ) + return Response(json.dumps(body), status=404, mimetype="application/json") # If unique_only is true, filter results to only include # items where everything except ID is unique @@ -70,22 +66,16 @@ def get_policy_search(country_id: str) -> dict: results = new_results # Format into: [{ id: 1, label: "My policy" }, ...] - policies = [ - dict(id=result["id"], label=result["label"]) for result in results - ] + policies = [dict(id=result["id"], label=result["label"]) for result in results] body = dict( status="ok", message="Policies found", result=policies, ) - return Response( - json.dumps(body), status=200, mimetype="application/json" - ) + return Response(json.dumps(body), status=200, mimetype="application/json") except Exception as e: body = dict(status="error", message=f"Internal server error: {e}") - return Response( - json.dumps(body), status=500, mimetype="application/json" - ) + return Response(json.dumps(body), status=500, mimetype="application/json") @validate_country @@ -177,9 +167,7 @@ def set_user_policy(country_id: str) -> dict: except Exception as e: return Response( json.dumps( - { - "message": f"Internal database error: {e}; please try again later." - } + {"message": f"Internal database error: {e}; please try again later."} ), status=500, mimetype="application/json", @@ -240,9 +228,7 @@ def set_user_policy(country_id: str) -> dict: except Exception as e: return Response( json.dumps( - { - "message": f"Internal database error: {e}; please try again later." - } + {"message": f"Internal database error: {e}; please try again later."} ), status=500, mimetype="application/json", @@ -354,9 +340,7 @@ def update_user_policy(country_id: str) -> dict: except Exception as e: return Response( json.dumps( - { - "message": f"Internal database error: {e}; please try again later." - } + {"message": f"Internal database error: {e}; please try again later."} ), status=500, mimetype="application/json", diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 59e7fdf4c..4279a1b1b 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -18,9 +18,7 @@ "//economy//over/", methods=["GET"], ) -def get_economic_impact( - country_id: str, policy_id: int, baseline_policy_id: int -): +def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int): policy_id = int(policy_id or get_current_law_policy_id(country_id)) baseline_policy_id = int( @@ -39,33 +37,25 @@ def get_economic_impact( include_district_breakdowns_raw = options.pop( "include_district_breakdowns", "false" ) - include_district_breakdowns = ( - include_district_breakdowns_raw.lower() == "true" - ) + include_district_breakdowns = include_district_breakdowns_raw.lower() == "true" if include_district_breakdowns and country_id == "us" and region == "us": dataset = "national-with-breakdowns" target: Literal["general", "cliff"] = options.pop("target", "general") - api_version = options.pop( - "version", COUNTRY_PACKAGE_VERSIONS.get(country_id) - ) + api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) - economic_impact_result: EconomicImpactResult = ( - economy_service.get_economic_impact( - country_id=country_id, - policy_id=policy_id, - baseline_policy_id=baseline_policy_id, - region=region, - dataset=dataset, - time_period=time_period, - options=options, - api_version=api_version, - target=target, - ) + economic_impact_result: EconomicImpactResult = economy_service.get_economic_impact( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options=options, + api_version=api_version, + target=target, ) - result_dict: dict[str, str | dict | None] = ( - economic_impact_result.to_dict() - ) + result_dict: dict[str, str | dict | None] = economic_impact_result.to_dict() return Response( json.dumps( diff --git a/policyengine_api/routes/household_routes.py b/policyengine_api/routes/household_routes.py index 0961d8cd3..d7420b39e 100644 --- a/policyengine_api/routes/household_routes.py +++ b/policyengine_api/routes/household_routes.py @@ -12,9 +12,7 @@ household_service = HouseholdService() -@household_bp.route( - "//household/", methods=["GET"] -) +@household_bp.route("//household/", methods=["GET"]) @validate_country def get_household(country_id: str, household_id: int) -> Response: """ @@ -26,9 +24,7 @@ def get_household(country_id: str, household_id: int) -> Response: """ print(f"Got request for household {household_id} in country {country_id}") - household: dict | None = household_service.get_household( - country_id, household_id - ) + household: dict | None = household_service.get_household(country_id, household_id) if household is None: raise NotFound(f"Household #{household_id} not found.") else: @@ -66,9 +62,7 @@ def post_household(country_id: str) -> Response: label: str | None = payload.get("label") household_json: dict = payload.get("data") - household_id = household_service.create_household( - country_id, household_json, label - ) + household_id = household_service.create_household(country_id, household_json, label) return Response( json.dumps( @@ -85,9 +79,7 @@ def post_household(country_id: str) -> Response: ) -@household_bp.route( - "//household/", methods=["PUT"] -) +@household_bp.route("//household/", methods=["PUT"]) @validate_country def update_household(country_id: str, household_id: int) -> Response: """ @@ -110,9 +102,7 @@ def update_household(country_id: str, household_id: int) -> Response: label: str | None = payload.get("label") household_json: dict = payload.get("data") - household: dict | None = household_service.get_household( - country_id, household_id - ) + household: dict | None = household_service.get_household(country_id, household_id) if household is None: raise NotFound(f"Household #{household_id} not found.") diff --git a/policyengine_api/routes/metadata_routes.py b/policyengine_api/routes/metadata_routes.py index 496d9556d..8dd5465e4 100644 --- a/policyengine_api/routes/metadata_routes.py +++ b/policyengine_api/routes/metadata_routes.py @@ -20,9 +20,7 @@ def get_metadata(country_id: str) -> Response: # Retrieve country metadata and add status and message to the response country_metadata = metadata_service.get_metadata(country_id) return Response( - json.dumps( - {"status": "ok", "message": None, "result": country_metadata} - ), + json.dumps({"status": "ok", "message": None, "result": country_metadata}), status=200, mimetype="application/json", ) diff --git a/policyengine_api/routes/policy_routes.py b/policyengine_api/routes/policy_routes.py index 913eb105c..3fc88fbf4 100644 --- a/policyengine_api/routes/policy_routes.py +++ b/policyengine_api/routes/policy_routes.py @@ -76,6 +76,4 @@ def set_policy(country_id: str) -> Response: ) code = 200 if is_existing_policy else 201 - return Response( - json.dumps(response_body), status=code, mimetype="application/json" - ) + return Response(json.dumps(response_body), status=code, mimetype="application/json") diff --git a/policyengine_api/routes/report_output_routes.py b/policyengine_api/routes/report_output_routes.py index b2c5502d1..94638c53e 100644 --- a/policyengine_api/routes/report_output_routes.py +++ b/policyengine_api/routes/report_output_routes.py @@ -32,9 +32,7 @@ def create_report_output(country_id: str) -> Response: # Extract required fields simulation_1_id = payload.get("simulation_1_id") simulation_2_id = payload.get("simulation_2_id") # Optional - year = payload.get( - "year", CURRENT_YEAR - ) # Default to current year as string + year = payload.get("year", CURRENT_YEAR) # Default to current year as string # Validate required fields if simulation_1_id is None: @@ -94,9 +92,7 @@ def create_report_output(country_id: str) -> Response: raise BadRequest(f"Failed to create report output: {str(e)}") -@report_output_bp.route( - "//report/", methods=["GET"] -) +@report_output_bp.route("//report/", methods=["GET"]) @validate_country def get_report_output(country_id: str, report_id: int) -> Response: """ @@ -108,9 +104,7 @@ def get_report_output(country_id: str, report_id: int) -> Response: """ print(f"Getting report output {report_id} for country {country_id}") - report_output: dict | None = report_output_service.get_report_output( - report_id - ) + report_output: dict | None = report_output_service.get_report_output(report_id) if report_output is None: raise NotFound(f"Report #{report_id} not found.") diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index 893d7cae4..5157b807d 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -16,9 +16,7 @@ simulation_analysis_service = SimulationAnalysisService() -@simulation_analysis_bp.route( - "//simulation-analysis", methods=["POST"] -) +@simulation_analysis_bp.route("//simulation-analysis", methods=["POST"]) @validate_country def execute_simulation_analysis(country_id): print("Got POST request for simulation analysis") diff --git a/policyengine_api/routes/simulation_routes.py b/policyengine_api/routes/simulation_routes.py index 86e6f0ddf..b5aaff19f 100644 --- a/policyengine_api/routes/simulation_routes.py +++ b/policyengine_api/routes/simulation_routes.py @@ -95,9 +95,7 @@ def create_simulation(country_id: str) -> Response: raise BadRequest(f"Failed to create simulation: {str(e)}") -@simulation_bp.route( - "//simulation/", methods=["GET"] -) +@simulation_bp.route("//simulation/", methods=["GET"]) @validate_country def get_simulation(country_id: str, simulation_id: int) -> Response: """ diff --git a/policyengine_api/services/ai_analysis_service.py b/policyengine_api/services/ai_analysis_service.py index fa6c56db4..f2fc3c710 100644 --- a/policyengine_api/services/ai_analysis_service.py +++ b/policyengine_api/services/ai_analysis_service.py @@ -45,9 +45,7 @@ def get_existing_analysis(self, prompt: str) -> Optional[str]: def trigger_ai_analysis(self, prompt: str) -> Generator[str, None, None]: # Configure a Claude client - claude_client = anthropic.Anthropic( - api_key=os.getenv("ANTHROPIC_API_KEY") - ) + claude_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) def generate(): response_text = "" diff --git a/policyengine_api/services/ai_prompt_service.py b/policyengine_api/services/ai_prompt_service.py index a696190bf..f4e88bffa 100644 --- a/policyengine_api/services/ai_prompt_service.py +++ b/policyengine_api/services/ai_prompt_service.py @@ -11,7 +11,6 @@ class AIPromptService: - def get_prompt(self, name: str, input_data: dict) -> str | None: """ Get an AI prompt with a given name, filled with the given data. diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 6f7ab5ab1..031696286 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -147,7 +147,6 @@ def get_economic_impact( """ try: - # Normalize region early for US; this allows us to accommodate legacy # regions that don't contain a region prefix. if country_id == "us": @@ -165,24 +164,22 @@ def get_economic_impact( if country_id == "uk": country_package_version = None - economic_impact_setup_options = ( - EconomicImpactSetupOptions.model_validate( - { - "process_id": process_id, - "country_id": country_id, - "reform_policy_id": policy_id, - "baseline_policy_id": baseline_policy_id, - "region": region, - "dataset": dataset, - "time_period": time_period, - "options": options, - "api_version": api_version, - "target": target, - "model_version": country_package_version, - "data_version": get_dataset_version(country_id), - "options_hash": options_hash, - } - ) + economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( + { + "process_id": process_id, + "country_id": country_id, + "reform_policy_id": policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": dataset, + "time_period": time_period, + "options": options, + "api_version": api_version, + "target": target, + "model_version": country_package_version, + "data_version": get_dataset_version(country_id), + "options_hash": options_hash, + } ) # Logging that we've received a request @@ -260,17 +257,15 @@ def _get_previous_impacts( Fetch any previous simulation runs for the given policy reform. """ - previous_impacts: list[Any] = ( - reform_impacts_service.get_all_reform_impacts( - country_id, - policy_id, - baseline_policy_id, - region, - dataset, - time_period, - options_hash, - api_version, - ) + previous_impacts: list[Any] = reform_impacts_service.get_all_reform_impacts( + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, ) return previous_impacts @@ -349,9 +344,7 @@ def _handle_execution_state( and hasattr(execution, "error") and execution.error ): - error_message = ( - f"Simulation API execution failed: {execution.error}" - ) + error_message = f"Simulation API execution failed: {execution.error}" self._set_reform_impact_error( setup_options=setup_options, @@ -372,9 +365,7 @@ def _handle_execution_state( return EconomicImpactResult.computing() else: - raise ValueError( - f"Unexpected sim API execution state: {execution_state}" - ) + raise ValueError(f"Unexpected sim API execution state: {execution_state}") def _handle_completed_impact( self, @@ -486,9 +477,7 @@ def _setup_sim_options( "baseline": json.loads(baseline_policy), "time_period": time_period, "include_cliffs": include_cliffs, - "region": self._setup_region( - country_id=country_id, region=region - ), + "region": self._setup_region(country_id=country_id, region=region), "data": self._setup_data( country_id=country_id, region=region, dataset=dataset ), @@ -527,9 +516,7 @@ def _validate_us_region(self, region: str) -> None: elif region.startswith("congressional_district/"): district_id = region[len("congressional_district/") :] if district_id.lower() not in get_valid_congressional_districts(): - raise ValueError( - f"Invalid congressional district: '{district_id}'" - ) + raise ValueError(f"Invalid congressional district: '{district_id}'") else: raise ValueError(f"Invalid US region: '{region}'") diff --git a/policyengine_api/services/household_service.py b/policyengine_api/services/household_service.py index 8b3e658f4..bedf64400 100644 --- a/policyengine_api/services/household_service.py +++ b/policyengine_api/services/household_service.py @@ -7,7 +7,6 @@ class HouseholdService: - def get_household(self, country_id: str, household_id: int) -> dict | None: """ Get a household's input data with a given ID. @@ -40,9 +39,7 @@ def get_household(self, country_id: str, household_id: int) -> dict | None: return household except Exception as e: - print( - f"Error fetching household #{household_id}. Details: {str(e)}" - ) + print(f"Error fetching household #{household_id}. Details: {str(e)}") raise e def create_household( @@ -107,7 +104,6 @@ def update_household( print("Updating household") try: - household_hash: str = hash_object(household_json) api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) @@ -123,12 +119,8 @@ def update_household( ) # Fetch the updated JSON back from the table - updated_household: dict = self.get_household( - country_id, household_id - ) + updated_household: dict = self.get_household(country_id, household_id) return updated_household except Exception as e: - print( - f"Error updating household #{household_id}. Details: {str(e)}" - ) + print(f"Error updating household #{household_id}. Details: {str(e)}") raise e diff --git a/policyengine_api/services/report_output_service.py b/policyengine_api/services/report_output_service.py index c1d1765fb..c34c62f79 100644 --- a/policyengine_api/services/report_output_service.py +++ b/policyengine_api/services/report_output_service.py @@ -5,7 +5,6 @@ class ReportOutputService: - def find_existing_report_output( self, country_id: str, @@ -43,17 +42,13 @@ def find_existing_report_output( existing_report = None if row is not None: existing_report = dict(row) - print( - f"Found existing report output with ID: {existing_report['id']}" - ) + print(f"Found existing report output with ID: {existing_report['id']}") # Keep output as JSON string - frontend expects string format return existing_report except Exception as e: - print( - f"Error checking for existing report output. Details: {str(e)}" - ) + print(f"Error checking for existing report output. Details: {str(e)}") raise e def create_report_output( @@ -217,7 +212,5 @@ def update_report_output( return True except Exception as e: - print( - f"Error updating report output #{report_id}. Details: {str(e)}" - ) + print(f"Error updating report output #{report_id}. Details: {str(e)}") raise e diff --git a/policyengine_api/services/simulation_analysis_service.py b/policyengine_api/services/simulation_analysis_service.py index 8949bf2ae..140fe4987 100644 --- a/policyengine_api/services/simulation_analysis_service.py +++ b/policyengine_api/services/simulation_analysis_service.py @@ -29,9 +29,7 @@ def execute_analysis( relevant_parameters: list[dict], relevant_parameter_baseline_values: list[dict], audience: str | None, - ) -> tuple[ - Generator[str, None, None] | str, Literal["streaming", "static"] - ]: + ) -> tuple[Generator[str, None, None] | str, Literal["streaming", "static"]]: """ Execute AI analysis for economy-wide simulation @@ -67,9 +65,7 @@ def execute_analysis( if existing_analysis is not None: return existing_analysis, "static" - print( - "Found no existing AI analysis; triggering new analysis with Claude" - ) + print("Found no existing AI analysis; triggering new analysis with Claude") # Otherwise, pass prompt to Claude, then return streaming function try: analysis = self.trigger_ai_analysis(prompt) @@ -109,9 +105,7 @@ def _generate_simulation_analysis_prompt( } try: - prompt = ai_prompt_service.get_prompt( - "simulation_analysis", prompt_data - ) + prompt = ai_prompt_service.get_prompt("simulation_analysis", prompt_data) return prompt except Exception as e: diff --git a/policyengine_api/services/simulation_service.py b/policyengine_api/services/simulation_service.py index bb5b5d290..7b83689e5 100644 --- a/policyengine_api/services/simulation_service.py +++ b/policyengine_api/services/simulation_service.py @@ -6,7 +6,6 @@ class SimulationService: - def find_existing_simulation( self, country_id: str, @@ -38,9 +37,7 @@ def find_existing_simulation( existing_simulation = None if row is not None: existing_simulation = dict(row) - print( - f"Found existing simulation with ID: {existing_simulation['id']}" - ) + print(f"Found existing simulation with ID: {existing_simulation['id']}") return existing_simulation @@ -98,9 +95,7 @@ def create_simulation( print(f"Error creating simulation. Details: {str(e)}") raise e - def get_simulation( - self, country_id: str, simulation_id: int - ) -> dict | None: + def get_simulation(self, country_id: str, simulation_id: int) -> dict | None: """ Get a simulation record by ID. @@ -131,9 +126,7 @@ def get_simulation( return simulation except Exception as e: - print( - f"Error fetching simulation #{simulation_id}. Details: {str(e)}" - ) + print(f"Error fetching simulation #{simulation_id}. Details: {str(e)}") raise e def update_simulation( @@ -198,7 +191,5 @@ def update_simulation( return True except Exception as e: - print( - f"Error updating simulation #{simulation_id}. Details: {str(e)}" - ) + print(f"Error updating simulation #{simulation_id}. Details: {str(e)}") raise e diff --git a/policyengine_api/services/tracer_analysis_service.py b/policyengine_api/services/tracer_analysis_service.py index 5857fcef6..2fd072f83 100644 --- a/policyengine_api/services/tracer_analysis_service.py +++ b/policyengine_api/services/tracer_analysis_service.py @@ -18,9 +18,7 @@ def execute_analysis( household_id: str, policy_id: str, variable: str, - ) -> tuple[ - Generator[str, None, None] | str, Literal["static", "streaming"] - ]: + ) -> tuple[Generator[str, None, None] | str, Literal["static", "streaming"]]: """ Executes tracer analysis for a variable in a household @@ -44,9 +42,7 @@ def execute_analysis( # Parse the tracer output for our given variable try: - tracer_segment: list[str] = self._parse_tracer_output( - tracer, variable - ) + tracer_segment: list[str] = self._parse_tracer_output(tracer, variable) except Exception as e: print(f"Error parsing tracer output: {str(e)}") raise e @@ -107,17 +103,13 @@ def _parse_tracer_output(self, tracer_output, target_variable): capturing = False # Input validation - if not isinstance(target_variable, str) or not isinstance( - tracer_output, list - ): + if not isinstance(target_variable, str) or not isinstance(tracer_output, list): return result # Create a regex pattern to match the exact variable name # This will match the variable name followed by optional whitespace, # then optional angle brackets with any content, then optional whitespace - pattern = ( - rf"^(\s*)({re.escape(target_variable)})(?!\w)\s*(?:<[^>]*>)?\s*" - ) + pattern = rf"^(\s*)({re.escape(target_variable)})(?!\w)\s*(?:<[^>]*>)?\s*" for line in tracer_output: # Count leading spaces to determine indentation level diff --git a/policyengine_api/utils/payload_validators/validate_household_payload.py b/policyengine_api/utils/payload_validators/validate_household_payload.py index 7b4f7d951..c66f15e26 100644 --- a/policyengine_api/utils/payload_validators/validate_household_payload.py +++ b/policyengine_api/utils/payload_validators/validate_household_payload.py @@ -19,9 +19,7 @@ def validate_household_payload(payload): # Check that label is either string or None, if present if "label" in payload: - if payload["label"] is not None and not isinstance( - payload["label"], str - ): + if payload["label"] is not None and not isinstance(payload["label"], str): return False, "Label must be a string or None" # Check that data is a dictionary diff --git a/policyengine_api/utils/payload_validators/validate_set_policy_payload.py b/policyengine_api/utils/payload_validators/validate_set_policy_payload.py index a48c75bda..f90f80d17 100644 --- a/policyengine_api/utils/payload_validators/validate_set_policy_payload.py +++ b/policyengine_api/utils/payload_validators/validate_set_policy_payload.py @@ -8,9 +8,7 @@ def validate_set_policy_payload(payload: dict) -> tuple[bool, str | None]: # Check that label is either string or None if "label" in payload: - if payload["label"] is not None and not isinstance( - payload["label"], str - ): + if payload["label"] is not None and not isinstance(payload["label"], str): return False, "Label must be a string or None" # Check that data is a dictionary diff --git a/policyengine_api/utils/singleton.py b/policyengine_api/utils/singleton.py index 28e8a0984..3776cb92d 100644 --- a/policyengine_api/utils/singleton.py +++ b/policyengine_api/utils/singleton.py @@ -3,7 +3,5 @@ class Singleton(type): def __call__(cls, *args, **kwargs): if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__( - *args, **kwargs - ) + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] diff --git a/setup.py b/setup.py index 52dd7170f..031b17ec8 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ "microdf_python>=1.0.0", ], extras_require={ - "dev": ["pytest-timeout", "coverage", "pytest-snapshot"], + "dev": ["pytest-timeout", "coverage", "pytest-snapshot", "ruff>=0.9.0"], }, # script policyengine-api-setup -> policyengine_api.setup_data:setup_data entry_points={ diff --git a/tests/env_variables/test_environment_variables.py b/tests/env_variables/test_environment_variables.py index 23a21ea1d..e1c5cec4d 100644 --- a/tests/env_variables/test_environment_variables.py +++ b/tests/env_variables/test_environment_variables.py @@ -27,9 +27,9 @@ def test_hugging_face_token(self): timeout=5, ) - assert ( - token_validation_response.status_code == 200 - ), f"Invalid HUGGING_FACE_TOKEN: {token_validation_response.text}" + assert token_validation_response.status_code == 200, ( + f"Invalid HUGGING_FACE_TOKEN: {token_validation_response.text}" + ) @pytest.mark.skipif( do_not_run_in_debug(), @@ -39,9 +39,7 @@ def test_github_microdata_auth_token(self): """Test if POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN is valid by querying GitHub user API.""" token = os.getenv("POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN") - assert ( - token is not None - ), "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN is not set" + assert token is not None, "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN is not set" headers = { "Authorization": f"Bearer {token}", @@ -55,11 +53,11 @@ def test_github_microdata_auth_token(self): timeout=5, ) - assert ( - token_validation_response.status_code == 200 - ), f"Invalid POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN: {token_validation_response.text}" + assert token_validation_response.status_code == 200, ( + f"Invalid POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN: {token_validation_response.text}" + ) token_user_details = token_validation_response.json() - assert ( - "login" in token_user_details - ), "Token is valid but did not return expected user details" + assert "login" in token_user_details, ( + "Token is valid but did not return expected user details" + ) diff --git a/tests/fixtures/integration/simulations.py b/tests/fixtures/integration/simulations.py index aefddc9fe..741676047 100644 --- a/tests/fixtures/integration/simulations.py +++ b/tests/fixtures/integration/simulations.py @@ -6,7 +6,9 @@ from unittest.mock import Mock, MagicMock, patch from policyengine_api.endpoints.household import add_yearly_variables -STANDARD_AXES_COUNT = 401 # Not formally defined anywhere, but this value is used throughout the API +STANDARD_AXES_COUNT = ( + 401 # Not formally defined anywhere, but this value is used throughout the API +) SMALL_AXES_COUNT = 5 TEST_YEAR = "2025" TEST_STATE = "NY" @@ -67,10 +69,6 @@ def create_household_with_axes(base_household, axes_config): def setup_small_axes_household(base_household, small_axes_config): """Fixture to setup a household with small axes for testing""" - household_with_axes = create_household_with_axes( - base_household, small_axes_config - ) - household_with_axes = add_yearly_variables( - household_with_axes, TEST_COUNTRY_ID - ) + household_with_axes = create_household_with_axes(base_household, small_axes_config) + household_with_axes = add_yearly_variables(household_with_axes, TEST_COUNTRY_ID) return household_with_axes diff --git a/tests/fixtures/services/ai_analysis_service.py b/tests/fixtures/services/ai_analysis_service.py index a2f4d21c4..95bba3039 100644 --- a/tests/fixtures/services/ai_analysis_service.py +++ b/tests/fixtures/services/ai_analysis_service.py @@ -39,14 +39,10 @@ def _configure(text_chunks: list[str]): # Set up mock stream mock_stream = MagicMock() - mock_client.messages.stream.return_value.__enter__.return_value = ( - mock_stream - ) + mock_client.messages.stream.return_value.__enter__.return_value = mock_stream # Configure stream to yield text events - events = [ - MockEvent(event_type="text", text=chunk) for chunk in text_chunks - ] + events = [MockEvent(event_type="text", text=chunk) for chunk in text_chunks] mock_stream.__iter__.return_value = events return mock_client @@ -67,9 +63,7 @@ def _configure(error_type: str): # Set up mock stream mock_stream = MagicMock() - mock_client.messages.stream.return_value.__enter__.return_value = ( - mock_stream - ) + mock_client.messages.stream.return_value.__enter__.return_value = mock_stream # Configure stream to yield an error event error_event = MockEvent(event_type="error", error={"type": error_type}) diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index 9ea1ca24a..687a82a48 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -24,9 +24,7 @@ MOCK_MODEL_VERSION = "1.2.3" MOCK_DATA_VERSION = None -MOCK_REFORM_POLICY_JSON = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} -) +MOCK_REFORM_POLICY_JSON = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) MOCK_BASELINE_POLICY_JSON = json.dumps({}) @@ -133,9 +131,7 @@ def mock_logger(): def mock_datetime(): """Mock datetime.datetime.now().""" mock_now = datetime.datetime(2025, 6, 26, 12, 0, 0) - with patch( - "policyengine_api.services.economy_service.datetime.datetime" - ) as mock: + with patch("policyengine_api.services.economy_service.datetime.datetime") as mock: mock.now.return_value = mock_now yield mock @@ -165,14 +161,11 @@ def create_mock_reform_impact( "options_hash": MOCK_OPTIONS_HASH, "status": status, "api_version": MOCK_API_VERSION, - "reform_impact_json": reform_impact_json - or json.dumps(MOCK_REFORM_IMPACT_DATA), + "reform_impact_json": reform_impact_json or json.dumps(MOCK_REFORM_IMPACT_DATA), "execution_id": execution_id, "start_time": datetime.datetime(2025, 6, 26, 12, 0, 0), "end_time": ( - datetime.datetime(2025, 6, 26, 12, 5, 0) - if status == "ok" - else None + datetime.datetime(2025, 6, 26, 12, 5, 0) if status == "ok" else None ), } diff --git a/tests/fixtures/services/household_fixtures.py b/tests/fixtures/services/household_fixtures.py index 54d49291f..063778c7a 100644 --- a/tests/fixtures/services/household_fixtures.py +++ b/tests/fixtures/services/household_fixtures.py @@ -22,9 +22,7 @@ @pytest.fixture def mock_hash_object(): """Mock the hash_object function.""" - with patch( - "policyengine_api.services.household_service.hash_object" - ) as mock: + with patch("policyengine_api.services.household_service.hash_object") as mock: mock.return_value = valid_hash_value yield mock diff --git a/tests/fixtures/services/policy_service.py b/tests/fixtures/services/policy_service.py index 18ee9071e..6c4a27f66 100644 --- a/tests/fixtures/services/policy_service.py +++ b/tests/fixtures/services/policy_service.py @@ -3,9 +3,7 @@ from unittest.mock import patch valid_policy_json = { - "data": { - "gov.irs.income.bracket.rates.2": {"2024-01-01.2024-12-31": 0.2433} - }, + "data": {"gov.irs.income.bracket.rates.2": {"2024-01-01.2024-12-31": 0.2433}}, } valid_hash_value = "NgJhpeuRVnIAwgYWuJsd2fI/N88rIE6Kcj8q4TPD/i4=" diff --git a/tests/integration/test_simulations.py b/tests/integration/test_simulations.py index 36056f239..9be9ba7e2 100644 --- a/tests/integration/test_simulations.py +++ b/tests/integration/test_simulations.py @@ -13,7 +13,6 @@ class TestSimsWithAxes: - def test__given_any_number_of_axes__sim_returns_valid_arrays( self, ): # , patched_add_yearly_variables, patched_countries_get): @@ -40,20 +39,16 @@ def test__given_any_number_of_axes__sim_returns_valid_arrays( print("Variable name: ", variable_name) if variable_name in FORBIDDEN_VARIABLES: continue - for period in result[entity_type][entity_id][ - variable_name - ]: + for period in result[entity_type][entity_id][variable_name]: print("Period: ", period) - value = result[entity_type][entity_id][variable_name][ - period - ] + value = result[entity_type][entity_id][variable_name][period] print(f"Value: {value}") if isinstance(value, list): # Assert no Nones - assert all( - v is not None for v in value - ), f"None found in {variable_name} for {entity_id} in {period}" + assert all(v is not None for v in value), ( + f"None found in {variable_name} for {entity_id} in {period}" + ) # Assert correct length - assert ( - len(value) == SMALL_AXES_COUNT - ), f"Expected {SMALL_AXES_COUNT} values for {variable_name}, got {len(value)}" + assert len(value) == SMALL_AXES_COUNT, ( + f"Expected {SMALL_AXES_COUNT} values for {variable_name}, got {len(value)}" + ) diff --git a/tests/to_refactor/api/test_api.py b/tests/to_refactor/api/test_api.py index 74f3e2bd6..f0855a6ec 100644 --- a/tests/to_refactor/api/test_api.py +++ b/tests/to_refactor/api/test_api.py @@ -23,9 +23,7 @@ def client(): # - expected_result: the expected result of the endpoint test_paths = [ - path - for path in (Path(__file__).parent).rglob("*") - if path.suffix == ".yaml" + path for path in (Path(__file__).parent).rglob("*") if path.suffix == ".yaml" ] test_data = [yaml.safe_load(path.read_text()) for path in test_paths] test_names = [test["name"] for test in test_data] @@ -70,6 +68,4 @@ def test_response(client, test: dict): json.loads(response.data), test.get("response", {}).get("data", {}) ) elif "html" in test.get("response", {}): - assert response.data.decode("utf-8") == test.get("response", {}).get( - "html", "" - ) + assert response.data.decode("utf-8") == test.get("response", {}).get("html", "") diff --git a/tests/to_refactor/fixtures/to_refactor_household_fixtures.py b/tests/to_refactor/fixtures/to_refactor_household_fixtures.py index 89b854f19..5fa6af91c 100644 --- a/tests/to_refactor/fixtures/to_refactor_household_fixtures.py +++ b/tests/to_refactor/fixtures/to_refactor_household_fixtures.py @@ -22,9 +22,7 @@ @pytest.fixture def mock_hash_object(): """Mock the hash_object function.""" - with patch( - "policyengine_api.services.household_service.hash_object" - ) as mock: + with patch("policyengine_api.services.household_service.hash_object") as mock: mock.return_value = valid_hash_value yield mock @@ -32,7 +30,5 @@ def mock_hash_object(): @pytest.fixture def mock_database(): """Mock the database module.""" - with patch( - "policyengine_api.services.household_service.database" - ) as mock_db: + with patch("policyengine_api.services.household_service.database") as mock_db: yield mock_db diff --git a/tests/to_refactor/python/test_ai_analysis_service_old.py b/tests/to_refactor/python/test_ai_analysis_service_old.py index aa8c825e3..0df3928ca 100644 --- a/tests/to_refactor/python/test_ai_analysis_service_old.py +++ b/tests/to_refactor/python/test_ai_analysis_service_old.py @@ -9,9 +9,7 @@ @patch("policyengine_api.services.ai_analysis_service.local_database") def test_get_existing_analysis_found(mock_db): - mock_db.query.return_value.fetchone.return_value = { - "analysis": "Existing analysis" - } + mock_db.query.return_value.fetchone.return_value = {"analysis": "Existing analysis"} prompt = "Test prompt" output = test_ai_service.get_existing_analysis(prompt) diff --git a/tests/to_refactor/python/test_household_routes.py b/tests/to_refactor/python/test_household_routes.py index 78d766fb8..3fa5af319 100644 --- a/tests/to_refactor/python/test_household_routes.py +++ b/tests/to_refactor/python/test_household_routes.py @@ -46,9 +46,7 @@ def test_get_household_invalid_id(self, rest_client): response = rest_client.get("/us/household/invalid") assert response.status_code == 404 - assert ( - b"The requested URL was not found on the server" in response.data - ) + assert b"The requested URL was not found on the server" in response.data class TestCreateHousehold: @@ -116,9 +114,7 @@ def test_update_household_success( mock_row.keys.return_value = valid_db_row.keys() mock_database.query().fetchone.return_value = mock_row - updated_household = { - "people": {"person1": {"age": 31, "income": 55000}} - } + updated_household = {"people": {"person1": {"age": 31, "income": 55000}}} updated_data = { "data": updated_household, @@ -182,9 +178,7 @@ def test_update_household_invalid_payload(self, rest_client): class TestHouseholdRouteServiceErrors: """Test handling of service-level errors in routes.""" - @patch( - "policyengine_api.services.household_service.HouseholdService.get_household" - ) + @patch("policyengine_api.services.household_service.HouseholdService.get_household") def test_get_household_service_error(self, mock_get, rest_client): """Test GET endpoint when service raises an error.""" mock_get.side_effect = Exception("Database connection failed") diff --git a/tests/to_refactor/python/test_policy_service_old.py b/tests/to_refactor/python/test_policy_service_old.py index a90680d80..a84e1b1b0 100644 --- a/tests/to_refactor/python/test_policy_service_old.py +++ b/tests/to_refactor/python/test_policy_service_old.py @@ -29,18 +29,13 @@ def policy_service(): class TestPolicyService: - - a_test_policy_id = ( - 8 # Pre-seeded current law policies occupy IDs 1 through 5 - ) + a_test_policy_id = 8 # Pre-seeded current law policies occupy IDs 1 through 5 def test_get_policy_success( self, policy_service, mock_database, sample_policy_data ): # Setup mock - mock_database.query.return_value.fetchone.return_value = ( - sample_policy_data - ) + mock_database.query.return_value.fetchone.return_value = sample_policy_data # Test result = policy_service.get_policy("us", self.a_test_policy_id) @@ -64,9 +59,7 @@ def test_get_policy_not_found(self, policy_service, mock_database): assert result is None mock_database.query.assert_called_once() - def test_get_policy_json( - self, policy_service, mock_database, sample_policy_data - ): + def test_get_policy_json(self, policy_service, mock_database, sample_policy_data): # Setup mock mock_database.query.return_value.fetchone.return_value = { "policy_json": sample_policy_data["policy_json"] @@ -131,9 +124,7 @@ def test_set_policy_existing( self, policy_service, mock_database, sample_policy_data ): # Setup mock - mock_database.query.return_value.fetchone.return_value = ( - sample_policy_data - ) + mock_database.query.return_value.fetchone.return_value = sample_policy_data # Test policy_id, message, exists = policy_service.set_policy( @@ -152,9 +143,7 @@ def test_get_unique_policy_with_label( self, policy_service, mock_database, sample_policy_data ): # Setup mock - mock_database.query.return_value.fetchone.return_value = ( - sample_policy_data - ) + mock_database.query.return_value.fetchone.return_value = sample_policy_data # Test result = policy_service._get_unique_policy_with_label( @@ -167,16 +156,12 @@ def test_get_unique_policy_with_label( assert result == sample_policy_data mock_database.query.assert_called_once() - def test_get_unique_policy_with_null_label( - self, policy_service, mock_database - ): + def test_get_unique_policy_with_null_label(self, policy_service, mock_database): # Setup mock mock_database.query.return_value.fetchone.return_value = None # Test - result = policy_service._get_unique_policy_with_label( - "us", "hash123", None - ) + result = policy_service._get_unique_policy_with_label("us", "hash123", None) # Verify assert result is None @@ -207,8 +192,6 @@ def test_error_handling(self, policy_service, mock_database, error_method): elif error_method == "set_policy": policy_service.set_policy("us", "label", {}) else: - policy_service._get_unique_policy_with_label( - "us", "hash", "label" - ) + policy_service._get_unique_policy_with_label("us", "hash", "label") assert str(exc_info.value) == "Database error" diff --git a/tests/to_refactor/python/test_simulation_analysis_routes.py b/tests/to_refactor/python/test_simulation_analysis_routes.py index 0a4812e31..f1f2ab6f1 100644 --- a/tests/to_refactor/python/test_simulation_analysis_routes.py +++ b/tests/to_refactor/python/test_simulation_analysis_routes.py @@ -40,9 +40,7 @@ def test_execute_simulation_analysis_new_analysis(rest_client): ) as mock_trigger: mock_trigger.return_value = (s for s in ["New analysis"]) - response = rest_client.post( - "/us/simulation-analysis", json=test_json - ) + response = rest_client.post("/us/simulation-analysis", json=test_json) assert response.status_code == 200 assert b"New analysis" in response.data @@ -58,9 +56,7 @@ def test_execute_simulation_analysis_error(rest_client): ) as mock_trigger: mock_trigger.side_effect = Exception("Test error") - response = rest_client.post( - "/us/simulation-analysis", json=test_json - ) + response = rest_client.post("/us/simulation-analysis", json=test_json) assert response.status_code == 500 assert "Test error" in response.json.get("message") @@ -95,9 +91,7 @@ def test_execute_simulation_analysis_enhanced_cps(rest_client): with patch( "policyengine_api.services.ai_analysis_service.AIAnalysisService.trigger_ai_analysis" ) as mock_trigger: - mock_trigger.return_value = ( - s for s in ["Enhanced CPS analysis"] - ) + mock_trigger.return_value = (s for s in ["Enhanced CPS analysis"]) response = rest_client.post( "/us/simulation-analysis", json=test_json_enhanced_cps diff --git a/tests/to_refactor/python/test_tracer_analysis_routes.py b/tests/to_refactor/python/test_tracer_analysis_routes.py index f88805f8d..83f7bde23 100644 --- a/tests/to_refactor/python/test_tracer_analysis_routes.py +++ b/tests/to_refactor/python/test_tracer_analysis_routes.py @@ -58,8 +58,7 @@ def test_execute_tracer_analysis_no_tracer(mock_db, rest_client): assert response.status_code == 404 assert ( - "No household simulation tracer found" - in json.loads(response.data)["message"] + "No household simulation tracer found" in json.loads(response.data)["message"] ) @@ -115,9 +114,7 @@ def test_invalid_variable_types(mock_db, rest_client): }, ) assert response.status_code == 400 - assert ( - "variable must be a string" in json.loads(response.data)["message"] - ) + assert "variable must be a string" in json.loads(response.data)["message"] # Test invalid country @@ -218,7 +215,4 @@ def test_validate_tracer_analysis_payload_failure(rest_client): }, ) assert response.status_code == 400 - assert ( - "Missing required key: variable" - in json.loads(response.data)["message"] - ) + assert "Missing required key: variable" in json.loads(response.data)["message"] diff --git a/tests/to_refactor/python/test_us_policy_macro.py b/tests/to_refactor/python/test_us_policy_macro.py index 1f49a1298..f4228523c 100644 --- a/tests/to_refactor/python/test_us_policy_macro.py +++ b/tests/to_refactor/python/test_us_policy_macro.py @@ -52,7 +52,7 @@ def utah_reform_runner(rest_client, region: str = "us"): assert economy_response.status_code == 200 assert economy_response.json["status"] == "computing", ( f'Expected first answer status to be "computing" but it is ' - f'{str(economy_response.json["status"])}' + f"{str(economy_response.json['status'])}" ) while economy_response.json["status"] == "computing": print("Before sleep:", datetime.datetime.now()) @@ -60,9 +60,9 @@ def utah_reform_runner(rest_client, region: str = "us"): print("After sleep:", datetime.datetime.now()) economy_response = rest_client.get(query) print(json.dumps(economy_response.json)) - assert ( - economy_response.json["status"] == "ok" - ), f'Expected status "ok", got {economy_response.json["status"]} with message "{economy_response.json}"' + assert economy_response.json["status"] == "ok", ( + f'Expected status "ok", got {economy_response.json["status"]} with message "{economy_response.json}"' + ) result = economy_response.json["result"] @@ -70,15 +70,11 @@ def utah_reform_runner(rest_client, region: str = "us"): # Ensure that there is some budgetary impact cost = round(result["budget"]["budgetary_impact"] / 1e6, 1) - assert ( - cost / 1867.4 - 1 - ) < 0.01, ( + assert (cost / 1867.4 - 1) < 0.01, ( f"Expected budgetary impact to be 1867.4 million, got {cost} million" ) - assert ( - result["intra_decile"]["all"]["Lose less than 5%"] / 0.534 - 1 - ) < 0.01, ( + assert (result["intra_decile"]["all"]["Lose less than 5%"] / 0.534 - 1) < 0.01, ( f"Expected 53.4% of people to lose less than 5%, got " f"{result['intra_decile']['all']['Lose less than 5%']}" ) diff --git a/tests/to_refactor/python/test_user_profile_routes.py b/tests/to_refactor/python/test_user_profile_routes.py index a3d873dbb..ec30f9eef 100644 --- a/tests/to_refactor/python/test_user_profile_routes.py +++ b/tests/to_refactor/python/test_user_profile_routes.py @@ -42,9 +42,7 @@ def test_set_and_get_record(self, rest_client): assert res.status_code == 200 assert return_object["status"] == "ok" assert return_object["result"]["auth0_id"] == self.auth0_id - assert ( - return_object["result"]["primary_country"] == self.primary_country - ) + assert return_object["result"]["primary_country"] == self.primary_country assert return_object["result"]["username"] == None user_id = return_object["result"]["user_id"] @@ -54,9 +52,7 @@ def test_set_and_get_record(self, rest_client): assert res.status_code == 200 assert return_object["status"] == "ok" - assert ( - return_object["result"]["primary_country"] == self.primary_country - ) + assert return_object["result"]["primary_country"] == self.primary_country assert return_object["result"].get("auth0_id") is None assert return_object["result"]["username"] == None @@ -77,9 +73,7 @@ def test_set_and_get_record(self, rest_client): malicious_updated_profile = {**updated_profile, "auth0_id": "BOGUS"} - res = rest_client.put( - "/us/user-profile", json=malicious_updated_profile - ) + res = rest_client.put("/us/user-profile", json=malicious_updated_profile) return_object = json.loads(res.text) assert res.status_code == 200 @@ -99,9 +93,7 @@ def test_set_and_get_record(self, rest_client): def test_non_existent_record(self, rest_client): non_existent_auth0_id = "non-existent-auth0-id" - res = rest_client.get( - f"/us/user-profile?auth0_id={non_existent_auth0_id}" - ) + res = rest_client.get(f"/us/user-profile?auth0_id={non_existent_auth0_id}") return_object = json.loads(res.text) assert res.status_code == 404 diff --git a/tests/to_refactor/python/test_validate_household_payload.py b/tests/to_refactor/python/test_validate_household_payload.py index 42e6a0708..d45363d0d 100644 --- a/tests/to_refactor/python/test_validate_household_payload.py +++ b/tests/to_refactor/python/test_validate_household_payload.py @@ -14,9 +14,7 @@ class TestHouseholdRouteValidation: {"data": {}, "label": 123}, # Invalid label type ], ) - def test_post_household_invalid_payload( - self, rest_client, invalid_payload - ): + def test_post_household_invalid_payload(self, rest_client, invalid_payload): """Test POST endpoint with various invalid payloads.""" response = rest_client.post( "/us/household", @@ -40,9 +38,7 @@ def test_get_household_invalid_id(self, rest_client, invalid_id): # Default Werkzeug validation returns 404, not 400 assert response.status_code == 404 - assert ( - b"The requested URL was not found on the server" in response.data - ) + assert b"The requested URL was not found on the server" in response.data @pytest.mark.parametrize( "country_id", diff --git a/tests/to_refactor/python/test_yearly_var_removal.py b/tests/to_refactor/python/test_yearly_var_removal.py index e4f463e19..b0d9211d9 100644 --- a/tests/to_refactor/python/test_yearly_var_removal.py +++ b/tests/to_refactor/python/test_yearly_var_removal.py @@ -139,8 +139,8 @@ def interface_test_household_under_policy( # Create a set of all variables listed within the metadata that are yearly, # as well as one that will store all variables accessed while looping # Note: This removes issues with SNAP variables, which are calculated monthly - var_filter = ( - lambda x: (metadata["variables"][x]["definitionPeriod"] == "year") + var_filter = lambda x: ( + (metadata["variables"][x]["definitionPeriod"] == "year") and x not in excluded_vars ) metadata_var_set = set(filter(var_filter, metadata["variables"].keys())) @@ -154,17 +154,14 @@ def interface_test_household_under_policy( # Skip ignored variables if ( variable in excluded_vars - or metadata["variables"][variable]["definitionPeriod"] - != "year" + or metadata["variables"][variable]["definitionPeriod"] != "year" ): continue # Ensure that the variable exists in both # result_object and test_object if variable not in metadata["variables"]: - print( - f"Failing due to variable {variable} not in metadata" - ) + print(f"Failing due to variable {variable} not in metadata") is_test_passing = False break @@ -188,14 +185,10 @@ def interface_test_household_under_policy( results_diff = result_var_set.difference(metadata_var_set) metadata_diff = metadata_var_set.difference(result_var_set) if len(results_diff) > 0: - print( - "Error: The following values are only present in the result object:" - ) + print("Error: The following values are only present in the result object:") print(results_diff) if len(metadata_diff) > 0: - print( - "Error: The following values are only present in the metadata:" - ) + print("Error: The following values are only present in the metadata:") print(metadata_diff) is_test_passing = False @@ -207,9 +200,7 @@ def test_us_household_under_policy(): Test that a US household under current law is created correctly """ - is_test_passing = interface_test_household_under_policy( - "us", "2", ["members"] - ) + is_test_passing = interface_test_household_under_policy("us", "2", ["members"]) assert is_test_passing == True @@ -270,8 +261,8 @@ def test_get_calculate(client): # Create a set of all variables listed within the metadata that are yearly, # as well as one that will store all variables accessed while looping # Note: This removes issues with SNAP variables, which are calculated monthly - var_filter = ( - lambda x: (metadata["variables"][x]["definitionPeriod"] == "year") + var_filter = lambda x: ( + (metadata["variables"][x]["definitionPeriod"] == "year") and x not in excluded_vars ) metadata_var_set = set(filter(var_filter, metadata["variables"].keys())) @@ -285,17 +276,14 @@ def test_get_calculate(client): # Skip ignored variables if ( variable in excluded_vars - or metadata["variables"][variable]["definitionPeriod"] - != "year" + or metadata["variables"][variable]["definitionPeriod"] != "year" ): continue # Ensure that the variable exists in both # result_object and test_object if variable not in metadata["variables"]: - print( - f"Failing due to variable {variable} not in metadata" - ) + print(f"Failing due to variable {variable} not in metadata") is_test_passing = False break @@ -319,14 +307,10 @@ def test_get_calculate(client): results_diff = result_var_set.difference(metadata_var_set) metadata_diff = metadata_var_set.difference(result_var_set) if len(results_diff) > 0: - print( - "Error: The following values are only present in the result object:" - ) + print("Error: The following values are only present in the result object:") print(results_diff) if len(metadata_diff) > 0: - print( - "Error: The following values are only present in the metadata:" - ) + print("Error: The following values are only present in the metadata:") print(metadata_diff) is_test_passing = False diff --git a/tests/unit/ai_prompts/test_simulation_analysis_prompt.py b/tests/unit/ai_prompts/test_simulation_analysis_prompt.py index 05f1931e7..db74f8f73 100644 --- a/tests/unit/ai_prompts/test_simulation_analysis_prompt.py +++ b/tests/unit/ai_prompts/test_simulation_analysis_prompt.py @@ -11,7 +11,6 @@ class TestGenerateSimulationAnalysisPrompt: - def test_given_valid_us_input(self, snapshot): snapshot.snapshot_dir = "tests/snapshots" @@ -29,13 +28,11 @@ def test_given_valid_uk_input(self, snapshot): def test_given_dataset_is_enhanced_cps(self, snapshot): snapshot.snapshot_dir = "tests/snapshots" - valid_enhanced_cps_input_data = ( - given_valid_data_and_dataset_is_enhanced_cps(valid_input_us) + valid_enhanced_cps_input_data = given_valid_data_and_dataset_is_enhanced_cps( + valid_input_us ) - prompt = generate_simulation_analysis_prompt( - valid_enhanced_cps_input_data - ) + prompt = generate_simulation_analysis_prompt(valid_enhanced_cps_input_data) snapshot.assert_match( prompt, "simulation_analysis_prompt_dataset_enhanced_cps.txt" ) @@ -46,6 +43,4 @@ def test_given_missing_input_field(self): Exception, match="1 validation error for InboundParameters\ntime_period\n Field required", ): - generate_simulation_analysis_prompt( - invalid_data_missing_input_field - ) + generate_simulation_analysis_prompt(invalid_data_missing_input_field) diff --git a/tests/unit/data/test_congressional_districts.py b/tests/unit/data/test_congressional_districts.py index 91b415705..7fd1bfb42 100644 --- a/tests/unit/data/test_congressional_districts.py +++ b/tests/unit/data/test_congressional_districts.py @@ -78,15 +78,11 @@ def test__all_state_codes_are_in_state_code_to_name(self): assert district.state_code in STATE_CODE_TO_NAME def test__california_has_52_districts(self): - ca_districts = [ - d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "CA" - ] + ca_districts = [d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "CA"] assert len(ca_districts) == 52 def test__texas_has_38_districts(self): - tx_districts = [ - d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "TX" - ] + tx_districts = [d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "TX"] assert len(tx_districts) == 38 def test__at_large_states_have_1_district(self): @@ -94,31 +90,23 @@ def test__at_large_states_have_1_district(self): at_large_states = [s for s in AT_LARGE_STATES if s != "DC"] for state_code in at_large_states: state_districts = [ - d - for d in CONGRESSIONAL_DISTRICTS - if d.state_code == state_code + d for d in CONGRESSIONAL_DISTRICTS if d.state_code == state_code ] assert len(state_districts) == 1 assert state_districts[0].number == 1 def test__dc_has_1_district(self): - dc_districts = [ - d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "DC" - ] + dc_districts = [d for d in CONGRESSIONAL_DISTRICTS if d.state_code == "DC"] assert len(dc_districts) == 1 assert dc_districts[0].number == 1 def test__dc_comes_after_delaware(self): # Find indices de_indices = [ - i - for i, d in enumerate(CONGRESSIONAL_DISTRICTS) - if d.state_code == "DE" + i for i, d in enumerate(CONGRESSIONAL_DISTRICTS) if d.state_code == "DE" ] dc_indices = [ - i - for i, d in enumerate(CONGRESSIONAL_DISTRICTS) - if d.state_code == "DC" + i for i, d in enumerate(CONGRESSIONAL_DISTRICTS) if d.state_code == "DC" ] # DC should come after all DE districts assert min(dc_indices) > max(de_indices) @@ -144,36 +132,27 @@ def test__name_has_correct_format(self): metadata = build_congressional_district_metadata() # Check first California district ca_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-01" + item for item in metadata if item["name"] == "congressional_district/CA-01" ) assert ca_01 is not None def test__label_has_correct_format(self): metadata = build_congressional_district_metadata() ca_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-01" + item for item in metadata if item["name"] == "congressional_district/CA-01" ) assert ca_01["label"] == "California's 1st congressional district" def test__state_abbreviation_is_uppercase(self): metadata = build_congressional_district_metadata() for item in metadata: - assert ( - item["state_abbreviation"] - == item["state_abbreviation"].upper() - ) + assert item["state_abbreviation"] == item["state_abbreviation"].upper() assert len(item["state_abbreviation"]) == 2 def test__state_name_matches_abbreviation(self): metadata = build_congressional_district_metadata() ca_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-01" + item for item in metadata if item["name"] == "congressional_district/CA-01" ) assert ca_01["state_abbreviation"] == "CA" assert ca_01["state_name"] == "California" @@ -181,9 +160,7 @@ def test__state_name_matches_abbreviation(self): def test__dc_state_fields(self): metadata = build_congressional_district_metadata() dc_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/DC-01" + item for item in metadata if item["name"] == "congressional_district/DC-01" ) assert dc_01["state_abbreviation"] == "DC" assert dc_01["state_name"] == "District of Columbia" @@ -198,39 +175,25 @@ def test__ordinal_suffixes_are_correct(self): # Find specific districts to test ordinal suffixes ca_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-01" + item for item in metadata if item["name"] == "congressional_district/CA-01" ) ca_02 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-02" + item for item in metadata if item["name"] == "congressional_district/CA-02" ) ca_03 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-03" + item for item in metadata if item["name"] == "congressional_district/CA-03" ) ca_11 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-11" + item for item in metadata if item["name"] == "congressional_district/CA-11" ) ca_12 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-12" + item for item in metadata if item["name"] == "congressional_district/CA-12" ) ca_21 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-21" + item for item in metadata if item["name"] == "congressional_district/CA-21" ) ca_22 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-22" + item for item in metadata if item["name"] == "congressional_district/CA-22" ) assert "1st" in ca_01["label"] @@ -245,17 +208,13 @@ def test__district_numbers_have_leading_zeros(self): metadata = build_congressional_district_metadata() # Single digit districts should have leading zero ca_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-01" + item for item in metadata if item["name"] == "congressional_district/CA-01" ) assert ca_01["name"] == "congressional_district/CA-01" # Double digit districts should not have leading zero ca_37 = next( - item - for item in metadata - if item["name"] == "congressional_district/CA-37" + item for item in metadata if item["name"] == "congressional_district/CA-37" ) assert ca_37["name"] == "congressional_district/CA-37" @@ -268,25 +227,21 @@ def test__at_large_states_have_at_large_label(self): for item in metadata if item["name"] == f"congressional_district/{state_code}-01" ) - assert ( - "at-large congressional district" in district["label"] - ), f"{state_code} should have at-large label" + assert "at-large congressional district" in district["label"], ( + f"{state_code} should have at-large label" + ) def test__alaska_at_large_label(self): metadata = build_congressional_district_metadata() ak_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/AK-01" + item for item in metadata if item["name"] == "congressional_district/AK-01" ) assert ak_01["label"] == "Alaska's at-large congressional district" def test__wyoming_at_large_label(self): metadata = build_congressional_district_metadata() wy_01 = next( - item - for item in metadata - if item["name"] == "congressional_district/WY-01" + item for item in metadata if item["name"] == "congressional_district/WY-01" ) assert wy_01["label"] == "Wyoming's at-large congressional district" diff --git a/tests/unit/data/test_sqlalchemy_v2.py b/tests/unit/data/test_sqlalchemy_v2.py index fceb898e4..3882bb0f7 100644 --- a/tests/unit/data/test_sqlalchemy_v2.py +++ b/tests/unit/data/test_sqlalchemy_v2.py @@ -21,9 +21,7 @@ class TestSQLAlchemyVersion: def test_sqlalchemy_version_is_v2(self): major = int(sqlalchemy.__version__.split(".")[0]) - assert ( - major >= 2 - ), f"Expected SQLAlchemy v2+, got {sqlalchemy.__version__}" + assert major >= 2, f"Expected SQLAlchemy v2+, got {sqlalchemy.__version__}" class TestResultProxy: @@ -60,9 +58,7 @@ def test_fetchone_returns_none_when_exhausted(self): def test_fetchall_returns_all_rows(self): engine = sqlalchemy.create_engine("sqlite://") with engine.connect() as conn: - conn.exec_driver_sql( - "CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT)" - ) + conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT)") conn.exec_driver_sql("INSERT INTO test VALUES (1, 'a')") conn.exec_driver_sql("INSERT INTO test VALUES (2, 'b')") conn.exec_driver_sql("INSERT INTO test VALUES (3, 'c')") @@ -142,9 +138,7 @@ def test_remote_insert_and_select(self): ] ) - result = db._execute_remote( - ["SELECT * FROM test_table WHERE id = ?", (1,)] - ) + result = db._execute_remote(["SELECT * FROM test_table WHERE id = ?", (1,)]) row = result.fetchone() assert row is not None assert row["id"] == 1 @@ -154,9 +148,7 @@ def test_remote_insert_and_select(self): def test_remote_select_no_results(self): db = self._make_remote_db() - result = db._execute_remote( - ["SELECT * FROM test_table WHERE id = ?", (999,)] - ) + result = db._execute_remote(["SELECT * FROM test_table WHERE id = ?", (999,)]) assert result.fetchone() is None def test_remote_update(self): @@ -173,9 +165,7 @@ def test_remote_update(self): ("updated", 1), ] ) - result = db._execute_remote( - ["SELECT * FROM test_table WHERE id = ?", (1,)] - ) + result = db._execute_remote(["SELECT * FROM test_table WHERE id = ?", (1,)]) row = result.fetchone() assert row["name"] == "updated" @@ -188,7 +178,5 @@ def test_remote_delete(self): ] ) db._execute_remote(["DELETE FROM test_table WHERE id = ?", (1,)]) - result = db._execute_remote( - ["SELECT * FROM test_table WHERE id = ?", (1,)] - ) + result = db._execute_remote(["SELECT * FROM test_table WHERE id = ?", (1,)]) assert result.fetchone() is None diff --git a/tests/unit/endpoints/economy/test_compare.py b/tests/unit/endpoints/economy/test_compare.py index 76617a69a..759cc7f26 100644 --- a/tests/unit/endpoints/economy/test_compare.py +++ b/tests/unit/endpoints/economy/test_compare.py @@ -121,9 +121,7 @@ def test__given_non_uk_country_canada__returns_none(self): result = uk_local_authority_breakdown({}, {}, "ca") assert result is None - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_uk_country__returns_breakdown( @@ -138,9 +136,7 @@ def test__given_uk_country__returns_breakdown( # Create mock weights - 3 local authorities, 10 households mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -228,13 +224,11 @@ def test__outcome_bucket_categorization_logic(self): else: bucket = "Lose more than 5%" - assert ( - bucket == expected_bucket - ), f"Failed for {percent_change}: expected {expected_bucket}, got {bucket}" + assert bucket == expected_bucket, ( + f"Failed for {percent_change}: expected {expected_bucket}, got {bucket}" + ) - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__outcome_buckets_are_correct( @@ -247,9 +241,7 @@ def test__outcome_buckets_are_correct( mock_weights = np.ones((1, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -272,9 +264,7 @@ def test__outcome_buckets_are_correct( assert result.outcomes_by_region["uk"]["Gain more than 5%"] == 1 assert result.outcomes_by_region["uk"]["Gain less than 5%"] == 0 - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__downloads_from_correct_repos( @@ -287,9 +277,7 @@ def test__downloads_from_correct_repos( mock_weights = np.ones((1, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -310,32 +298,22 @@ def test__downloads_from_correct_repos( # Verify correct repos are used calls = mock_download.call_args_list - assert ( - calls[0][1]["repo"] == "policyengine/policyengine-uk-data-private" - ) + assert calls[0][1]["repo"] == "policyengine/policyengine-uk-data-private" assert calls[0][1]["repo_filename"] == "local_authority_weights.h5" - assert ( - calls[1][1]["repo"] == "policyengine/policyengine-uk-data-public" - ) + assert calls[1][1]["repo"] == "policyengine/policyengine-uk-data-public" assert calls[1][1]["repo_filename"] == "local_authorities_2021.csv" def test__given_constituency_region__returns_none(self): """When simulating a constituency, local authority breakdown should not be computed.""" - result = uk_local_authority_breakdown( - {}, {}, "uk", "constituency/Aldershot" - ) + result = uk_local_authority_breakdown({}, {}, "uk", "constituency/Aldershot") assert result is None def test__given_constituency_region_with_code__returns_none(self): """When simulating a constituency by code, local authority breakdown should not be computed.""" - result = uk_local_authority_breakdown( - {}, {}, "uk", "constituency/E12345678" - ) + result = uk_local_authority_breakdown({}, {}, "uk", "constituency/E12345678") assert result is None - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_specific_la_region__returns_only_that_la( @@ -349,9 +327,7 @@ def test__given_specific_la_region__returns_only_that_la( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -378,9 +354,7 @@ def test__given_specific_la_region__returns_only_that_la( assert "Aberdeen City" not in result.by_local_authority assert "Isle of Anglesey" not in result.by_local_authority - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_country_scotland_region__returns_only_scottish_las( @@ -394,9 +368,7 @@ def test__given_country_scotland_region__returns_only_scottish_las( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -423,9 +395,7 @@ def test__given_country_scotland_region__returns_only_scottish_las( assert "Hartlepool" not in result.by_local_authority assert "Isle of Anglesey" not in result.by_local_authority - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_uk_region__returns_all_las( @@ -439,9 +409,7 @@ def test__given_uk_region__returns_all_las( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -466,9 +434,7 @@ def test__given_uk_region__returns_all_las( assert "Aberdeen City" in result.by_local_authority assert "Isle of Anglesey" in result.by_local_authority - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_no_region__returns_all_las( @@ -482,9 +448,7 @@ def test__given_no_region__returns_all_las( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -548,21 +512,15 @@ def test__given_non_uk_country_nigeria__returns_none(self): def test__given_local_authority_region__returns_none(self): """When simulating a local authority, constituency breakdown should not be computed.""" - result = uk_constituency_breakdown( - {}, {}, "uk", "local_authority/Leicester" - ) + result = uk_constituency_breakdown({}, {}, "uk", "local_authority/Leicester") assert result is None def test__given_local_authority_region_with_code__returns_none(self): """When simulating a local authority by code, constituency breakdown should not be computed.""" - result = uk_constituency_breakdown( - {}, {}, "uk", "local_authority/E06000016" - ) + result = uk_constituency_breakdown({}, {}, "uk", "local_authority/E06000016") assert result is None - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_specific_constituency_region__returns_only_that_constituency( @@ -577,9 +535,7 @@ def test__given_specific_constituency_region__returns_only_that_constituency( # Create mock weights - 3 constituencies, 10 households mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -607,9 +563,7 @@ def test__given_specific_constituency_region__returns_only_that_constituency( assert "Edinburgh East" not in result.by_constituency assert "Cardiff South" not in result.by_constituency - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_country_scotland_region__returns_only_scottish_constituencies( @@ -623,9 +577,7 @@ def test__given_country_scotland_region__returns_only_scottish_constituencies( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -642,9 +594,7 @@ def test__given_country_scotland_region__returns_only_scottish_constituencies( baseline = {"household_net_income": np.array([1000.0] * 10)} reform = {"household_net_income": np.array([1050.0] * 10)} - result = uk_constituency_breakdown( - baseline, reform, "uk", "country/scotland" - ) + result = uk_constituency_breakdown(baseline, reform, "uk", "country/scotland") assert result is not None assert len(result.by_constituency) == 1 @@ -652,9 +602,7 @@ def test__given_country_scotland_region__returns_only_scottish_constituencies( assert "Aldershot" not in result.by_constituency assert "Cardiff South" not in result.by_constituency - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_uk_region__returns_all_constituencies( @@ -668,9 +616,7 @@ def test__given_uk_region__returns_all_constituencies( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -695,9 +641,7 @@ def test__given_uk_region__returns_all_constituencies( assert "Edinburgh East" in result.by_constituency assert "Cardiff South" in result.by_constituency - @patch( - "policyengine_api.endpoints.economy.compare.download_huggingface_dataset" - ) + @patch("policyengine_api.endpoints.economy.compare.download_huggingface_dataset") @patch("policyengine_api.endpoints.economy.compare.h5py.File") @patch("policyengine_api.endpoints.economy.compare.pd.read_csv") def test__given_no_region__returns_all_constituencies( @@ -711,9 +655,7 @@ def test__given_no_region__returns_all_constituencies( mock_weights = np.ones((3, 10)) mock_h5py_context = MagicMock() - mock_h5py_context.__enter__ = MagicMock( - return_value={"2025": mock_weights} - ) + mock_h5py_context.__enter__ = MagicMock(return_value={"2025": mock_weights}) mock_h5py_context.__exit__ = MagicMock(return_value=False) mock_h5py_file.return_value = mock_h5py_context @@ -785,9 +727,7 @@ def test__5pct_gain_classified_below_5pct(self): # Every decile should have 0% in "Gain more than 5%" for pct in result["deciles"]["Gain more than 5%"]: - assert ( - pct == 0.0 - ), f"5% gain incorrectly classified as >5% (got {pct})" + assert pct == 0.0, f"5% gain incorrectly classified as >5% (got {pct})" # Every decile should have 100% in "Gain less than 5%" for pct in result["deciles"]["Gain less than 5%"]: assert pct == 1.0, f"5% gain not classified as <5% (got {pct})" @@ -854,9 +794,7 @@ def test__near_zero_baseline_no_division_error(self): # Should not raise; all households gained income total = sum(result["all"][label] for label in result["all"]) - assert ( - abs(total - 1.0) < 1e-9 - ), f"Proportions should sum to 1, got {total}" + assert abs(total - 1.0) < 1e-9, f"Proportions should sum to 1, got {total}" def test__zero_baseline_uses_floor_of_one(self): """When baseline income is 0, the max(B, 1) floor means the @@ -874,9 +812,7 @@ def test__zero_baseline_uses_floor_of_one(self): # $100 gain on a floored baseline of $1 = 10000% change -> >5% for pct in result["deciles"]["Gain more than 5%"]: - assert ( - pct == 1.0 - ), f"Zero baseline with $100 gain should be >5% (got {pct})" + assert pct == 1.0, f"Zero baseline with $100 gain should be >5% (got {pct})" # No NaN or Inf in any bucket for label in result["all"]: assert not np.isnan(result["all"][label]) @@ -932,9 +868,7 @@ def test__4pct_gain_not_doubled_into_above_5pct(self): result = intra_decile_impact(baseline, reform) for pct in result["deciles"]["Gain more than 5%"]: - assert ( - pct == 0.0 - ), "4% gain incorrectly classified as >5% (doubling bug)" + assert pct == 0.0, "4% gain incorrectly classified as >5% (doubling bug)" for pct in result["deciles"]["Gain less than 5%"]: assert pct == 1.0, "4% gain not classified as <5%" @@ -975,9 +909,9 @@ def test__5pct_gain_classified_below_5pct(self): result = intra_wealth_decile_impact(baseline, reform) for pct in result["deciles"]["Gain more than 5%"]: - assert ( - pct == 0.0 - ), f"5% gain incorrectly classified as >5% in wealth decile (got {pct})" + assert pct == 0.0, ( + f"5% gain incorrectly classified as >5% in wealth decile (got {pct})" + ) def test__2pct_gain_not_doubled(self): """A 2% gain must stay in the <5% bucket for wealth deciles too.""" diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index 4ba7d0616..d44dde8cb 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -39,7 +39,6 @@ class TestModalSimulationExecution: """Tests for the ModalSimulationExecution dataclass.""" class TestNameProperty: - def test__given_job_id__then_name_returns_job_id(self): # Given execution = ModalSimulationExecution( @@ -54,7 +53,6 @@ def test__given_job_id__then_name_returns_job_id(self): assert name == MOCK_MODAL_JOB_ID class TestAttributes: - def test__given_complete_execution__then_all_attributes_accessible( self, ): @@ -92,10 +90,7 @@ class TestSimulationAPIModal: """Tests for the SimulationAPIModal class.""" class TestInit: - - def test__given_env_var_set__then_uses_env_url( - self, mock_httpx_client - ): + def test__given_env_var_set__then_uses_env_url(self, mock_httpx_client): # Given with patch.dict( "os.environ", @@ -107,9 +102,7 @@ def test__given_env_var_set__then_uses_env_url( # Then assert api.base_url == MOCK_MODAL_BASE_URL - def test__given_env_var_not_set__then_uses_default_url( - self, mock_httpx_client - ): + def test__given_env_var_not_set__then_uses_default_url(self, mock_httpx_client): # Given with patch.dict("os.environ", {}, clear=False): import os @@ -124,7 +117,6 @@ def test__given_env_var_not_set__then_uses_default_url( assert "modal.run" in api.base_url class TestRun: - def test__given_valid_payload__then_returns_execution_with_job_id( self, mock_httpx_client, @@ -188,9 +180,7 @@ def test__given_network_error__then_raises_exception( mock_modal_logger, ): # Given - mock_httpx_client.post.side_effect = httpx.RequestError( - "Connection failed" - ) + mock_httpx_client.post.side_effect = httpx.RequestError("Connection failed") api = SimulationAPIModal() # When/Then @@ -198,7 +188,6 @@ def test__given_network_error__then_raises_exception( api.run(MOCK_SIMULATION_PAYLOAD) class TestGetExecutionById: - def test__given_running_job__then_returns_running_status( self, mock_httpx_client, @@ -277,10 +266,7 @@ def test__given_job_id__then_polls_correct_endpoint( assert f"/jobs/{MOCK_MODAL_JOB_ID}" in call_args[0][0] class TestGetExecutionId: - - def test__given_execution__then_returns_job_id( - self, mock_httpx_client - ): + def test__given_execution__then_returns_job_id(self, mock_httpx_client): # Given api = SimulationAPIModal() execution = ModalSimulationExecution( @@ -295,10 +281,7 @@ def test__given_execution__then_returns_job_id( assert execution_id == MOCK_MODAL_JOB_ID class TestGetExecutionStatus: - - def test__given_execution__then_returns_status_string( - self, mock_httpx_client - ): + def test__given_execution__then_returns_status_string(self, mock_httpx_client): # Given api = SimulationAPIModal() execution = ModalSimulationExecution( @@ -313,7 +296,6 @@ def test__given_execution__then_returns_status_string( assert status == MODAL_EXECUTION_STATUS_RUNNING class TestGetExecutionResult: - def test__given_complete_execution__then_returns_result( self, mock_httpx_client ): @@ -349,7 +331,6 @@ def test__given_incomplete_execution__then_returns_none( assert result is None class TestHealthCheck: - def test__given_healthy_api__then_returns_true( self, mock_httpx_client, mock_modal_logger ): @@ -386,9 +367,7 @@ def test__given_network_error__then_returns_false( self, mock_httpx_client, mock_modal_logger ): # Given - mock_httpx_client.get.side_effect = httpx.RequestError( - "Connection failed" - ) + mock_httpx_client.get.side_effect = httpx.RequestError("Connection failed") api = SimulationAPIModal() # When diff --git a/tests/unit/services/test_ai_analysis_service.py b/tests/unit/services/test_ai_analysis_service.py index 34810cc2b..e853a1f68 100644 --- a/tests/unit/services/test_ai_analysis_service.py +++ b/tests/unit/services/test_ai_analysis_service.py @@ -13,7 +13,6 @@ class TestTriggerAIAnalysis: - def test_trigger_ai_analysis_given_successful_streaming( self, mock_stream_text_events, test_db ): @@ -33,8 +32,7 @@ def test_trigger_ai_analysis_given_successful_streaming( for i, chunk in enumerate(results): if i < len(text_chunks): expected_chunk = ( - json.dumps({"type": "text", "stream": text_chunks[i][:5]}) - + "\n" + json.dumps({"type": "text", "stream": text_chunks[i][:5]}) + "\n" ) assert chunk == expected_chunk diff --git a/tests/unit/services/test_create_profile.py b/tests/unit/services/test_create_profile.py index b16688f33..6217d7633 100644 --- a/tests/unit/services/test_create_profile.py +++ b/tests/unit/services/test_create_profile.py @@ -7,7 +7,6 @@ class TestCreateProfile: - def test_create_profile_valid(self): auth0_id = "test-auth-id" primary_country = "United States" diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 49c0fe39b..c49783bad 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -36,9 +36,7 @@ class TestEconomyService: - class TestGetEconomicImpact: - @pytest.fixture def economy_service(self): return EconomyService() @@ -175,9 +173,7 @@ def test__given_no_previous_impact__creates_new_simulation( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = ( - [] - ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] result = economy_service.get_economic_impact(**base_params) @@ -200,9 +196,7 @@ def test__given_no_previous_impact__includes_metadata_in_simulation_params( mock_numpy_random, ): """Verify that _metadata with policy IDs is passed to simulation API.""" - mock_reform_impacts_service.get_all_reform_impacts.return_value = ( - [] - ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] economy_service.get_economic_impact(**base_params) @@ -212,12 +206,9 @@ def test__given_no_previous_impact__includes_metadata_in_simulation_params( # Verify _metadata is included with correct values assert "_metadata" in sim_params + assert sim_params["_metadata"]["reform_policy_id"] == MOCK_POLICY_ID assert ( - sim_params["_metadata"]["reform_policy_id"] == MOCK_POLICY_ID - ) - assert ( - sim_params["_metadata"]["baseline_policy_id"] - == MOCK_BASELINE_POLICY_ID + sim_params["_metadata"]["baseline_policy_id"] == MOCK_BASELINE_POLICY_ID ) assert sim_params["_metadata"]["process_id"] == MOCK_PROCESS_ID @@ -234,8 +225,8 @@ def test__given_exception__raises_error( mock_datetime, mock_numpy_random, ): - mock_reform_impacts_service.get_all_reform_impacts.side_effect = ( - Exception("Database error") + mock_reform_impacts_service.get_all_reform_impacts.side_effect = Exception( + "Database error" ) with pytest.raises(Exception) as exc_info: @@ -243,7 +234,6 @@ def test__given_exception__raises_error( assert str(exc_info.value) == "Database error" class TestGetPreviousImpacts: - @pytest.fixture def economy_service(self): return EconomyService() @@ -280,7 +270,6 @@ def test_given_valid_parameters_calls_service_correctly( ) class TestGetMostRecentImpact: - @pytest.fixture def economy_service(self): return EconomyService() @@ -308,9 +297,7 @@ def test__given_existing_impacts__returns_first_impact( create_mock_reform_impact(), create_mock_reform_impact(), ] - mock_reform_impacts_service.get_all_reform_impacts.return_value = ( - impacts - ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = impacts result = economy_service._get_most_recent_impact(setup_options) @@ -320,9 +307,7 @@ def test__given_no_impacts__returns_none( self, economy_service, setup_options, mock_reform_impacts_service ): # Arrange - mock_reform_impacts_service.get_all_reform_impacts.return_value = ( - [] - ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] # Act result = economy_service._get_most_recent_impact(setup_options) @@ -331,7 +316,6 @@ def test__given_no_impacts__returns_none( assert result is None class TestDetermineImpactAction: - @pytest.fixture def economy_service(self): return EconomyService() @@ -355,9 +339,7 @@ def test__given_error_status__returns_completed(self, economy_service): assert result == ImpactAction.COMPLETED - def test__given_computing_status__returns_computing( - self, economy_service - ): + def test__given_computing_status__returns_computing(self, economy_service): impact = create_mock_reform_impact(status="computing") result = economy_service._determine_impact_action(impact) @@ -372,7 +354,6 @@ def test__given_unknown_status__raises_error(self, economy_service): assert "Unknown impact status: unknown" in str(exc_info.value) class TestHandleExecutionState: - @pytest.fixture def economy_service(self): return EconomyService() @@ -453,9 +434,7 @@ def test__given_unknown_state__raises_error( economy_service._handle_execution_state( setup_options, "UNKNOWN", reform_impact ) - assert "Unexpected sim API execution state: UNKNOWN" in str( - exc_info.value - ) + assert "Unexpected sim API execution state: UNKNOWN" in str(exc_info.value) # Modal status tests def test__given_modal_complete_state__then_returns_completed_result( @@ -525,9 +504,7 @@ def test__given_modal_failed_state_with_error_message__then_includes_error_in_me # Then assert result.status == ImpactStatus.ERROR # Verify the error message was passed to the service - call_args = ( - mock_reform_impacts_service.set_error_reform_impact.call_args - ) + call_args = mock_reform_impacts_service.set_error_reform_impact.call_args assert "Simulation timed out" in call_args[1]["message"] def test__given_modal_running_state__then_returns_computing_result( @@ -561,7 +538,6 @@ def test__given_modal_submitted_state__then_returns_computing_result( assert result.data is None class TestCreateProcessId: - @pytest.fixture def economy_service(self): return EconomyService() @@ -577,9 +553,7 @@ def test_given_mocked_datetime_and_random_returns_expected_format( class TestEconomicImpactResult: - class TestToDict: - def test__given_completed_result__returns_correct_dict(self): result = EconomicImpactResult.completed(MOCK_REFORM_IMPACT_DATA) @@ -606,7 +580,6 @@ def test__given_error_result__returns_correct_dict(self): assert result_dict == {"status": "error", "data": None} class TestClassMethods: - def test__given_completed__creates_correct_instance(self): result = EconomicImpactResult.completed(MOCK_REFORM_IMPACT_DATA) @@ -631,7 +604,6 @@ def test__given_error__creates_correct_instance_and_logs(self): class TestEconomicImpactSetupOptions: - def test__given_valid_data__creates_instance(self): options = EconomicImpactSetupOptions( process_id=MOCK_PROCESS_ID, @@ -667,9 +639,7 @@ class TestSetupSimOptions: """ test_country_id = "us" - test_reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + test_reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) test_current_law_baseline_policy = json.dumps({}) test_region = "us" test_time_period = 2025 @@ -698,16 +668,13 @@ def test__given_us_nationwide__returns_correct_sim_options(self): assert sim_options["time_period"] == self.test_time_period assert sim_options["region"] == "us" assert ( - sim_options["data"] - == "gs://policyengine-us-data/enhanced_cps_2024.h5" + sim_options["data"] == "gs://policyengine-us-data/enhanced_cps_2024.h5" ) def test__given_us_state_ca__returns_correct_sim_options(self): # Test with a normalized US state (prefixed format) country_id = "us" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "state/ca" # Pre-normalized time_period = 2025 @@ -727,21 +694,15 @@ def test__given_us_state_ca__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads( - current_law_baseline_policy - ) + assert sim_options["baseline"] == json.loads(current_law_baseline_policy) assert sim_options["time_period"] == time_period assert sim_options["region"] == "state/ca" - assert ( - sim_options["data"] == "gs://policyengine-us-data/states/CA.h5" - ) + assert sim_options["data"] == "gs://policyengine-us-data/states/CA.h5" def test__given_us_state_utah__returns_correct_sim_options(self): # Test with normalized Utah state country_id = "us" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "state/ut" # Pre-normalized time_period = 2025 @@ -761,20 +722,14 @@ def test__given_us_state_utah__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads( - current_law_baseline_policy - ) + assert sim_options["baseline"] == json.loads(current_law_baseline_policy) assert sim_options["time_period"] == time_period assert sim_options["region"] == "state/ut" - assert ( - sim_options["data"] == "gs://policyengine-us-data/states/UT.h5" - ) + assert sim_options["data"] == "gs://policyengine-us-data/states/UT.h5" def test__given_cliff_target__returns_correct_sim_options(self): country_id = "us" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "us" time_period = 2025 @@ -796,22 +751,17 @@ def test__given_cliff_target__returns_correct_sim_options(self): assert sim_options["country"] == country_id assert sim_options["scope"] == scope assert sim_options["reform"] == json.loads(reform_policy) - assert sim_options["baseline"] == json.loads( - current_law_baseline_policy - ) + assert sim_options["baseline"] == json.loads(current_law_baseline_policy) assert sim_options["time_period"] == time_period assert sim_options["region"] == region assert ( - sim_options["data"] - == "gs://policyengine-us-data/enhanced_cps_2024.h5" + sim_options["data"] == "gs://policyengine-us-data/enhanced_cps_2024.h5" ) assert sim_options["include_cliffs"] is True def test__given_uk__returns_correct_sim_options(self): country_id = "uk" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "uk" time_period = 2025 @@ -840,9 +790,7 @@ def test__given_congressional_district__returns_correct_sim_options( self, ): country_id = "us" - reform_policy = json.dumps( - {"sample_param": {"2024-01-01.2100-12-31": 15}} - ) + reform_policy = json.dumps({"sample_param": {"2024-01-01.2100-12-31": 15}}) current_law_baseline_policy = json.dumps({}) region = "congressional_district/CA-37" # Pre-normalized time_period = 2025 @@ -861,10 +809,7 @@ def test__given_congressional_district__returns_correct_sim_options( sim_options = sim_options_model.model_dump() assert sim_options["region"] == "congressional_district/CA-37" - assert ( - sim_options["data"] - == "gs://policyengine-us-data/districts/CA-37.h5" - ) + assert sim_options["data"] == "gs://policyengine-us-data/districts/CA-37.h5" class TestSetupRegion: """Tests for _setup_region method. @@ -897,18 +842,14 @@ def test__given_prefixed_state_tx__returns_unchanged(self): def test__given_congressional_district__returns_unchanged(self): service = EconomyService() - result = service._setup_region( - "us", "congressional_district/CA-37" - ) + result = service._setup_region("us", "congressional_district/CA-37") assert result == "congressional_district/CA-37" def test__given_lowercase_congressional_district__returns_unchanged( self, ): service = EconomyService() - result = service._setup_region( - "us", "congressional_district/ca-37" - ) + result = service._setup_region("us", "congressional_district/ca-37") assert result == "congressional_district/ca-37" def test__given_invalid_prefixed_state__raises_value_error(self): @@ -923,17 +864,13 @@ def test__given_invalid_congressional_district__raises_value_error( service = EconomyService() with pytest.raises(ValueError) as exc_info: service._setup_region("us", "congressional_district/cruft") - assert "Invalid congressional district: 'cruft'" in str( - exc_info.value - ) + assert "Invalid congressional district: 'cruft'" in str(exc_info.value) def test__given_invalid_prefix__raises_value_error(self): service = EconomyService() with pytest.raises(ValueError) as exc_info: service._setup_region("us", "invalid_prefix/tx") - assert "Invalid US region: 'invalid_prefix/tx'" in str( - exc_info.value - ) + assert "Invalid US region: 'invalid_prefix/tx'" in str(exc_info.value) def test__given_invalid_bare_value__raises_value_error(self): # Bare values without prefix are now invalid (should be normalized first) @@ -1010,10 +947,7 @@ def test__given_uk__returns_efrs_dataset(self): # Test with UK - returns enhanced FRS dataset service = EconomyService() result = service._setup_data("uk", "uk") - assert ( - result - == "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" - ) + assert result == "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" def test__given_invalid_country__raises_value_error(self, mock_logger): # Test with invalid country @@ -1025,9 +959,7 @@ def test__given_invalid_country__raises_value_error(self, mock_logger): def test__given_passthrough_dataset__returns_dataset_directly(self): # Test with passthrough dataset (national-with-breakdowns) service = EconomyService() - result = service._setup_data( - "us", "us", dataset="national-with-breakdowns" - ) + result = service._setup_data("us", "us", dataset="national-with-breakdowns") assert result == "national-with-breakdowns" def test__given_passthrough_test_dataset__returns_dataset_directly( @@ -1049,9 +981,7 @@ def test__given_default_dataset__uses_get_default_dataset(self): def test__given_unknown_dataset__uses_get_default_dataset(self): # Test that unknown dataset values fall through to get_default_dataset service = EconomyService() - result = service._setup_data( - "us", "state/ca", dataset="unknown-dataset" - ) + result = service._setup_data("us", "state/ca", dataset="unknown-dataset") assert result == "gs://policyengine-us-data/states/CA.h5" class TestValidateUsRegion: @@ -1089,14 +1019,10 @@ def test__given_invalid_congressional_district__raises_value_error( service = EconomyService() with pytest.raises(ValueError) as exc_info: service._validate_us_region("congressional_district/CA-99") - assert "Invalid congressional district: 'CA-99'" in str( - exc_info.value - ) + assert "Invalid congressional district: 'CA-99'" in str(exc_info.value) def test__given_nonexistent_district__raises_value_error(self): service = EconomyService() with pytest.raises(ValueError) as exc_info: service._validate_us_region("congressional_district/cruft") - assert "Invalid congressional district: 'cruft'" in str( - exc_info.value - ) + assert "Invalid congressional district: 'cruft'" in str(exc_info.value) diff --git a/tests/unit/services/test_household_service.py b/tests/unit/services/test_household_service.py index 3c0ee9b50..b59c78e50 100644 --- a/tests/unit/services/test_household_service.py +++ b/tests/unit/services/test_household_service.py @@ -18,7 +18,6 @@ class TestGetHousehold: - def test_get_household_given_existing_record( self, test_db, existing_household_record ): @@ -26,9 +25,7 @@ def test_get_household_given_existing_record( # GIVEN an existing record... (included as fixture) # WHEN we call get_household for this record... - result = service.get_household( - valid_db_row["country_id"], valid_db_row["id"] - ) + result = service.get_household(valid_db_row["country_id"], valid_db_row["id"]) valid_household_json = valid_request_body["data"] diff --git a/tests/unit/services/test_metadata_service.py b/tests/unit/services/test_metadata_service.py index 40c6805de..255c90189 100644 --- a/tests/unit/services/test_metadata_service.py +++ b/tests/unit/services/test_metadata_service.py @@ -4,7 +4,6 @@ class TestMetadataService: - def test_get_metadata_nonexistent_country(self): service = MetadataService() # GIVEN a non-existent country ID @@ -98,9 +97,9 @@ def test_verify_metadata_for_given_country( assert "region" in metadata["economy_options"] regions = metadata["economy_options"]["region"] for region in test_regions: - assert any( - r["name"] == region for r in regions - ), f"Expected region '{region}' not found" + assert any(r["name"] == region for r in regions), ( + f"Expected region '{region}' not found" + ) # Verify time periods exist and have correct structure assert "time_period" in metadata["economy_options"] @@ -126,9 +125,7 @@ def test_verify_metadata_for_given_country( ("us", ["national", "state", "place", "congressional_district"]), ], ) - def test_verify_region_types_for_given_country( - self, country_id, expected_types - ): + def test_verify_region_types_for_given_country(self, country_id, expected_types): """ Verifies that all regions for UK and US have a 'type' field with valid values. @@ -138,9 +135,7 @@ def test_verify_region_types_for_given_country( regions = metadata["economy_options"]["region"] for region in regions: - assert ( - "type" in region - ), f"Region '{region['name']}' missing 'type' field" - assert ( - region["type"] in expected_types - ), f"Region '{region['name']}' has invalid type '{region['type']}'" + assert "type" in region, f"Region '{region['name']}' missing 'type' field" + assert region["type"] in expected_types, ( + f"Region '{region['name']}' has invalid type '{region['type']}'" + ) diff --git a/tests/unit/services/test_policy_service.py b/tests/unit/services/test_policy_service.py index 4530dd9d5..ac358ab71 100644 --- a/tests/unit/services/test_policy_service.py +++ b/tests/unit/services/test_policy_service.py @@ -15,10 +15,7 @@ class TestGetPolicy: - - def test_get_policy_given_existing_record( - self, test_db, existing_policy_record - ): + def test_get_policy_given_existing_record(self, test_db, existing_policy_record): # GIVEN an existing record... (included as fixture) # WHEN we call get_policy for this record... @@ -43,9 +40,7 @@ def test_get_policy_given_nonexistent_record(self, test_db): # WHEN we call get_policy for a nonexistent record NO_SUCH_RECORD_ID = 999 - result = service.get_policy( - valid_policy_data["country_id"], NO_SUCH_RECORD_ID - ) + result = service.get_policy(valid_policy_data["country_id"], NO_SUCH_RECORD_ID) # THEN the result should be None assert result is None @@ -60,9 +55,7 @@ def test_get_policy_given_str_id(self): ): # WHEN we call get_policy with the invalid ID # THEN an exception should be raised - service.get_policy( - valid_policy_data["country_id"], INVALID_RECORD_ID - ) + service.get_policy(valid_policy_data["country_id"], INVALID_RECORD_ID) def test_get_policy_given_negative_int_id(self): # GIVEN an invalid ID @@ -74,18 +67,14 @@ def test_get_policy_given_negative_int_id(self): ): # WHEN we call get_policy with the invalid ID # THEN an exception should be raised - service.get_policy( - valid_policy_data["country_id"], INVALID_RECORD_ID - ) + service.get_policy(valid_policy_data["country_id"], INVALID_RECORD_ID) def test_get_policy_given_invalid_country_id(self): # GIVEN an invalid country_id INVALID_COUNTRY_ID = "xx" # Unsupported country code # WHEN we call get_policy with the invalid country_id - result = service.get_policy( - INVALID_COUNTRY_ID, valid_policy_data["id"] - ) + result = service.get_policy(INVALID_COUNTRY_ID, valid_policy_data["id"]) # THEN the result should be None or raise an exception assert result is None @@ -236,9 +225,7 @@ def test_set_policy_existing( existing_policy = existing_policy_record # Setup mock - mock_database.query.return_value.fetchone.return_value = ( - existing_policy - ) + mock_database.query.return_value.fetchone.return_value = existing_policy # Define expected database calls - matches actual implementation expected_calls = [ @@ -277,9 +264,7 @@ def test_set_policy_given_database_insert_failure( # Setup mock to raise exception on insert mock_database.query.return_value.fetchone.side_effect = [ None, # First call: policy does not exist - Exception( - "Database insertion failed" - ), # Second call: insertion fails + Exception("Database insertion failed"), # Second call: insertion fails ] # WHEN we call set_policy @@ -300,9 +285,7 @@ def test_set_policy_given_invalid_country_id(self, mock_hash_object): # THEN an exception should be raised service.set_policy(INVALID_COUNTRY_ID, test_label, test_policy) - def test_set_policy_given_empty_label( - self, mock_database, mock_hash_object - ): + def test_set_policy_given_empty_label(self, mock_database, mock_hash_object): # GIVEN an empty label EMPTY_LABEL = "" test_policy = {"param": "value"} diff --git a/tests/unit/services/test_report_output_service.py b/tests/unit/services/test_report_output_service.py index 15f6b8576..c1f6b3e55 100644 --- a/tests/unit/services/test_report_output_service.py +++ b/tests/unit/services/test_report_output_service.py @@ -13,9 +13,7 @@ class TestFindExistingReportOutput: """Test finding existing report outputs in the database.""" - def test_find_existing_report_output_found( - self, test_db, existing_report_record - ): + def test_find_existing_report_output_found(self, test_db, existing_report_record): """Test finding an existing report output.""" # GIVEN an existing report record (from fixture) @@ -29,10 +27,7 @@ def test_find_existing_report_output_found( # THEN the result should contain the existing report assert result is not None assert result["id"] == existing_report_record["id"] - assert ( - result["simulation_1_id"] - == existing_report_record["simulation_1_id"] - ) + assert result["simulation_1_id"] == existing_report_record["simulation_1_id"] assert result["status"] == existing_report_record["status"] def test_find_existing_report_output_not_found(self, test_db): @@ -248,10 +243,7 @@ def test_get_report_output_existing(self, test_db, existing_report_record): # THEN the correct report should be returned assert result is not None assert result["id"] == existing_report_record["id"] - assert ( - result["simulation_1_id"] - == existing_report_record["simulation_1_id"] - ) + assert result["simulation_1_id"] == existing_report_record["simulation_1_id"] assert result["status"] == existing_report_record["status"] def test_get_report_output_nonexistent(self, test_db): @@ -335,21 +327,15 @@ def test_duplicate_report_returns_existing(self, test_db): # THEN the same report should be returned (no duplicate created) assert first_report["id"] == second_report["id"] assert first_report["country_id"] == second_report["country_id"] - assert ( - first_report["simulation_1_id"] == second_report["simulation_1_id"] - ) - assert ( - first_report["simulation_2_id"] == second_report["simulation_2_id"] - ) + assert first_report["simulation_1_id"] == second_report["simulation_1_id"] + assert first_report["simulation_2_id"] == second_report["simulation_2_id"] assert first_report["year"] == second_report["year"] class TestUpdateReportOutput: """Test updating report outputs in the database.""" - def test_update_report_output_to_complete( - self, test_db, existing_report_record - ): + def test_update_report_output_to_complete(self, test_db, existing_report_record): """Test updating a report to complete status with output.""" # GIVEN an existing pending report report_id = existing_report_record["id"] @@ -374,9 +360,7 @@ def test_update_report_output_to_complete( assert result["status"] == "complete" assert result["output"] == test_output_json - def test_update_report_output_to_error( - self, test_db, existing_report_record - ): + def test_update_report_output_to_error(self, test_db, existing_report_record): """Test updating a report to error status with message.""" # GIVEN an existing pending report report_id = existing_report_record["id"] @@ -400,9 +384,7 @@ def test_update_report_output_to_error( assert result["status"] == "error" assert result["error_message"] == error_msg - def test_update_report_output_partial_update( - self, test_db, existing_report_record - ): + def test_update_report_output_partial_update(self, test_db, existing_report_record): """Test that partial updates work correctly.""" # GIVEN an existing report report_id = existing_report_record["id"] @@ -424,9 +406,7 @@ def test_update_report_output_partial_update( assert result["status"] == "complete" assert result["output"] is None # Should remain unchanged - def test_update_report_output_no_fields( - self, test_db, existing_report_record - ): + def test_update_report_output_no_fields(self, test_db, existing_report_record): """Test that update with no optional fields still updates API version.""" # GIVEN an existing report diff --git a/tests/unit/services/test_simulation_service.py b/tests/unit/services/test_simulation_service.py index 49c8654a3..ac1fbccf6 100644 --- a/tests/unit/services/test_simulation_service.py +++ b/tests/unit/services/test_simulation_service.py @@ -31,9 +31,7 @@ def test_find_existing_simulation_given_existing_record( assert result is not None assert result["id"] == existing_simulation_record["id"] assert result["country_id"] == valid_simulation_data["country_id"] - assert ( - result["population_id"] == valid_simulation_data["population_id"] - ) + assert result["population_id"] == valid_simulation_data["population_id"] assert result["policy_id"] == valid_simulation_data["policy_id"] def test_find_existing_simulation_given_no_match(self, test_db): @@ -154,9 +152,7 @@ def test_create_simulation_retrieves_correct_id(self, test_db): class TestGetSimulation: """Test retrieving simulations from the database.""" - def test_get_simulation_existing( - self, test_db, existing_simulation_record - ): + def test_get_simulation_existing(self, test_db, existing_simulation_record): """Test retrieving an existing simulation.""" # GIVEN an existing simulation record @@ -181,9 +177,7 @@ def test_get_simulation_nonexistent(self, test_db): # THEN None should be returned assert result is None - def test_get_simulation_wrong_country( - self, test_db, existing_simulation_record - ): + def test_get_simulation_wrong_country(self, test_db, existing_simulation_record): """Test that simulations are country-specific.""" # GIVEN an existing simulation for 'us' @@ -234,11 +228,6 @@ def test_duplicate_simulation_returns_existing(self, test_db): # THEN the same simulation should be returned (no duplicate created) assert first_simulation["id"] == second_simulation["id"] - assert ( - first_simulation["country_id"] == second_simulation["country_id"] - ) - assert ( - first_simulation["population_id"] - == second_simulation["population_id"] - ) + assert first_simulation["country_id"] == second_simulation["country_id"] + assert first_simulation["population_id"] == second_simulation["population_id"] assert first_simulation["policy_id"] == second_simulation["policy_id"] diff --git a/tests/unit/services/test_tracer_analysis_service.py b/tests/unit/services/test_tracer_analysis_service.py index eeb08c9b0..3cd65cf39 100644 --- a/tests/unit/services/test_tracer_analysis_service.py +++ b/tests/unit/services/test_tracer_analysis_service.py @@ -77,9 +77,7 @@ def test_tracer_output_for_empty_tracer(): valid_target_variable = "snap" # When: Extracting from an empty output - result = test_service._parse_tracer_output( - empty_tracer, valid_target_variable - ) + result = test_service._parse_tracer_output(empty_tracer, valid_target_variable) # Then: It should return an empty list since there is no data to parse expected_output = empty_tracer @@ -137,9 +135,7 @@ def test_tracer_output_for_variable_that_is_substring_of_another(): target_variable = "snap_net_income" # When: Extracting the segment for this variable - result = test_service._parse_tracer_output( - valid_tracer_output, target_variable - ) + result = test_service._parse_tracer_output(valid_tracer_output, target_variable) # Then: It should return only the exact match for "snap_net_income", not "snap_net_income_fpg_ratio" diff --git a/tests/unit/services/test_tracer_service.py b/tests/unit/services/test_tracer_service.py index e5436d476..84ece8df3 100644 --- a/tests/unit/services/test_tracer_service.py +++ b/tests/unit/services/test_tracer_service.py @@ -58,6 +58,4 @@ def test_get_tracer_database_error(test_db): valid_api_version, ] with pytest.raises(Exception): - tracer_service.get_tracer( - *missing_parameter_causing_database_exception - ) + tracer_service.get_tracer(*missing_parameter_causing_database_exception) diff --git a/tests/unit/services/test_update_profile_service.py b/tests/unit/services/test_update_profile_service.py index f9fd607b7..5c6016899 100644 --- a/tests/unit/services/test_update_profile_service.py +++ b/tests/unit/services/test_update_profile_service.py @@ -10,10 +10,7 @@ class TestUpdateProfile: - - def test_update_profile_given_existing_record( - self, test_db, existing_user_profile - ): + def test_update_profile_given_existing_record(self, test_db, existing_user_profile): # GIVEN an existing profile record (from fixture) # WHEN we call update_profile with new data @@ -54,9 +51,7 @@ def test_update_profile_given_nonexistent_record(self, test_db): # THEN the result should be False assert result is False - def test_update_profile_with_partial_fields( - self, test_db, existing_user_profile - ): + def test_update_profile_with_partial_fields(self, test_db, existing_user_profile): # GIVEN an existing profile record (from fixture) # WHEN we call update_profile with only some fields provided @@ -93,9 +88,7 @@ def test_update_profile_with_database_error( def mock_db_query_error(*args, **kwargs): raise Exception("Database error") - monkeypatch.setattr( - "policyengine_api.data.database.query", mock_db_query_error - ) + monkeypatch.setattr("policyengine_api.data.database.query", mock_db_query_error) # WHEN we call update_profile # THEN an exception should be raised diff --git a/tests/unit/services/test_user_service.py b/tests/unit/services/test_user_service.py index 75fe4c834..502d7918c 100644 --- a/tests/unit/services/test_user_service.py +++ b/tests/unit/services/test_user_service.py @@ -10,7 +10,6 @@ class TestGetProfile: - def test_get_profile_id_not_specified(self): # GIVEN no ID # WHEN we call get_profile with no auth0_id or user_id @@ -33,9 +32,7 @@ def test_get_profile_nonexistent_record(self): def test_get_profile_auth0_id(self, existing_user_profile): # WHEN we call get_profile with auth0_id - result = service.get_profile( - auth0_id=existing_user_profile["auth0_id"] - ) + result = service.get_profile(auth0_id=existing_user_profile["auth0_id"]) # THEN returns record assert result == existing_user_profile diff --git a/tests/unit/test_country.py b/tests/unit/test_country.py index b57e8ceee..55a1f7c70 100644 --- a/tests/unit/test_country.py +++ b/tests/unit/test_country.py @@ -30,9 +30,7 @@ def test__uk_has_360_local_authorities(self, uk_regions): ] assert len(local_authority_regions) == 360 - def test__local_authority_regions_have_correct_name_format( - self, uk_regions - ): + def test__local_authority_regions_have_correct_name_format(self, uk_regions): """Verify local authority region names have the correct prefix.""" local_authority_regions = [ r for r in uk_regions if r.get("type") == "local_authority" @@ -121,9 +119,7 @@ def test__coordinates_are_numeric(self, local_authorities_df): assert local_authorities_df["x"].dtype in ["float64", "int64"] assert local_authorities_df["y"].dtype in ["float64", "int64"] - def test__english_local_authorities_have_e_prefix( - self, local_authorities_df - ): + def test__english_local_authorities_have_e_prefix(self, local_authorities_df): """Verify English local authorities have E prefix codes.""" english_las = local_authorities_df[ local_authorities_df["code"].str.startswith("E") @@ -131,9 +127,7 @@ def test__english_local_authorities_have_e_prefix( # England has 296 local authorities (majority of the 360 total) assert len(english_las) == 296 - def test__scottish_local_authorities_have_s_prefix( - self, local_authorities_df - ): + def test__scottish_local_authorities_have_s_prefix(self, local_authorities_df): """Verify Scottish local authorities have S prefix codes.""" scottish_las = local_authorities_df[ local_authorities_df["code"].str.startswith("S") @@ -141,9 +135,7 @@ def test__scottish_local_authorities_have_s_prefix( # Scotland has 32 council areas assert len(scottish_las) == 32 - def test__welsh_local_authorities_have_w_prefix( - self, local_authorities_df - ): + def test__welsh_local_authorities_have_w_prefix(self, local_authorities_df): """Verify Welsh local authorities have W prefix codes.""" welsh_las = local_authorities_df[ local_authorities_df["code"].str.startswith("W")