Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b2cbd64
Batch estimator call in GeneralizingEstimator.score
mathias-sm May 21, 2026
3cef4a5
ENH: Vectorise accuracy scoring in GeneralizingEstimator.score
mathias-sm May 22, 2026
7a66f15
MAINT: Refactor batched fast path in GeneralizingEstimator.score
mathias-sm May 22, 2026
3202895
ENH: Vectorise balanced_accuracy in GeneralizingEstimator.score
mathias-sm May 22, 2026
7a30425
ENH: Batch `scoring=None` in GeneralizingEstimator.score
mathias-sm May 22, 2026
b23e72d
ENH: Vectorise roc_auc in GeneralizingEstimator.score
mathias-sm May 22, 2026
aac4a90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2026
df1cfdd
Fix typo: computaion ==> computation...
mathias-sm May 22, 2026
d7faec4
Adds entry in `doc/changes/dev/` + `names.inc`
mathias-sm May 22, 2026
b1830f2
Clarify comments + harmonize var names
mathias-sm May 25, 2026
d3f94ec
Merge branch 'mne-tools:main' into enh-issue-13906-BatchedGeneralizin…
mathias-sm May 25, 2026
f059622
Match comment and docstring
mathias-sm May 25, 2026
86f9b3b
Minimal comment change to re-trigger tests
mathias-sm May 25, 2026
c475df7
Merge branch 'main' into enh-issue-13906-BatchedGeneralizingEstimator
mathias-sm May 25, 2026
e8341d0
Reuse `_gl_transform` where possible
mathias-sm May 26, 2026
19e918f
Create helper functions to _gl_score
mathias-sm May 26, 2026
e775c12
transpose + reshape once only in `_gl_transform`
mathias-sm May 26, 2026
bd87068
Variable initialization not used anymore
mathias-sm May 26, 2026
ed3ada8
Fix docstrings in _gl_score and _gl_transform
mathias-sm May 26, 2026
98003f6
Adds tests for the batched/vectorized _gl_score
mathias-sm May 26, 2026
5df991b
Merge branch 'main' into enh-issue-13906-BatchedGeneralizingEstimator
larsoner May 27, 2026
321c05d
More transparent reshaping syntax
mathias-sm May 28, 2026
7e0fa3a
Use standard syntax for optional dims in docstring
mathias-sm May 28, 2026
a17fb0e
Fix typo / syntax error in new reshaping syntax
mathias-sm May 28, 2026
5ca53eb
Uses `functools.wraps` instead of manual attr copy
mathias-sm May 28, 2026
82eb284
Merge branch 'mne-tools:main' into enh-issue-13906-BatchedGeneralizin…
mathias-sm May 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/changes/dev/13909.other.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Batch and vectorise classifier estimation and scoring in
:meth:`mne.decoding.GeneralizingEstimator.score` for ``scoring=None``,
``"accuracy"``, ``"balanced_accuracy"`` and ``"roc_auc"``, by
:newcontrib:`Mathias Sablé-Meyer`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
.. _Martin Luessi: https://github.com/mluessi
.. _Martin Oberg: https://github.com/obergmartin
.. _Martin Schulz: https://github.com/marsipu
.. _Mathias Sablé-Meyer: https://s-m.ac/
.. _Mathieu Scheltienne: https://github.com/mscheltienne
.. _Mathurin Massias: https://mathurinm.github.io/
.. _Mats van Es: https://github.com/matsvanes
Expand Down
176 changes: 159 additions & 17 deletions mne/decoding/search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging

import numpy as np
from scipy.stats import rankdata
from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone
from sklearn.metrics import check_scoring
from sklearn.preprocessing import LabelEncoder
Expand Down Expand Up @@ -680,14 +681,15 @@ def _gl_transform(estimators, X, method, pb):

Returns
-------
Xt : array, shape (n_samples, n_slices)
The transformed values generated by each estimator.
"""
Xt : array, shape (n_samples, n_estimators, n_slices[, n_classes])
Predictions for each (estimator, slice) pair. The trailing axis is
present when ``method`` returns multi-output (e.g. ``predict_proba``).
""" # noqa: E501
n_sample, n_iter = X.shape[0], X.shape[-1]
# stack generalized data for faster prediction
X_stack = X.transpose(np.r_[0, X.ndim - 1, range(1, X.ndim - 1)])
X_stack = X_stack.reshape((n_sample * n_iter,) + X_stack.shape[2:])
for ii, est in enumerate(estimators):
# stack generalized data for faster prediction
X_stack = X.transpose(np.r_[0, X.ndim - 1, range(1, X.ndim - 1)])
X_stack = X_stack.reshape(np.r_[n_sample * n_iter, X_stack.shape[2:]])
transform = getattr(est, method)
_y_pred = transform(X_stack)
# unstack generalizations
Expand Down Expand Up @@ -715,6 +717,100 @@ def _gl_init_pred(y_pred, X, n_train):
return y_pred


def _resolve_scoring_for_classifier(scoring, estimators):
"""Promote scoring=None to 'accuracy' for default estimator.

scoring=None goes through sklearn's ``_PassthroughScorer``, which delegates
to ``estimator.score(X, y)``. For a classifier that inherits
``ClassifierMixin.score`` unchanged, that's accuracy — which we can batch.
We compare ``type(est).score.__qualname__`` rather than ``.__name__``
because the bare name is "score" regardless of the defining class. A bare
method has qualname "ClassifierMixin.score"; any override resolves to
"<Subclass>.score", which we leave untouched.
"""
if len(estimators) and getattr(scoring, "_score_func", None) is None:
qname = getattr(type(estimators[0]).score, "__qualname__", "")
if qname == "ClassifierMixin.score":
scoring = check_scoring(estimators[0], "accuracy")
return scoring


def _detect_response_method(scoring):
"""Return (response_method, can_batch) for ``scoring``.

If we can batch the estimator (one of predict, predict_proba,
decision_function, or a tuple of those) then we return that, and
can_batch is True if additionally _score_func is not None.
Otherwise we return None (for the response_method) and False
"""
score_func = getattr(scoring, "_score_func", None)
rm = getattr(scoring, "_response_method", None)
valid = {"predict", "predict_proba", "decision_function"}
if rm == "default":
response_method = "predict"
elif isinstance(rm, str) and rm in valid:
response_method = rm
elif isinstance(rm, tuple) and all(m in valid for m in rm):
response_method = rm
else:
response_method = None
can_batch = score_func is not None and response_method is not None
return response_method, can_batch


def _make_batched_score(score_func, response_method, method, y, sign, kwargs):
"""Return a callable ``y_pred``->score per task, None if not recognised.

The returned callable expects ``y_pred`` of shape
``(n_sample, n_train, n_iter)`` and returns shape ``(n_train, n_iter)``.
Falls back to None for any scorer with non-default ``kwargs`` or
multi-target ``y``, both of which require a slice-by-slice loop.
"""
if kwargs or y.ndim != 1:
return None
name = getattr(score_func, "__name__", "")

if name == "accuracy_score" and response_method == "predict":

def batched_score(y_pred):
return sign * (y_pred == y[:, None, None]).mean(axis=0)

return batched_score

if name == "balanced_accuracy_score" and response_method == "predict":
classes = np.unique(y)

def batched_score(y_pred):
return sign * np.stack(
[(y_pred[y == c] == c).mean(axis=0) for c in classes]
).mean(axis=0)

return batched_score

if name == "roc_auc_score" and method in ("predict_proba", "decision_function"):
classes = np.unique(y)
if len(classes) != 2: # multi-class needs ovr/ovo; defer
return None
pos = y == classes[1]
n_pos, n_neg = int(pos.sum()), int((~pos).sum())
if not (n_pos and n_neg): # degenerate folds raise downstream in sklearn
return None

def batched_score(y_pred):
# Mann-Whitney U identity with average-rank tie correction.
# Equivalent to sklearn's roc_auc within floating point precision,
# but different computation.
ranks = rankdata(y_pred, method="average", axis=0)
return (
sign
* (ranks[pos].sum(axis=0) - n_pos * (n_pos + 1) / 2.0)
/ (n_pos * n_neg)
)

return batched_score
return None


def _gl_score(estimators, scoring, X, y, pb):
"""Score GeneralizingEstimator in parallel.

Expand Down Expand Up @@ -743,17 +839,63 @@ def _gl_score(estimators, scoring, X, y, pb):
"""
# FIXME: The level parallelization may be a bit high, and might be memory
# consuming. Perhaps need to lower it down to the loop across X slices.
score_shape = [len(estimators), X.shape[-1]]
for jj in range(X.shape[-1]):
for ii, est in enumerate(estimators):
_score = scoring(est, X[..., jj], y)
# Initialize array of predictions on the first score iteration
if (ii == 0) and (jj == 0):
dtype = type(_score)
score = np.zeros(score_shape, dtype)
score[ii, jj, ...] = _score

pb.update(jj * len(estimators) + ii + 1)
n_iter = X.shape[-1]
n_train = len(estimators)
score_shape = [n_train, n_iter]

scoring = _resolve_scoring_for_classifier(scoring, estimators)
response_method, can_batch = _detect_response_method(scoring)

# If we can't batch, fall back to a simple nested loop for scoring
if not can_batch:
for jj in range(n_iter):
for ii, est in enumerate(estimators):
_score = scoring(est, X[..., jj], y)
if (ii == 0) and (jj == 0):
score = np.zeros(score_shape, type(_score))
score[ii, jj, ...] = _score
pb.update(jj * n_train + ii + 1)
return score

# Resolve a single method name for _gl_transform: pick the first available
# if response_method is a tuple (e.g. roc_auc).
if isinstance(response_method, str):
method = response_method
else:
for m in response_method:
if hasattr(estimators[0], m):
method = m
break

# Batch all predictions through _gl_transform. y_pred shape:
# (n_sample, n_train, n_iter) or (n_sample, n_train, n_iter, n_classes).
y_pred = _gl_transform(estimators, X, method, pb)

# Binary predict_proba: take the positive-class column to match sklearn
# scorer expectations for binary problems.
if method == "predict_proba" and y_pred.ndim == 4 and y_pred.shape[-1] == 2:
y_pred = y_pred[..., 1]

# `scoring._kwargs or {}` also guards score_func(..., **kwargs) against
# scoring._kwargs being None.
score_func = scoring._score_func
sign = scoring._sign
kwargs = scoring._kwargs or {}
batched_score = _make_batched_score(
score_func, response_method, method, y, sign, kwargs
)

# Reduce predictions to scores. Vectorised if we recognised the scorer,
# otherwise nested loops over (estimator, slice).
if batched_score is not None:
score = batched_score(y_pred)
else:
for ii in range(n_train):
for jj in range(n_iter):
_score = sign * score_func(y, y_pred[:, ii, jj], **kwargs)
if (ii == 0) and (jj == 0):
score = np.zeros(score_shape, type(_score))
score[ii, jj, ...] = _score
return score


Expand Down
86 changes: 82 additions & 4 deletions mne/decoding/tests/test_search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
# Copyright the MNE-Python contributors.

import platform
from functools import wraps
from inspect import signature

import numpy as np
import pytest
from numpy.testing import assert_array_equal, assert_equal
from numpy.testing import assert_allclose, assert_array_equal, assert_equal

sklearn = pytest.importorskip("sklearn")

from sklearn.base import BaseEstimator, clone, is_classifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import BaggingClassifier
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
from sklearn.metrics import make_scorer, roc_auc_score
from sklearn.metrics import check_scoring, make_scorer, roc_auc_score
from sklearn.model_selection import cross_val_predict
from sklearn.multiclass import OneVsRestClassifier
from sklearn.pipeline import make_pipeline
Expand Down Expand Up @@ -246,7 +247,11 @@ def test_generalization_light(metadata_routing):
gl.fit(X, y)
score = gl.score(X, y)
auc = roc_auc_score(y, gl.estimators_[0].predict_proba(X[..., 0])[..., 1])
assert_equal(score[0, 0], auc)

# The rank identity implemented when batching gives the same AUC as sklearn
# within floating point precision, but implements it with different
# operations. A bit-exact match would need a loop, defeating the batching.
assert_allclose(score[0, 0], auc)

for scoring in ["foo", 999]:
gl = GeneralizingEstimator(logreg, scoring=scoring)
Expand All @@ -267,7 +272,8 @@ def test_generalization_light(metadata_routing):
[roc_auc_score(y - 1, _y_pred) for _y_pred in _y_preds]
for _y_preds in gl.decision_function(X).transpose(1, 2, 0)
]
assert_array_equal(score, manual_score)
# allclose instead of equal: see above, batching roc_auc forces this.
assert_allclose(score, manual_score)

# n_jobs
gl = GeneralizingEstimator(logreg, n_jobs=2)
Expand All @@ -293,6 +299,78 @@ def test_generalization_light(metadata_routing):
assert_array_equal(y_preds[0], y_preds[1])


@pytest.mark.parametrize(
"scoring, est_name, method",
[
(None, "logreg", "predict"),
("accuracy", "logreg", "predict"),
("balanced_accuracy", "logreg", "predict"),
("roc_auc", "logreg", "decision_function"),
("neg_log_loss", "logreg", "predict_proba"),
(None, "ridge", "predict"),
("roc_auc_multiclass", "logreg", "predict_proba"),
("accuracy_kwargs", "logreg", "predict"),
],
)
def test_gl_score_branches(scoring, est_name, method):
"""Test _gl_score against its own can_batch=False nested-loop fallback."""
n_trials, n_sensors, n_iter = 12, 3, 4
rng = np.random.RandomState(0)
X = rng.randn(n_trials, n_sensors, n_iter)
y = rng.randint(0, 3 if scoring == "roc_auc_multiclass" else 2, n_trials)
per_slice = scoring in ("neg_log_loss", "roc_auc_multiclass", "accuracy_kwargs")
# liblinear is binary-only, switch to lbfgs for the multi-class case.
solver = "lbfgs" if scoring == "roc_auc_multiclass" else "liblinear"
if scoring == "roc_auc_multiclass":
scoring = make_scorer(
roc_auc_score, response_method="predict_proba", multi_class="ovr"
)
elif scoring == "accuracy_kwargs":
# start from the default scorer but add a kwarg to prevent batching
acc_func = check_scoring(LogisticRegression(), "accuracy")._score_func
scoring = make_scorer(acc_func, normalize=False)
est = Ridge() if est_name == "ridge" else LogisticRegression(solver=solver)
gl = GeneralizingEstimator(est, scoring=scoring).fit(X, y)

# Measure batching: count pred/call scores. Wraps `fn` calls so they append
# to a bucket; @wraps preserves __name__ (needed as _gl_score matches it)
def counting(fn, bucket):
@wraps(fn)
def wrapped(*a, **k):
bucket.append(1)
return fn(*a, **k)

return wrapped

# First we count calls to scorer
score_calls = []
scorer = check_scoring(est, scoring)
if getattr(scorer, "_score_func", None) is not None:
scorer._score_func = counting(scorer._score_func, score_calls)

# Now we count calls to the estimator that _gl_score will call (hardcoded)
pred_calls = []
for e in gl.estimators_:
setattr(e, method, counting(getattr(e, method), pred_calls))

# Batched path: assert call counts immediately so the buckets only reflect
# this run (the reference run below would otherwise add to them).
gl.scoring = scorer
actual = gl.score(X, y)
assert len(pred_calls) == (n_iter if est_name != "ridge" else n_iter**2)
assert len(score_calls) == (n_iter**2 if per_slice else 0)

# Reference: force can_batch=False. _score_func set (non-None) bypasses the
# qname coercion; missing _response_method makes can_batch False.
def force_fallback(e, X, y):
return scorer(e, X, y)

force_fallback._score_func = id
gl.scoring = force_fallback
expected = gl.score(X, y)
assert_allclose(actual, expected)


@pytest.mark.parametrize(
"n_jobs, verbose", [(1, False), (2, False), (1, True), (2, "info")]
)
Expand Down