Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions tests/test_computation_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,26 @@ def test_congressional_district_is_us_only(self):


class TestModuleFunctionSignatures:
"""Tests that all module functions share the expected 6-param signature."""
"""Tests that all module functions share the expected signature pattern.

_EXPECTED_PARAMS = [
Modules use a common 7-param signature pattern:
(pe_baseline_sim, pe_reform_sim, baseline_sim_id, reform_sim_id,
report_id, session, **kwargs) -> None

run_modules() passes country_id as a kwarg. Modules that need it (e.g.
compute_decile_module) accept it explicitly; others accept **_kwargs.
"""

_BASE_PARAMS = [
"pe_baseline_sim",
"pe_reform_sim",
"baseline_sim_id",
"reform_sim_id",
"report_id",
"session",
]
# 7th param can be either explicit country_id or **_kwargs
_VALID_7TH_PARAMS = {"country_id", "_kwargs"}

def _get_all_unique_functions(self):
"""Collect all unique module functions from both dispatch tables."""
Expand All @@ -150,19 +160,24 @@ def _get_all_unique_functions(self):
fns.append(fn)
return fns

def test_all_functions_have_6_parameters(self):
def test_all_functions_have_7_parameters(self):
for fn in self._get_all_unique_functions():
sig = inspect.signature(fn)
assert len(sig.parameters) == 6, (
f"{fn.__name__} has {len(sig.parameters)} params, expected 6"
assert len(sig.parameters) == 7, (
f"{fn.__name__} has {len(sig.parameters)} params, expected 7"
)

def test_all_functions_have_expected_param_names(self):
for fn in self._get_all_unique_functions():
sig = inspect.signature(fn)
param_names = list(sig.parameters.keys())
assert param_names == self._EXPECTED_PARAMS, (
f"{fn.__name__} params {param_names} != {self._EXPECTED_PARAMS}"
# First 6 params must match exactly
assert param_names[:6] == self._BASE_PARAMS, (
f"{fn.__name__} first 6 params {param_names[:6]} != {self._BASE_PARAMS}"
)
# 7th param can be country_id or _kwargs
assert param_names[6] in self._VALID_7TH_PARAMS, (
f"{fn.__name__} 7th param '{param_names[6]}' not in {self._VALID_7TH_PARAMS}"
)

def test_all_functions_return_none(self):
Expand All @@ -189,7 +204,9 @@ def test_runs_all_when_modules_is_none(self):
run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], session)

for fn in dispatch.values():
fn.assert_called_once_with("bl", "rf", ids[0], ids[1], ids[2], session)
fn.assert_called_once_with(
"bl", "rf", ids[0], ids[1], ids[2], session, country_id=""
)

def test_runs_only_requested_modules(self):
dispatch = self._make_mock_dispatch(["a", "b", "c"])
Expand Down Expand Up @@ -229,7 +246,7 @@ def test_preserves_call_order(self):
call_order = []

def make_tracker(name):
def fn(*args):
def fn(*args, **kwargs):
call_order.append(name)

return fn
Expand All @@ -248,7 +265,7 @@ def test_none_modules_runs_all_in_dispatch_key_order(self):
call_order = []

def make_tracker(name):
def fn(*args):
def fn(*args, **kwargs):
call_order.append(name)

return fn
Expand All @@ -268,7 +285,7 @@ def test_passes_all_args_correctly(self):

run_modules(dispatch, ["test_mod"], bl, rf, b_id, r_id, rep_id, session)

mock_fn.assert_called_once_with(bl, rf, b_id, r_id, rep_id, session)
mock_fn.assert_called_once_with(bl, rf, b_id, r_id, rep_id, session, country_id="")

def test_duplicate_module_name_runs_twice(self):
dispatch = self._make_mock_dispatch(["a"])
Expand Down