From 1022c9119587c800f00b127873628c6b95456cc6 Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Tue, 3 Mar 2026 20:25:28 +0530 Subject: [PATCH] fix: Update computation module tests for country_id parameter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The computation module functions now include a 7th parameter for country_id (passed as kwarg by run_modules). Some functions accept it explicitly while others use **_kwargs. Updated tests to: - Expect 7 parameters instead of 6 - Accept either 'country_id' or '_kwargs' as the 7th param name - Add **kwargs to tracker functions in mock tests - Include country_id='' in mock assertion expectations 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/test_computation_modules.py | 39 ++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/tests/test_computation_modules.py b/tests/test_computation_modules.py index 5329d39..0d94fdf 100644 --- a/tests/test_computation_modules.py +++ b/tests/test_computation_modules.py @@ -127,9 +127,17 @@ def test_congressional_district_is_us_only(self): class TestModuleFunctionSignatures: - """Tests that all module functions share the expected 6-param signature.""" + """Tests that all module functions share the expected signature pattern. - _EXPECTED_PARAMS = [ + Modules use a common 7-param signature pattern: + (pe_baseline_sim, pe_reform_sim, baseline_sim_id, reform_sim_id, + report_id, session, **kwargs) -> None + + run_modules() passes country_id as a kwarg. Modules that need it (e.g. + compute_decile_module) accept it explicitly; others accept **_kwargs. + """ + + _BASE_PARAMS = [ "pe_baseline_sim", "pe_reform_sim", "baseline_sim_id", @@ -137,6 +145,8 @@ class TestModuleFunctionSignatures: "report_id", "session", ] + # 7th param can be either explicit country_id or **_kwargs + _VALID_7TH_PARAMS = {"country_id", "_kwargs"} def _get_all_unique_functions(self): """Collect all unique module functions from both dispatch tables.""" @@ -150,19 +160,24 @@ def _get_all_unique_functions(self): fns.append(fn) return fns - def test_all_functions_have_6_parameters(self): + def test_all_functions_have_7_parameters(self): for fn in self._get_all_unique_functions(): sig = inspect.signature(fn) - assert len(sig.parameters) == 6, ( - f"{fn.__name__} has {len(sig.parameters)} params, expected 6" + assert len(sig.parameters) == 7, ( + f"{fn.__name__} has {len(sig.parameters)} params, expected 7" ) def test_all_functions_have_expected_param_names(self): for fn in self._get_all_unique_functions(): sig = inspect.signature(fn) param_names = list(sig.parameters.keys()) - assert param_names == self._EXPECTED_PARAMS, ( - f"{fn.__name__} params {param_names} != {self._EXPECTED_PARAMS}" + # First 6 params must match exactly + assert param_names[:6] == self._BASE_PARAMS, ( + f"{fn.__name__} first 6 params {param_names[:6]} != {self._BASE_PARAMS}" + ) + # 7th param can be country_id or _kwargs + assert param_names[6] in self._VALID_7TH_PARAMS, ( + f"{fn.__name__} 7th param '{param_names[6]}' not in {self._VALID_7TH_PARAMS}" ) def test_all_functions_return_none(self): @@ -189,7 +204,9 @@ def test_runs_all_when_modules_is_none(self): run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], session) for fn in dispatch.values(): - fn.assert_called_once_with("bl", "rf", ids[0], ids[1], ids[2], session) + fn.assert_called_once_with( + "bl", "rf", ids[0], ids[1], ids[2], session, country_id="" + ) def test_runs_only_requested_modules(self): dispatch = self._make_mock_dispatch(["a", "b", "c"]) @@ -229,7 +246,7 @@ def test_preserves_call_order(self): call_order = [] def make_tracker(name): - def fn(*args): + def fn(*args, **kwargs): call_order.append(name) return fn @@ -248,7 +265,7 @@ def test_none_modules_runs_all_in_dispatch_key_order(self): call_order = [] def make_tracker(name): - def fn(*args): + def fn(*args, **kwargs): call_order.append(name) return fn @@ -268,7 +285,7 @@ def test_passes_all_args_correctly(self): run_modules(dispatch, ["test_mod"], bl, rf, b_id, r_id, rep_id, session) - mock_fn.assert_called_once_with(bl, rf, b_id, r_id, rep_id, session) + mock_fn.assert_called_once_with(bl, rf, b_id, r_id, rep_id, session, country_id="") def test_duplicate_module_name_runs_twice(self): dispatch = self._make_mock_dispatch(["a"])