diff --git a/panel/simdec_app.py b/panel/simdec_app.py index d351ac4..458617a 100644 --- a/panel/simdec_app.py +++ b/panel/simdec_app.py @@ -70,6 +70,7 @@ def load_data(text_fname): return pd.read_csv(io.BytesIO(raw)) except Exception: pn.state.notifications.error(GENERIC_ERROR_MSG, duration=0) + return pd.read_csv(DEFAULT_STRESS_CSV) @pn.cache @@ -183,11 +184,11 @@ def explained_variance_80(sensitivity_indices_table): si_values = df["Value"].tolist()[1:] input_names = df["Inputs"].tolist()[1:] - # Ensuring explained variance is at least 80% of the total - target = 0.8 * sum(si_values) - pos_80 = bisect.bisect_right(np.cumsum(si_values), target) - - return input_names[: pos_80 + 1] + # Find the variables needed to reach 80% of explained variance + total = sum(si_values) + pos = bisect.bisect_left(np.cumsum(si_values), 0.8 * total) + n_vars = min(pos + 1, 4) + return input_names[:n_vars] @pn.cache diff --git a/src/simdec/decomposition.py b/src/simdec/decomposition.py index 958d2aa..08deb78 100644 --- a/src/simdec/decomposition.py +++ b/src/simdec/decomposition.py @@ -138,7 +138,7 @@ def decomposition( n_var_dec = sensitivity_indices.size n_var_dec = max(1, n_var_dec) # keep at least one variable - n_var_dec = min(5, n_var_dec) # use at most 5 variables + n_var_dec = min(4, n_var_dec) # use at most 4 variables else: n_var_dec = inputs.shape[1] diff --git a/tests/test_decomposition.py b/tests/test_decomposition.py index fa0989c..0e4a56a 100644 --- a/tests/test_decomposition.py +++ b/tests/test_decomposition.py @@ -24,3 +24,52 @@ def test_decomposition(): assert res.states == [2, 2, 2, 2] assert res.statistic.shape == (2, 2, 2, 2) npt.assert_allclose(res.bins.describe().T["mean"], res.statistic.flatten()) + + +def test_auto_ordering_single_dominant_variable(): + fname = path_data / "stress.csv" + data = pd.read_csv(fname) + output_name, *v_names = list(data.columns) + inputs, output = data[v_names], data[output_name] + + si = np.array([0.90, 0.05, 0.03, 0.02]) + res = sd.decomposition(inputs=inputs, output=output, sensitivity_indices=si) + assert len(res.var_names) == 1 + + +def test_auto_ordering_two_variables_cross_threshold(): + fname = path_data / "stress.csv" + data = pd.read_csv(fname) + output_name, *v_names = list(data.columns) + inputs, output = data[v_names], data[output_name] + + # sum = 1.0, cumsum = [0.75, 0.81, ...] -> crosses 0.8 after 2nd variable + si = np.array([0.75, 0.06, 0.10, 0.09]) + res = sd.decomposition(inputs=inputs, output=output, sensitivity_indices=si) + assert len(res.var_names) == 2 + + +def test_auto_ordering_cap_at_four(): + """Even if more than 4 variables are needed to reach 0.8, cap at 4.""" + fname = path_data / "stress.csv" + data = pd.read_csv(fname) + output_name, *v_names = list(data.columns) + inputs, output = data[v_names], data[output_name] + + # sum = 1.0, each variable contributes equally -> need all 4 to reach 0.8 + si = np.array([0.25, 0.25, 0.25, 0.25]) + res = sd.decomposition(inputs=inputs, output=output, sensitivity_indices=si) + assert len(res.var_names) == 4 + + +def test_auto_ordering_si_not_summing_to_one(): + """Threshold is relative to sum(si), not hardcoded 1.0.""" + fname = path_data / "stress.csv" + data = pd.read_csv(fname) + output_name, *v_names = list(data.columns) + inputs, output = data[v_names], data[output_name] + + # sum = 2.0, 0.8 * 2.0 = 1.6, cumsum = [1.8, ...] -> crosses after 1st variable + si = np.array([1.80, 0.10, 0.05, 0.05]) + res = sd.decomposition(inputs=inputs, output=output, sensitivity_indices=si) + assert len(res.var_names) == 1