From b2cbd649f00cb05aedf2aa03d008e00a513a8a8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Fri, 22 May 2026 00:59:53 +0100 Subject: [PATCH 01/22] Batch estimator call in GeneralizingEstimator.score `_gl_score` invoked the scorer's response method (`predict` / `predict_proba` / `decision_function`) `n_estimators * n_slices` times per fold. Stack X across slices and call the response method once per estimator, then apply the metric per slice on the resulting predictions. The batching saves on overhead and best utilises vectorized operations. Scorers without a recognised `_response_method` (e.g. `scoring=None` or custom callables) fall back to the original nested loop. --- mne/decoding/search_light.py | 75 +++++++++++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 9 deletions(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 81699ecd5ba..31238fca162 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -743,17 +743,74 @@ def _gl_score(estimators, scoring, X, y, pb): """ # FIXME: The level parallelization may be a bit high, and might be memory # consuming. Perhaps need to lower it down to the loop across X slices. - score_shape = [len(estimators), X.shape[-1]] - for jj in range(X.shape[-1]): - for ii, est in enumerate(estimators): - _score = scoring(est, X[..., jj], y) - # Initialize array of predictions on the first score iteration + n_sample, n_iter = X.shape[0], X.shape[-1] + n_train = len(estimators) + score_shape = [n_train, n_iter] + score = None + + # Detect whether we can batch the estimator. Recognised: + # * predict, + # * predict_proba + # * decision_function + # * "default" (= predict) + # * A tuple of those: roc_auc = ("decision_function", "predict_proba") + score_func = getattr(scoring, "_score_func", None) + rm = getattr(scoring, "_response_method", None) + valid = {"predict", "predict_proba", "decision_function"} + if rm == "default": + response_method = "predict" + elif isinstance(rm, str) and rm in valid: + response_method = rm + elif isinstance(rm, tuple) and all(m in valid for m in rm): + response_method = rm + else: + response_method = None + can_batch = score_func is not None and response_method is not None + + # If we can't batch we do a simple nested loop. + # Covers scoring=None / unrecognised scorers + if not can_batch: + for jj in range(n_iter): + for ii, est in enumerate(estimators): + _score = scoring(est, X[..., jj], y) + if (ii == 0) and (jj == 0): + score = np.zeros(score_shape, type(_score)) + score[ii, jj, ...] = _score + pb.update(jj * n_train + ii + 1) + return score + + # We can batch; the logic is: reshape X, predict once, reshape back, score + # First: stack X across slices for one batched response call per estimator + X_stack = np.moveaxis(X, -1, 1) + X_stack = X_stack.reshape(n_sample * n_iter, *X_stack.shape[2:]) + + # Use the provided response method, or pick the first one supported + # by the estimator + if isinstance(response_method, str): + method = response_method + else: + for m in response_method: + if hasattr(estimators[0], m): + method = m + break + + # Ensures score_func(..., **kwargs) doesn't crash when scoring._kwargs=None + kwargs = scoring._kwargs or {} + + for ii, est in enumerate(estimators): + y_pred = getattr(est, method)(X_stack) + # predict_proba returns probabilities for both classes; use the + # positive-class probabilities expected by binary scorers + if method == "predict_proba" and y_pred.ndim == 2 and y_pred.shape[1] == 2: + y_pred = y_pred[:, 1] + # Now, reshape back the prediction, then score + y_pred = y_pred.reshape((n_sample, n_iter) + y_pred.shape[1:]) + for jj in range(n_iter): + _score = scoring._sign * score_func(y, y_pred[:, jj], **kwargs) if (ii == 0) and (jj == 0): - dtype = type(_score) - score = np.zeros(score_shape, dtype) + score = np.zeros(score_shape, type(_score)) score[ii, jj, ...] = _score - - pb.update(jj * len(estimators) + ii + 1) + pb.update(ii * n_iter + jj + 1) return score From 3cef4a52ab92ce693a48c6ac3fe799f4f82d9b76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Fri, 22 May 2026 10:48:47 +0100 Subject: [PATCH 02/22] ENH: Vectorise accuracy scoring in GeneralizingEstimator.score When the scorer is `accuracy_score` with default kwargs, 1d-output (but can be multi-class), and `response_method == "predict"`, replace the per-slice `accuracy_score(y, y_pred[:, jj])` calls with one numpy reduction per estimator: `(y_pred == y[:, None]).mean(axis=0)`. Other scorers, multi-output `y`, etc. keep nested-loop behavior. --- mne/decoding/search_light.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 31238fca162..7e1cce71590 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -797,6 +797,15 @@ def _gl_score(estimators, scoring, X, y, pb): # Ensures score_func(..., **kwargs) doesn't crash when scoring._kwargs=None kwargs = scoring._kwargs or {} + # Fast path: vectorised accuracy. Skips the per-jj score_func call and + # computes (y_pred == y[:, None]).mean(axis=0) for all slices at once. + fast_accuracy = ( + getattr(score_func, "__name__", "") == "accuracy_score" + and not kwargs + and response_method == "predict" + and y.ndim == 1 + ) + for ii, est in enumerate(estimators): y_pred = getattr(est, method)(X_stack) # predict_proba returns probabilities for both classes; use the @@ -805,12 +814,20 @@ def _gl_score(estimators, scoring, X, y, pb): y_pred = y_pred[:, 1] # Now, reshape back the prediction, then score y_pred = y_pred.reshape((n_sample, n_iter) + y_pred.shape[1:]) - for jj in range(n_iter): - _score = scoring._sign * score_func(y, y_pred[:, jj], **kwargs) - if (ii == 0) and (jj == 0): - score = np.zeros(score_shape, type(_score)) - score[ii, jj, ...] = _score - pb.update(ii * n_iter + jj + 1) + # Either we can also score with a batch, here, or we loop again, below + if fast_accuracy: + row = scoring._sign * (y_pred == y[:, None]).mean(axis=0) + if ii == 0: + score = np.zeros(score_shape, row.dtype) + score[ii] = row + pb.update((ii + 1) * n_iter) + else: + for jj in range(n_iter): + _score = scoring._sign * score_func(y, y_pred[:, jj], **kwargs) + if (ii == 0) and (jj == 0): + score = np.zeros(score_shape, type(_score)) + score[ii, jj, ...] = _score + pb.update(ii * n_iter + jj + 1) return score From 7a66f157cde7a37884f5cd980116758e5db4b581 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Fri, 22 May 2026 12:53:57 +0100 Subject: [PATCH 03/22] MAINT: Refactor batched fast path in GeneralizingEstimator.score Replace the `fast_accuracy` flag with a `batched_score` which is either a callable if we recognized the scorer, or otherwise set to `None`. The scoring loop then branches on `batched_score`: call `batched_score(y_pred)` if `batched_score` is set, otherwise fall back to the nested loop. --- mne/decoding/search_light.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 7e1cce71590..0496079c9ac 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -797,14 +797,16 @@ def _gl_score(estimators, scoring, X, y, pb): # Ensures score_func(..., **kwargs) doesn't crash when scoring._kwargs=None kwargs = scoring._kwargs or {} - # Fast path: vectorised accuracy. Skips the per-jj score_func call and - # computes (y_pred == y[:, None]).mean(axis=0) for all slices at once. - fast_accuracy = ( - getattr(score_func, "__name__", "") == "accuracy_score" - and not kwargs - and response_method == "predict" - and y.ndim == 1 - ) + # Batched path: when we recognise score_func, build `batched_score` that + # scores all n_iter slices in a single vectorised reduction. it stays None + # for unrecognised scorers which falls back to nested loops + sign = scoring._sign + batched_score = None + if not kwargs and y.ndim == 1: + name = getattr(score_func, "__name__", "") + if name == "accuracy_score" and response_method == "predict": + def batched_score(y_pred): + return sign * (y_pred == y[:, None]).mean(axis=0) for ii, est in enumerate(estimators): y_pred = getattr(est, method)(X_stack) @@ -814,16 +816,16 @@ def _gl_score(estimators, scoring, X, y, pb): y_pred = y_pred[:, 1] # Now, reshape back the prediction, then score y_pred = y_pred.reshape((n_sample, n_iter) + y_pred.shape[1:]) - # Either we can also score with a batch, here, or we loop again, below - if fast_accuracy: - row = scoring._sign * (y_pred == y[:, None]).mean(axis=0) + # Either we can score with batching (if) or we loop again (else) + if batched_score is not None: + row = batched_score(y_pred) if ii == 0: score = np.zeros(score_shape, row.dtype) score[ii] = row pb.update((ii + 1) * n_iter) else: for jj in range(n_iter): - _score = scoring._sign * score_func(y, y_pred[:, jj], **kwargs) + _score = sign * score_func(y, y_pred[:, jj], **kwargs) if (ii == 0) and (jj == 0): score = np.zeros(score_shape, type(_score)) score[ii, jj, ...] = _score From 32028955ba93eb08b29ea6e48419b336035e6414 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Fri, 22 May 2026 13:01:31 +0100 Subject: [PATCH 04/22] ENH: Vectorise balanced_accuracy in GeneralizingEstimator.score Adds `balanced_accuracy_score` to the `batched_score` dispatch by manually estimating accuracy per class and then averaging. --- mne/decoding/search_light.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 0496079c9ac..62ff25b541a 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -807,6 +807,12 @@ def _gl_score(estimators, scoring, X, y, pb): if name == "accuracy_score" and response_method == "predict": def batched_score(y_pred): return sign * (y_pred == y[:, None]).mean(axis=0) + elif name == "balanced_accuracy_score" and response_method == "predict": + classes = np.unique(y) + def batched_score(y_pred): + return sign * np.stack( + [(y_pred[y == c] == c).mean(axis=0) for c in classes] + ).mean(axis=0) for ii, est in enumerate(estimators): y_pred = getattr(est, method)(X_stack) From 7a30425fc899684baba9724060ef8e9525ef7b9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Fri, 22 May 2026 13:33:50 +0100 Subject: [PATCH 05/22] ENH: Batch `scoring=None` in GeneralizingEstimator.score When `scoring=None`, sklearn wraps `estimator.score` in a scorer with no `_score_func` so previous code did not batch. But for stock classifiers, this is just `accuracy_score(y, predict(X))`: we now detect this and promote `scoring` to the named "accuracy" scorer which uses the existing batched path. --- mne/decoding/search_light.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 62ff25b541a..b0833d0d0f4 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -748,6 +748,18 @@ def _gl_score(estimators, scoring, X, y, pb): score_shape = [n_train, n_iter] score = None + # scoring=None goes through sklearn's _PassthroughScorer, which delegates + # to estimator.score(X, y). For a classifier inheriting + # ClassifierMixin.score unchanged, that's accuracy which we now set. We + # compare `type(est).score.__qualname__` rather than `.__name__` because + # the bare name is "score" no matter which class defined the method. A bare + # method has qualname "ClassifierMixin.score", whereas any override + # resolves to ".score". We only take over bare methods. + if len(estimators) and getattr(scoring, "_score_func", None) is None: + qname = getattr(type(estimators[0]).score, "__qualname__", "") + if qname == "ClassifierMixin.score": + scoring = check_scoring(estimators[0], "accuracy") + # Detect whether we can batch the estimator. Recognised: # * predict, # * predict_proba From b23e72dfbb8979bc347ed9436878fc57e4e2b465 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Fri, 22 May 2026 14:43:27 +0100 Subject: [PATCH 06/22] ENH: Vectorise roc_auc in GeneralizingEstimator.score Add `roc_auc_score` to the `batched_score` dispatch via the Mann-Whitney U identity with average-rank tie correction (`scipy.stats.rankdata(..., method="average", axis=0)` ranks all slices at once). Binary classification only: multi-class, or non-default kwargs, revert to nested loops. The rank identity implemented when batching gives the same AUC as sklearn within floating point precision, but implements it with different operations. A bit-exact match would need a loop, defeating the batching. --- mne/decoding/search_light.py | 17 +++++++++++++++++ mne/decoding/tests/test_search_light.py | 11 ++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index b0833d0d0f4..ebb19a6cb2d 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -5,6 +5,7 @@ import logging import numpy as np +from scipy.stats import rankdata from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone from sklearn.metrics import check_scoring from sklearn.preprocessing import LabelEncoder @@ -825,6 +826,22 @@ def batched_score(y_pred): return sign * np.stack( [(y_pred[y == c] == c).mean(axis=0) for c in classes] ).mean(axis=0) + elif name == "roc_auc_score" and method in ( + "predict_proba", "decision_function" + ): + classes = np.unique(y) + if len(classes) == 2: # multi-class needs ovr/ovo; defer + pos = y == classes[1] + n_pos, n_neg = int(pos.sum()), int((~pos).sum()) + if n_pos and n_neg: # degenerate folds raise downstream in sklearn + def batched_score(y_pred): + # Mann-Whitney U identity with average-rank tie + # correction. Equivalent to sklearn's roc_auc within + # floating point precision, but different computaion + ranks = rankdata(y_pred, method="average", axis=0) + return sign * ( + ranks[pos].sum(axis=0) - n_pos * (n_pos + 1) / 2.0 + ) / (n_pos * n_neg) for ii, est in enumerate(estimators): y_pred = getattr(est, method)(X_stack) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 56d239c1bcc..5d9f732e22b 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from numpy.testing import assert_array_equal, assert_equal +from numpy.testing import assert_array_equal, assert_equal, assert_allclose sklearn = pytest.importorskip("sklearn") @@ -246,7 +246,11 @@ def test_generalization_light(metadata_routing): gl.fit(X, y) score = gl.score(X, y) auc = roc_auc_score(y, gl.estimators_[0].predict_proba(X[..., 0])[..., 1]) - assert_equal(score[0, 0], auc) + + # The rank identity implemented when batching gives the same AUC as sklearn + # within floating point precision, but implements it with different + # operations. A bit-exact match would need a loop, defeating the batching. + assert_allclose(score[0, 0], auc) for scoring in ["foo", 999]: gl = GeneralizingEstimator(logreg, scoring=scoring) @@ -267,7 +271,8 @@ def test_generalization_light(metadata_routing): [roc_auc_score(y - 1, _y_pred) for _y_pred in _y_preds] for _y_preds in gl.decision_function(X).transpose(1, 2, 0) ] - assert_array_equal(score, manual_score) + # allclose instead of equal: see above, batching roc_auc forces this. + assert_allclose(score, manual_score) # n_jobs gl = GeneralizingEstimator(logreg, n_jobs=2) From aac4a90f63abf1262f31a27ac25622a07c4ce41f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 May 2026 16:19:14 +0000 Subject: [PATCH 07/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/decoding/search_light.py | 16 +++++++++++----- mne/decoding/tests/test_search_light.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index ebb19a6cb2d..76d7b45dd2e 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -768,7 +768,7 @@ def _gl_score(estimators, scoring, X, y, pb): # * "default" (= predict) # * A tuple of those: roc_auc = ("decision_function", "predict_proba") score_func = getattr(scoring, "_score_func", None) - rm = getattr(scoring, "_response_method", None) + rm = getattr(scoring, "_response_method", None) valid = {"predict", "predict_proba", "decision_function"} if rm == "default": response_method = "predict" @@ -818,30 +818,36 @@ def _gl_score(estimators, scoring, X, y, pb): if not kwargs and y.ndim == 1: name = getattr(score_func, "__name__", "") if name == "accuracy_score" and response_method == "predict": + def batched_score(y_pred): return sign * (y_pred == y[:, None]).mean(axis=0) elif name == "balanced_accuracy_score" and response_method == "predict": classes = np.unique(y) + def batched_score(y_pred): return sign * np.stack( [(y_pred[y == c] == c).mean(axis=0) for c in classes] ).mean(axis=0) elif name == "roc_auc_score" and method in ( - "predict_proba", "decision_function" + "predict_proba", + "decision_function", ): classes = np.unique(y) if len(classes) == 2: # multi-class needs ovr/ovo; defer pos = y == classes[1] n_pos, n_neg = int(pos.sum()), int((~pos).sum()) if n_pos and n_neg: # degenerate folds raise downstream in sklearn + def batched_score(y_pred): # Mann-Whitney U identity with average-rank tie # correction. Equivalent to sklearn's roc_auc within # floating point precision, but different computaion ranks = rankdata(y_pred, method="average", axis=0) - return sign * ( - ranks[pos].sum(axis=0) - n_pos * (n_pos + 1) / 2.0 - ) / (n_pos * n_neg) + return ( + sign + * (ranks[pos].sum(axis=0) - n_pos * (n_pos + 1) / 2.0) + / (n_pos * n_neg) + ) for ii, est in enumerate(estimators): y_pred = getattr(est, method)(X_stack) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 5d9f732e22b..749d241c36c 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from numpy.testing import assert_array_equal, assert_equal, assert_allclose +from numpy.testing import assert_allclose, assert_array_equal, assert_equal sklearn = pytest.importorskip("sklearn") From df1cfddc97e860d1c0b2d7c6edc9f878fb4f4578 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Fri, 22 May 2026 17:27:17 +0100 Subject: [PATCH 08/22] Fix typo: computaion ==> computation... --- mne/decoding/search_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 76d7b45dd2e..8f1f0a322e2 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -841,7 +841,7 @@ def batched_score(y_pred): def batched_score(y_pred): # Mann-Whitney U identity with average-rank tie # correction. Equivalent to sklearn's roc_auc within - # floating point precision, but different computaion + # floating point precision, but different computation ranks = rankdata(y_pred, method="average", axis=0) return ( sign From d7faec491b10c9e4d32dc28d24f4668c95bc1a44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Fri, 22 May 2026 18:44:00 +0100 Subject: [PATCH 09/22] Adds entry in `doc/changes/dev/` + `names.inc` --- doc/changes/dev/13909.other.rst | 4 ++++ doc/changes/names.inc | 1 + 2 files changed, 5 insertions(+) create mode 100644 doc/changes/dev/13909.other.rst diff --git a/doc/changes/dev/13909.other.rst b/doc/changes/dev/13909.other.rst new file mode 100644 index 00000000000..a2131a6ade9 --- /dev/null +++ b/doc/changes/dev/13909.other.rst @@ -0,0 +1,4 @@ +Batch and vectorise classifier estimation and scoring in +:meth:`mne.decoding.GeneralizingEstimator.score` for ``scoring=None``, +``"accuracy"``, ``"balanced_accuracy"`` and ``"roc_auc"``, by +:newcontrib:`Mathias Sablé-Meyer`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 86644392e3b..3418311d780 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -212,6 +212,7 @@ .. _Martin Luessi: https://github.com/mluessi .. _Martin Oberg: https://github.com/obergmartin .. _Martin Schulz: https://github.com/marsipu +.. _Mathias Sablé-Meyer: https://s-m.ac/ .. _Mathieu Scheltienne: https://github.com/mscheltienne .. _Mathurin Massias: https://mathurinm.github.io/ .. _Mats van Es: https://github.com/matsvanes From b1830f219cf52bda97c4e090d9a3ca55a6a4760b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Mon, 25 May 2026 13:35:24 +0100 Subject: [PATCH 10/22] Clarify comments + harmonize var names --- mne/decoding/search_light.py | 42 +++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 8f1f0a322e2..a776406e8e4 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -755,17 +755,14 @@ def _gl_score(estimators, scoring, X, y, pb): # compare `type(est).score.__qualname__` rather than `.__name__` because # the bare name is "score" no matter which class defined the method. A bare # method has qualname "ClassifierMixin.score", whereas any override - # resolves to ".score". We only take over bare methods. + # resolves to ".score". We only override the bare implementation. if len(estimators) and getattr(scoring, "_score_func", None) is None: qname = getattr(type(estimators[0]).score, "__qualname__", "") if qname == "ClassifierMixin.score": scoring = check_scoring(estimators[0], "accuracy") - # Detect whether we can batch the estimator. Recognised: - # * predict, - # * predict_proba - # * decision_function - # * "default" (= predict) + # Detect whether we can batch the estimator. Supported methods: + # * predict, predict_proba, decision_function, * "default" (= predict), # * A tuple of those: roc_auc = ("decision_function", "predict_proba") score_func = getattr(scoring, "_score_func", None) rm = getattr(scoring, "_response_method", None) @@ -780,8 +777,7 @@ def _gl_score(estimators, scoring, X, y, pb): response_method = None can_batch = score_func is not None and response_method is not None - # If we can't batch we do a simple nested loop. - # Covers scoring=None / unrecognised scorers + # If we can't batch, fall back to a simple nested loop for scoring if not can_batch: for jj in range(n_iter): for ii, est in enumerate(estimators): @@ -792,13 +788,16 @@ def _gl_score(estimators, scoring, X, y, pb): pb.update(jj * n_train + ii + 1) return score - # We can batch; the logic is: reshape X, predict once, reshape back, score - # First: stack X across slices for one batched response call per estimator + # If we can batch; the logic is: reshape X; for each estimator: predict; + # reshape back; score with batching if we can and with a loop if not. + # Use the provided response method, or pick the first one supported + # by the estimator + # Collapse the last dimension into the sample dimension: + # (n_trials, ..., n_times) -> (n_trials * n_times, ...) X_stack = np.moveaxis(X, -1, 1) X_stack = X_stack.reshape(n_sample * n_iter, *X_stack.shape[2:]) - # Use the provided response method, or pick the first one supported - # by the estimator + # Resolve the prediction method to call later. if isinstance(response_method, str): method = response_method else: @@ -807,12 +806,13 @@ def _gl_score(estimators, scoring, X, y, pb): method = m break - # Ensures score_func(..., **kwargs) doesn't crash when scoring._kwargs=None + # Extract scoring args if any. Don't batch if non-default arguments; also + # ensures score_func(..., **kwargs) won't crash when scoring._kwargs=None kwargs = scoring._kwargs or {} - # Batched path: when we recognise score_func, build `batched_score` that - # scores all n_iter slices in a single vectorised reduction. it stays None - # for unrecognised scorers which falls back to nested loops + # When we recognise score_func, we build `batched_score` that scores in a + # single vectorised reduction. + # `batched_score`=None for unrecognised scorers: fall back to nested loops sign = scoring._sign batched_score = None if not kwargs and y.ndim == 1: @@ -849,6 +849,8 @@ def batched_score(y_pred): / (n_pos * n_neg) ) + # For each estimator: predict, then score either vectorially or + # slice-by-slice. for ii, est in enumerate(estimators): y_pred = getattr(est, method)(X_stack) # predict_proba returns probabilities for both classes; use the @@ -857,12 +859,12 @@ def batched_score(y_pred): y_pred = y_pred[:, 1] # Now, reshape back the prediction, then score y_pred = y_pred.reshape((n_sample, n_iter) + y_pred.shape[1:]) - # Either we can score with batching (if) or we loop again (else) + # Use vectorized scoring when available; otherwise score slice-by-slice. if batched_score is not None: - row = batched_score(y_pred) + _score = batched_score(y_pred) if ii == 0: - score = np.zeros(score_shape, row.dtype) - score[ii] = row + score = np.zeros(score_shape, _score.dtype) + score[ii] = _score pb.update((ii + 1) * n_iter) else: for jj in range(n_iter): From f0596225b451b00bb412681ab6cfb47c48c885ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Mon, 25 May 2026 14:13:40 +0100 Subject: [PATCH 11/22] Match comment and docstring --- mne/decoding/search_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index a776406e8e4..2842cf53f4e 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -793,7 +793,7 @@ def _gl_score(estimators, scoring, X, y, pb): # Use the provided response method, or pick the first one supported # by the estimator # Collapse the last dimension into the sample dimension: - # (n_trials, ..., n_times) -> (n_trials * n_times, ...) + # (n_sample, ..., n_iter) -> (n_sample * n_iter, ...) X_stack = np.moveaxis(X, -1, 1) X_stack = X_stack.reshape(n_sample * n_iter, *X_stack.shape[2:]) From 86f9b3ba915dbe703285a6160ce32bd882e48c3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Mon, 25 May 2026 16:36:42 +0100 Subject: [PATCH 12/22] Minimal comment change to re-trigger tests --- mne/decoding/search_light.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 2842cf53f4e..fd6500d2873 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -788,16 +788,14 @@ def _gl_score(estimators, scoring, X, y, pb): pb.update(jj * n_train + ii + 1) return score - # If we can batch; the logic is: reshape X; for each estimator: predict; + # If we batch, the logic is: reshape X; for each estimator: predict; # reshape back; score with batching if we can and with a loop if not. - # Use the provided response method, or pick the first one supported - # by the estimator - # Collapse the last dimension into the sample dimension: + # Collapse the n_iter into the n_sample dimension: # (n_sample, ..., n_iter) -> (n_sample * n_iter, ...) X_stack = np.moveaxis(X, -1, 1) X_stack = X_stack.reshape(n_sample * n_iter, *X_stack.shape[2:]) - # Resolve the prediction method to call later. + # Resolve the prediction method to call later if isinstance(response_method, str): method = response_method else: From e8341d0292f550c97c78e24547335e9ddfe7754f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Tue, 26 May 2026 22:47:12 +0100 Subject: [PATCH 13/22] Reuse `_gl_transform` where possible My code partially re-implemented logic that was present in `_gl_transform` already. This new version, as much as possible, simply reuses it. This cleans up the scoring's logic and avoid duplication. --- mne/decoding/search_light.py | 52 +++++++++++++++--------------------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index fd6500d2873..356fb930d0d 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -744,7 +744,7 @@ def _gl_score(estimators, scoring, X, y, pb): """ # FIXME: The level parallelization may be a bit high, and might be memory # consuming. Perhaps need to lower it down to the loop across X slices. - n_sample, n_iter = X.shape[0], X.shape[-1] + n_iter = X.shape[-1] n_train = len(estimators) score_shape = [n_train, n_iter] score = None @@ -788,14 +788,8 @@ def _gl_score(estimators, scoring, X, y, pb): pb.update(jj * n_train + ii + 1) return score - # If we batch, the logic is: reshape X; for each estimator: predict; - # reshape back; score with batching if we can and with a loop if not. - # Collapse the n_iter into the n_sample dimension: - # (n_sample, ..., n_iter) -> (n_sample * n_iter, ...) - X_stack = np.moveaxis(X, -1, 1) - X_stack = X_stack.reshape(n_sample * n_iter, *X_stack.shape[2:]) - - # Resolve the prediction method to call later + # Resolve a single method name for _gl_transform: pick the first available + # if response_method is a tuple (e.g. roc_auc). if isinstance(response_method, str): method = response_method else: @@ -804,12 +798,21 @@ def _gl_score(estimators, scoring, X, y, pb): method = m break + # Batch all predictions through _gl_transform. y_pred shape: + # (n_sample, n_train, n_iter) or (n_sample, n_train, n_iter, n_classes). + y_pred = _gl_transform(estimators, X, method, pb) + + # Binary predict_proba: take the positive-class column to match sklearn + # scorer expectations for binary problems. + if method == "predict_proba" and y_pred.ndim == 4 and y_pred.shape[-1] == 2: + y_pred = y_pred[..., 1] + # Extract scoring args if any. Don't batch if non-default arguments; also # ensures score_func(..., **kwargs) won't crash when scoring._kwargs=None kwargs = scoring._kwargs or {} # When we recognise score_func, we build `batched_score` that scores in a - # single vectorised reduction. + # single vectorised reduction over (n_sample, n_train, n_iter). # `batched_score`=None for unrecognised scorers: fall back to nested loops sign = scoring._sign batched_score = None @@ -818,7 +821,7 @@ def _gl_score(estimators, scoring, X, y, pb): if name == "accuracy_score" and response_method == "predict": def batched_score(y_pred): - return sign * (y_pred == y[:, None]).mean(axis=0) + return sign * (y_pred == y[:, None, None]).mean(axis=0) elif name == "balanced_accuracy_score" and response_method == "predict": classes = np.unique(y) @@ -847,30 +850,17 @@ def batched_score(y_pred): / (n_pos * n_neg) ) - # For each estimator: predict, then score either vectorially or - # slice-by-slice. - for ii, est in enumerate(estimators): - y_pred = getattr(est, method)(X_stack) - # predict_proba returns probabilities for both classes; use the - # positive-class probabilities expected by binary scorers - if method == "predict_proba" and y_pred.ndim == 2 and y_pred.shape[1] == 2: - y_pred = y_pred[:, 1] - # Now, reshape back the prediction, then score - y_pred = y_pred.reshape((n_sample, n_iter) + y_pred.shape[1:]) - # Use vectorized scoring when available; otherwise score slice-by-slice. - if batched_score is not None: - _score = batched_score(y_pred) - if ii == 0: - score = np.zeros(score_shape, _score.dtype) - score[ii] = _score - pb.update((ii + 1) * n_iter) - else: + # Reduce predictions to scores. Vectorised if we recognised the scorer, + # otherwise nested loops over (estimator, slice). + if batched_score is not None: + score = batched_score(y_pred) + else: + for ii in range(n_train): for jj in range(n_iter): - _score = sign * score_func(y, y_pred[:, jj], **kwargs) + _score = sign * score_func(y, y_pred[:, ii, jj], **kwargs) if (ii == 0) and (jj == 0): score = np.zeros(score_shape, type(_score)) score[ii, jj, ...] = _score - pb.update(ii * n_iter + jj + 1) return score From 19e918f07e963d11e58f0574b30f757378539421 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Tue, 26 May 2026 23:19:17 +0100 Subject: [PATCH 14/22] Create helper functions to _gl_score So far, this re-write implemented everything in the body of `_gl_score`, but this turned out to make reading the overall logic harder. Abstract away (i) defaulting to "accuracy" for default args, (ii) extracting response method and ability to batch, and (iii) creating vectorized scoring functions --- mne/decoding/search_light.py | 171 +++++++++++++++++++++-------------- 1 file changed, 103 insertions(+), 68 deletions(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 356fb930d0d..34a58f99643 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -716,6 +716,100 @@ def _gl_init_pred(y_pred, X, n_train): return y_pred +def _resolve_scoring_for_classifier(scoring, estimators): + """Promote scoring=None to 'accuracy' for default estimator. + + scoring=None goes through sklearn's ``_PassthroughScorer``, which delegates + to ``estimator.score(X, y)``. For a classifier that inherits + ``ClassifierMixin.score`` unchanged, that's accuracy — which we can batch. + We compare ``type(est).score.__qualname__`` rather than ``.__name__`` + because the bare name is "score" regardless of the defining class. A bare + method has qualname "ClassifierMixin.score"; any override resolves to + ".score", which we leave untouched. + """ + if len(estimators) and getattr(scoring, "_score_func", None) is None: + qname = getattr(type(estimators[0]).score, "__qualname__", "") + if qname == "ClassifierMixin.score": + scoring = check_scoring(estimators[0], "accuracy") + return scoring + + +def _detect_response_method(scoring): + """Return (response_method, can_batch) for ``scoring``. + + If we can batch the estimator (one of predict, predict_proba, + decision_function, or a tuple of those) then we return that, and + can_batch is True if additionally _score_func is None (default). + Otherwise we return None (for the response_method) and False + """ + score_func = getattr(scoring, "_score_func", None) + rm = getattr(scoring, "_response_method", None) + valid = {"predict", "predict_proba", "decision_function"} + if rm == "default": + response_method = "predict" + elif isinstance(rm, str) and rm in valid: + response_method = rm + elif isinstance(rm, tuple) and all(m in valid for m in rm): + response_method = rm + else: + response_method = None + can_batch = score_func is not None and response_method is not None + return response_method, can_batch + + +def _make_batched_score(score_func, response_method, method, y, sign, kwargs): + """Return a callable ``y_pred``->score per task, None if not recognised. + + The returned callable expects ``y_pred`` of shape + ``(n_sample, n_train, n_iter)`` and returns shape ``(n_train, n_iter)``. + Falls back to None for any scorer with non-default ``kwargs`` or + multi-target ``y``, both of which require a slice-by-slice loop. + """ + if kwargs or y.ndim != 1: + return None + name = getattr(score_func, "__name__", "") + + if name == "accuracy_score" and response_method == "predict": + + def batched_score(y_pred): + return sign * (y_pred == y[:, None, None]).mean(axis=0) + + return batched_score + + if name == "balanced_accuracy_score" and response_method == "predict": + classes = np.unique(y) + + def batched_score(y_pred): + return sign * np.stack( + [(y_pred[y == c] == c).mean(axis=0) for c in classes] + ).mean(axis=0) + + return batched_score + + if name == "roc_auc_score" and method in ("predict_proba", "decision_function"): + classes = np.unique(y) + if len(classes) != 2: # multi-class needs ovr/ovo; defer + return None + pos = y == classes[1] + n_pos, n_neg = int(pos.sum()), int((~pos).sum()) + if not (n_pos and n_neg): # degenerate folds raise downstream in sklearn + return None + + def batched_score(y_pred): + # Mann-Whitney U identity with average-rank tie correction. + # Equivalent to sklearn's roc_auc within floating point precision, + # but different computation. + ranks = rankdata(y_pred, method="average", axis=0) + return ( + sign + * (ranks[pos].sum(axis=0) - n_pos * (n_pos + 1) / 2.0) + / (n_pos * n_neg) + ) + + return batched_score + return None + + def _gl_score(estimators, scoring, X, y, pb): """Score GeneralizingEstimator in parallel. @@ -749,33 +843,8 @@ def _gl_score(estimators, scoring, X, y, pb): score_shape = [n_train, n_iter] score = None - # scoring=None goes through sklearn's _PassthroughScorer, which delegates - # to estimator.score(X, y). For a classifier inheriting - # ClassifierMixin.score unchanged, that's accuracy which we now set. We - # compare `type(est).score.__qualname__` rather than `.__name__` because - # the bare name is "score" no matter which class defined the method. A bare - # method has qualname "ClassifierMixin.score", whereas any override - # resolves to ".score". We only override the bare implementation. - if len(estimators) and getattr(scoring, "_score_func", None) is None: - qname = getattr(type(estimators[0]).score, "__qualname__", "") - if qname == "ClassifierMixin.score": - scoring = check_scoring(estimators[0], "accuracy") - - # Detect whether we can batch the estimator. Supported methods: - # * predict, predict_proba, decision_function, * "default" (= predict), - # * A tuple of those: roc_auc = ("decision_function", "predict_proba") - score_func = getattr(scoring, "_score_func", None) - rm = getattr(scoring, "_response_method", None) - valid = {"predict", "predict_proba", "decision_function"} - if rm == "default": - response_method = "predict" - elif isinstance(rm, str) and rm in valid: - response_method = rm - elif isinstance(rm, tuple) and all(m in valid for m in rm): - response_method = rm - else: - response_method = None - can_batch = score_func is not None and response_method is not None + scoring = _resolve_scoring_for_classifier(scoring, estimators) + response_method, can_batch = _detect_response_method(scoring) # If we can't batch, fall back to a simple nested loop for scoring if not can_batch: @@ -807,48 +876,14 @@ def _gl_score(estimators, scoring, X, y, pb): if method == "predict_proba" and y_pred.ndim == 4 and y_pred.shape[-1] == 2: y_pred = y_pred[..., 1] - # Extract scoring args if any. Don't batch if non-default arguments; also - # ensures score_func(..., **kwargs) won't crash when scoring._kwargs=None - kwargs = scoring._kwargs or {} - - # When we recognise score_func, we build `batched_score` that scores in a - # single vectorised reduction over (n_sample, n_train, n_iter). - # `batched_score`=None for unrecognised scorers: fall back to nested loops + # `scoring._kwargs or {}` also guards score_func(..., **kwargs) against + # scoring._kwargs being None. + score_func = scoring._score_func sign = scoring._sign - batched_score = None - if not kwargs and y.ndim == 1: - name = getattr(score_func, "__name__", "") - if name == "accuracy_score" and response_method == "predict": - - def batched_score(y_pred): - return sign * (y_pred == y[:, None, None]).mean(axis=0) - elif name == "balanced_accuracy_score" and response_method == "predict": - classes = np.unique(y) - - def batched_score(y_pred): - return sign * np.stack( - [(y_pred[y == c] == c).mean(axis=0) for c in classes] - ).mean(axis=0) - elif name == "roc_auc_score" and method in ( - "predict_proba", - "decision_function", - ): - classes = np.unique(y) - if len(classes) == 2: # multi-class needs ovr/ovo; defer - pos = y == classes[1] - n_pos, n_neg = int(pos.sum()), int((~pos).sum()) - if n_pos and n_neg: # degenerate folds raise downstream in sklearn - - def batched_score(y_pred): - # Mann-Whitney U identity with average-rank tie - # correction. Equivalent to sklearn's roc_auc within - # floating point precision, but different computation - ranks = rankdata(y_pred, method="average", axis=0) - return ( - sign - * (ranks[pos].sum(axis=0) - n_pos * (n_pos + 1) / 2.0) - / (n_pos * n_neg) - ) + kwargs = scoring._kwargs or {} + batched_score = _make_batched_score( + score_func, response_method, method, y, sign, kwargs + ) # Reduce predictions to scores. Vectorised if we recognised the scorer, # otherwise nested loops over (estimator, slice). From e775c12dd7f7141804ea098c06c175595abed90e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Tue, 26 May 2026 23:29:54 +0100 Subject: [PATCH 15/22] transpose + reshape once only in `_gl_transform` In reusing `_gl_transform` in the new batched `_gl_score` I ended up with a significant slowdown, which it turns out was due to a transpose+reshape performed many times instead of once --- mne/decoding/search_light.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 34a58f99643..1a16fe92073 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -685,10 +685,10 @@ def _gl_transform(estimators, X, method, pb): The transformed values generated by each estimator. """ n_sample, n_iter = X.shape[0], X.shape[-1] + # stack generalized data for faster prediction + X_stack = X.transpose(np.r_[0, X.ndim - 1, range(1, X.ndim - 1)]) + X_stack = X_stack.reshape(np.r_[n_sample * n_iter, X_stack.shape[2:]]) for ii, est in enumerate(estimators): - # stack generalized data for faster prediction - X_stack = X.transpose(np.r_[0, X.ndim - 1, range(1, X.ndim - 1)]) - X_stack = X_stack.reshape(np.r_[n_sample * n_iter, X_stack.shape[2:]]) transform = getattr(est, method) _y_pred = transform(X_stack) # unstack generalizations From bd870685570b406a47db59847423b85aee5f21b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Tue, 26 May 2026 23:42:10 +0100 Subject: [PATCH 16/22] Variable initialization not used anymore --- mne/decoding/search_light.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 1a16fe92073..b2727f56b3f 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -841,7 +841,6 @@ def _gl_score(estimators, scoring, X, y, pb): n_iter = X.shape[-1] n_train = len(estimators) score_shape = [n_train, n_iter] - score = None scoring = _resolve_scoring_for_classifier(scoring, estimators) response_method, can_batch = _detect_response_method(scoring) From ed3ada8003a340af19ac59da94623bf54bece02f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Wed, 27 May 2026 00:53:10 +0100 Subject: [PATCH 17/22] Fix docstrings in _gl_score and _gl_transform --- mne/decoding/search_light.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index b2727f56b3f..ce11ae296ae 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -681,9 +681,10 @@ def _gl_transform(estimators, X, method, pb): Returns ------- - Xt : array, shape (n_samples, n_slices) - The transformed values generated by each estimator. - """ + Xt : array, shape (n_samples, n_estimators, n_slices) | (n_samples, n_estimators, n_slices, n_classes) + Predictions for each (estimator, slice) pair. The trailing axis is + present when ``method`` returns multi-output (e.g. ``predict_proba``). + """ # noqa: E501 n_sample, n_iter = X.shape[0], X.shape[-1] # stack generalized data for faster prediction X_stack = X.transpose(np.r_[0, X.ndim - 1, range(1, X.ndim - 1)]) @@ -739,7 +740,7 @@ def _detect_response_method(scoring): If we can batch the estimator (one of predict, predict_proba, decision_function, or a tuple of those) then we return that, and - can_batch is True if additionally _score_func is None (default). + can_batch is True if additionally _score_func is not None. Otherwise we return None (for the response_method) and False """ score_func = getattr(scoring, "_score_func", None) From 98003f62c61aef7223d3ccec92f5fd1332b98405 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Wed, 27 May 2026 00:54:01 +0100 Subject: [PATCH 18/22] Adds tests for the batched/vectorized _gl_score --- mne/decoding/tests/test_search_light.py | 74 ++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 749d241c36c..9c3a3e4731b 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -15,7 +15,7 @@ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.ensemble import BaggingClassifier from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge -from sklearn.metrics import make_scorer, roc_auc_score +from sklearn.metrics import check_scoring, make_scorer, roc_auc_score from sklearn.model_selection import cross_val_predict from sklearn.multiclass import OneVsRestClassifier from sklearn.pipeline import make_pipeline @@ -298,6 +298,78 @@ def test_generalization_light(metadata_routing): assert_array_equal(y_preds[0], y_preds[1]) +@pytest.mark.parametrize( + "scoring, est_name, method", + [ + (None, "logreg", "predict"), + ("accuracy", "logreg", "predict"), + ("balanced_accuracy", "logreg", "predict"), + ("roc_auc", "logreg", "decision_function"), + ("neg_log_loss", "logreg", "predict_proba"), + (None, "ridge", "predict"), + ("roc_auc_multiclass", "logreg", "predict_proba"), + ("accuracy_kwargs", "logreg", "predict"), + ], +) +def test_gl_score_branches(scoring, est_name, method): + """Test _gl_score against its own can_batch=False nested-loop fallback.""" + n_trials, n_sensors, n_iter = 12, 3, 4 + rng = np.random.RandomState(0) + X = rng.randn(n_trials, n_sensors, n_iter) + y = rng.randint(0, 3 if scoring == "roc_auc_multiclass" else 2, n_trials) + per_slice = scoring in ("neg_log_loss", "roc_auc_multiclass", "accuracy_kwargs") + # liblinear is binary-only, switch to lbfgs for the multi-class case. + solver = "lbfgs" if scoring == "roc_auc_multiclass" else "liblinear" + if scoring == "roc_auc_multiclass": + scoring = make_scorer( + roc_auc_score, response_method="predict_proba", multi_class="ovr" + ) + elif scoring == "accuracy_kwargs": + # start from the default scorer but add a kwarg to prevent batching + acc_func = check_scoring(LogisticRegression(), "accuracy")._score_func + scoring = make_scorer(acc_func, normalize=False) + est = Ridge() if est_name == "ridge" else LogisticRegression(solver=solver) + gl = GeneralizingEstimator(est, scoring=scoring).fit(X, y) + + # Measure batching: count pred and call scores. Wraps `fn` calls so they + # append to a bucket; preserve __name__ because _gl_score matches it + def counting(fn, bucket): + def wrapped(*a, **k): + bucket.append(1) + return fn(*a, **k) + + wrapped.__name__ = getattr(fn, "__name__", wrapped.__name__) + return wrapped + + # First we count calls to scorer + score_calls = [] + scorer = check_scoring(est, scoring) + if getattr(scorer, "_score_func", None) is not None: + scorer._score_func = counting(scorer._score_func, score_calls) + + # Now we count calls to the estimator that _gl_score will call (hardcoded) + pred_calls = [] + for e in gl.estimators_: + setattr(e, method, counting(getattr(e, method), pred_calls)) + + # Batched path: assert call counts immediately so the buckets only reflect + # this run (the reference run below would otherwise add to them). + gl.scoring = scorer + actual = gl.score(X, y) + assert len(pred_calls) == (n_iter if est_name != "ridge" else n_iter**2) + assert len(score_calls) == (n_iter**2 if per_slice else 0) + + # Reference: force can_batch=False. _score_func set (non-None) bypasses the + # qname coercion; missing _response_method makes can_batch False. + def force_fallback(e, X, y): + return scorer(e, X, y) + + force_fallback._score_func = id + gl.scoring = force_fallback + expected = gl.score(X, y) + assert_allclose(actual, expected) + + @pytest.mark.parametrize( "n_jobs, verbose", [(1, False), (2, False), (1, True), (2, "info")] ) From 321c05dcb4825e71aaac89ef760da144095b9540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Thu, 28 May 2026 10:39:31 +0100 Subject: [PATCH 19/22] More transparent reshaping syntax Co-authored-by: Eric Larson --- mne/decoding/search_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index ce11ae296ae..4016dcd2af0 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -688,7 +688,7 @@ def _gl_transform(estimators, X, method, pb): n_sample, n_iter = X.shape[0], X.shape[-1] # stack generalized data for faster prediction X_stack = X.transpose(np.r_[0, X.ndim - 1, range(1, X.ndim - 1)]) - X_stack = X_stack.reshape(np.r_[n_sample * n_iter, X_stack.shape[2:]]) + X_stack = X_stack.reshape((n_sample * n_iter,) + X_stack.shape[2:]]) for ii, est in enumerate(estimators): transform = getattr(est, method) _y_pred = transform(X_stack) From 7e0fa3aa258090e55cc8c658a77f3ba40827499b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Thu, 28 May 2026 10:40:05 +0100 Subject: [PATCH 20/22] Use standard syntax for optional dims in docstring Co-authored-by: Eric Larson --- mne/decoding/search_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 4016dcd2af0..1063454e7db 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -681,7 +681,7 @@ def _gl_transform(estimators, X, method, pb): Returns ------- - Xt : array, shape (n_samples, n_estimators, n_slices) | (n_samples, n_estimators, n_slices, n_classes) + Xt : array, shape (n_samples, n_estimators, n_slices[, n_classes]) Predictions for each (estimator, slice) pair. The trailing axis is present when ``method`` returns multi-output (e.g. ``predict_proba``). """ # noqa: E501 From a17fb0e4d83c529a038afa090f13d2502ef38a17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Thu, 28 May 2026 10:48:41 +0100 Subject: [PATCH 21/22] Fix typo / syntax error in new reshaping syntax --- mne/decoding/search_light.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 1063454e7db..fa7f5e2f882 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -688,7 +688,7 @@ def _gl_transform(estimators, X, method, pb): n_sample, n_iter = X.shape[0], X.shape[-1] # stack generalized data for faster prediction X_stack = X.transpose(np.r_[0, X.ndim - 1, range(1, X.ndim - 1)]) - X_stack = X_stack.reshape((n_sample * n_iter,) + X_stack.shape[2:]]) + X_stack = X_stack.reshape((n_sample * n_iter,) + X_stack.shape[2:]) for ii, est in enumerate(estimators): transform = getattr(est, method) _y_pred = transform(X_stack) From 5ca53ebb2829a4e5206470f07abbd481ece18d50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathias=20Sabl=C3=A9-Meyer?= Date: Thu, 28 May 2026 10:50:05 +0100 Subject: [PATCH 22/22] Uses `functools.wraps` instead of manual attr copy --- mne/decoding/tests/test_search_light.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 9c3a3e4731b..1ed96c034d1 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -3,6 +3,7 @@ # Copyright the MNE-Python contributors. import platform +from functools import wraps from inspect import signature import numpy as np @@ -331,14 +332,14 @@ def test_gl_score_branches(scoring, est_name, method): est = Ridge() if est_name == "ridge" else LogisticRegression(solver=solver) gl = GeneralizingEstimator(est, scoring=scoring).fit(X, y) - # Measure batching: count pred and call scores. Wraps `fn` calls so they - # append to a bucket; preserve __name__ because _gl_score matches it + # Measure batching: count pred/call scores. Wraps `fn` calls so they append + # to a bucket; @wraps preserves __name__ (needed as _gl_score matches it) def counting(fn, bucket): + @wraps(fn) def wrapped(*a, **k): bucket.append(1) return fn(*a, **k) - wrapped.__name__ = getattr(fn, "__name__", wrapped.__name__) return wrapped # First we count calls to scorer