diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 4448bce2c..f8216502f 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -19,6 +19,7 @@ Version 1.5 (Source - GitHub) Enhancements ~~~~~~~~~~~~ +- Implementation of Pseudo Online framework (:gh:`641` by `Igor Carrara`_) - Introduce a new logo for the MOABB library (:gh:`858` by `Pierre Guetschel`_ and community) - Better verbosity control for initialization of the library (:gh:`850` by `Bruno Aristimunha`_) - Ability to join rows from the tables of MOABB predictive performance scores and detailed CodeCarbon compute profiling metrics by the column `codecarbon_task_name` in MOABB results and the column `task_name` in CodeCarbon results (:gh:`866` by `Ethan Davis`_). @@ -215,6 +216,7 @@ Enhancements - Add choice to choose the size of time window (by `Sebastien Velut`_) + Bugs ~~~~ - Fix caching in the workflows (:gh:`632` by `Pierre Guetschel`_) diff --git a/examples/plot_pseudoonline.py b/examples/plot_pseudoonline.py new file mode 100644 index 000000000..a84cd66b5 --- /dev/null +++ b/examples/plot_pseudoonline.py @@ -0,0 +1,60 @@ +# Set up the Directory for made it run on a server. + +import numpy as np +from pyriemann.classification import MDM, FgMDM +from pyriemann.estimation import Covariances +from sklearn.pipeline import Pipeline + +from moabb.datasets import BNCI2014_001 +from moabb.evaluations import WithinSessionEvaluation +from moabb.paradigms import MotorImagery + + +sub = 1 + +# Initialize parameter for the Band Pass filter +fmin = 8 +fmax = 30 +tmax = 3 + +# Load Dataset and switch to Pseudoonline mode +dataset = BNCI2014_001() +dataset.pseudoonline = True + +# events = ["right_hand", "left_hand"] +events = list(dataset.event_id.keys()) + +paradigm = MotorImagery( + events=events, n_classes=len(events), fmin=fmin, fmax=fmax, tmax=tmax, overlap=50 +) + +X, y, meta = paradigm.get_data(dataset=dataset, subjects=[sub]) +print("Print Events_id:", y) +unique, counts = np.unique(y, return_counts=True) +print("Number of events per class:", dict(zip(unique, counts))) + + +pipelines = {} +pipelines["MDM"] = Pipeline( + steps=[ + ("Covariances", Covariances("cov")), + ("MDM", MDM(metric=dict(mean="riemann", distance="riemann"))), + ] +) + +pipelines["FgMDM"] = Pipeline( + steps=[("Covariances", Covariances("cov")), ("FgMDM", FgMDM())] +) + +dataset.subject_list = dataset.subject_list[int(sub) - 1 : int(sub)] +# Select an evaluation Within Session +evaluation_online = WithinSessionEvaluation( + paradigm=paradigm, datasets=dataset, overwrite=True, random_state=42, n_jobs=1 +) + +# Print the results +results_ALL = evaluation_online.process(pipelines) +results_pipeline = results_ALL.groupby(["pipeline"], as_index=False)["score"].mean() +results_pipeline_std = results_ALL.groupby(["pipeline"], as_index=False)["score"].std() +results_pipeline["std"] = results_pipeline_std["score"] +print(results_pipeline) diff --git a/moabb/datasets/base.py b/moabb/datasets/base.py index e2d65f0e9..6f116ca79 100644 --- a/moabb/datasets/base.py +++ b/moabb/datasets/base.py @@ -361,6 +361,7 @@ def __init__( paradigm, doi=None, unit_factor=1e6, + overlap=False, ): """Initialize function for the BaseDataset.""" try: @@ -390,6 +391,7 @@ def __init__( self.paradigm = paradigm self.doi = doi self.unit_factor = unit_factor + self.overlap = overlap def _create_process_pipeline(self): return FixedPipeline( diff --git a/moabb/datasets/bnci.py b/moabb/datasets/bnci.py index bfec02708..78f4c9315 100644 --- a/moabb/datasets/bnci.py +++ b/moabb/datasets/bnci.py @@ -5,7 +5,7 @@ from pathlib import Path import numpy as np -from mne import Annotations, create_info +from mne import Annotations, create_info, find_events from mne.channels import make_standard_montage from mne.io import RawArray from mne.utils import verbose @@ -37,6 +37,7 @@ def load_data( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): # noqa: D301 """Get paths to local copies of a BNCI dataset files. @@ -122,6 +123,7 @@ def load_data( baseurl_list[dataset], only_filenames, verbose, + pseudoonline, ) @@ -185,6 +187,7 @@ def _load_data_001_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): """Load data for 001-2014 dataset.""" if (subject < 1) or (subject > 9): @@ -201,13 +204,21 @@ def _load_data_001_2014( sessions = {} filenames = [] + time_task = 4 + time_fix = 2 for session_idx, r in enumerate(["T", "E"]): url = "{u}001-2014/A{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) filename = data_path(url, path, force_update, update_path) filenames += filename if only_filenames: continue - runs, ev = _convert_mi(filename[0], ch_names, ch_types) + + if pseudoonline: + runs, ev = _convert_mi_pseudoonline( + filename[0], time_task, time_fix, ch_names, ch_types, pseudoonline + ) + else: + runs, ev = _convert_mi(filename[0], ch_names, ch_types) # FIXME: deal with run with no event (1:3) and name them sessions[f"{session_idx}{_map[r]}"] = { str(ii): run for ii, run in enumerate(runs) @@ -226,12 +237,15 @@ def _load_data_002_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): """Load data for 002-2014 dataset.""" if (subject < 1) or (subject > 14): raise ValueError("Subject must be between 1 and 14. Got %d." % subject) runs = [] + time_task = 5 + time_fix = 3 filenames = [] for r in ["T", "E"]: url = "{u}002-2014/S{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) @@ -240,7 +254,12 @@ def _load_data_002_2014( if only_filenames: continue # FIXME: electrode position and name are not provided directly. - raws, _ = _convert_mi(filename, None, ["eeg"] * 15) + if pseudoonline: + raws, _ = _convert_mi_pseudoonline( + filename, time_task, time_fix, None, ["eeg"] * 15, pseudoonline + ) + else: + raws, _ = _convert_mi(filename, None, ["eeg"] * 15) runs.extend(zip([r] * len(raws), raws)) if only_filenames: return filenames @@ -257,6 +276,7 @@ def _load_data_004_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): """Load data for 004-2014 dataset.""" if (subject < 1) or (subject > 9): @@ -266,6 +286,8 @@ def _load_data_004_2014( ch_types = ["eeg"] * 3 + ["eog"] * 3 sessions = [] + time_task = 4.5 + time_fix = 3 filenames = [] for r in ["T", "E"]: url = "{u}004-2014/B{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) @@ -273,7 +295,12 @@ def _load_data_004_2014( filenames.append(filename) if only_filenames: continue - raws, _ = _convert_mi(filename, ch_names, ch_types) + if pseudoonline: + raws, _ = _convert_mi_pseudoonline( + filename, time_task, time_fix, ch_names, ch_types, pseudoonline + ) + else: + raws, _ = _convert_mi(filename, ch_names, ch_types) sessions.extend(zip([r] * len(raws), raws)) if only_filenames: @@ -291,7 +318,14 @@ def _load_data_008_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) + """Load data for 008-2014 dataset.""" if (subject < 1) or (subject > 8): raise ValueError("Subject must be between 1 and 8. Got %d." % subject) @@ -317,7 +351,12 @@ def _load_data_009_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 009-2014 dataset.""" if (subject < 1) or (subject > 10): raise ValueError("Subject must be between 1 and 10. Got %d." % subject) @@ -356,6 +395,7 @@ def _load_data_001_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): """Load data for 001-2015 dataset.""" if (subject < 1) or (subject > 12): @@ -375,6 +415,8 @@ def _load_data_001_2015( ch_types = ["eeg"] * 13 sessions = {} + time_task = 5 + time_fix = 0 filenames = [] for session_idx, r in ses: url = "{u}001-2015/S{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) @@ -382,7 +424,12 @@ def _load_data_001_2015( filenames += filename if only_filenames: continue - runs, ev = _convert_mi(filename[0], ch_names, ch_types) + if pseudoonline: + runs, ev = _convert_mi_pseudoonline( + filename[0], time_task, time_fix, ch_names, ch_types, pseudoonline + ) + else: + runs, ev = _convert_mi(filename[0], ch_names, ch_types) sessions[f"{session_idx}{r}"] = {str(ii): run for ii, run in enumerate(runs)} if only_filenames: return filenames @@ -398,7 +445,12 @@ def _load_data_003_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 003-2015 dataset.""" if (subject < 1) or (subject > 10): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -457,7 +509,12 @@ def _load_data_004_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 004-2015 dataset.""" if (subject < 1) or (subject > 9): raise ValueError("Subject must be between 1 and 9. Got %d." % subject) @@ -491,7 +548,12 @@ def _load_data_009_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 009-2015 dataset.""" if (subject < 1) or (subject > 21): raise ValueError("Subject must be between 1 and 21. Got %d." % subject) @@ -522,7 +584,12 @@ def _load_data_010_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 010-2015 dataset.""" if (subject < 1) or (subject > 12): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -554,7 +621,12 @@ def _load_data_012_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 012-2015 dataset.""" if (subject < 1) or (subject > 12): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -581,7 +653,12 @@ def _load_data_013_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False, ): + if pseudoonline: + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 013-2015 dataset.""" if (subject < 1) or (subject > 6): raise ValueError("Subject must be between 1 and 6. Got %d." % subject) @@ -641,6 +718,50 @@ def _convert_mi(filename, ch_names, ch_types): return runs, event_id +def _convert_mi_pseudoonline( + filename, time_task, time_fix, ch_names, ch_types, pseudoonline +): + """Process (Graz) motor imagery data from MAT files. + + Parameters + ---------- + filename : str + Path to the MAT file. + time_task: float + Actual duration of the task + time_fix: + Duration of Fixation Cross + ch_names : list of str + List of channel names. + ch_types : list of str + List of channel types. + + Returns + ------- + raw : instance of RawArray + returns list of recording runs.""" + runs = [] + event_id = {} + data = loadmat(filename, struct_as_record=False, squeeze_me=True) + + if isinstance(data["data"], np.ndarray): + run_array = data["data"] + else: + run_array = [data["data"]] + + for run in run_array: + raw, evd = _convert_run_pseudoonline( + run, time_task, time_fix, ch_names, ch_types, None, pseudoonline + ) + if raw is None: + continue + runs.append(raw) + event_id.update(evd) + # change labels to match rest + standardize_keys(event_id) + return runs, event_id + + def standardize_keys(d): master_list = [ ["both feet", "feet"], @@ -691,6 +812,72 @@ def _convert_run(run, ch_names=None, ch_types=None, verbose=None): return raw, event_id +def _convert_run_pseudoonline( + run, + time_task, + time_fix, + ch_names=None, + ch_types=None, + verbose=None, + pseudoonline=False, +): + """Convert one run to raw.""" + # parse eeg data + event_id = {} + n_chan = run.X.shape[1] + montage = make_standard_montage("standard_1005") + eeg_data = 1e-6 * run.X + sfreq = run.fs + + if not ch_names: + ch_names = ["EEG%d" % ch for ch in range(1, n_chan + 1)] + montage = None # no montage + + if not ch_types: + ch_types = ["eeg"] * n_chan + + trigger = np.zeros((len(eeg_data), 1)) + # some runs does not contains trials i.e baseline runs + if len(run.trial) > 0: + trigger[run.trial - 1, 0] = run.y + else: + return None, None + + eeg_data = np.c_[eeg_data, trigger] + ch_names = ch_names + ["stim"] + ch_types = ch_types + ["stim"] + event_id = {ev: (ii + 1) for ii, ev in enumerate(run.classes)} + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) + raw = RawArray(data=eeg_data.T, info=info, verbose=verbose) + raw.set_montage(montage) + + if pseudoonline: + # ================================================================================================================= + # Code to add the event Nothing with label 9 + # ================================================================================================================= + # The idea is to replace the old stim channel with a new STIM channel that locate the events at the exact time that + # start and the event also for the nothing phase. + events = find_events(raw, stim_channel="stim") + stim_data = np.zeros((1, len(raw.times))) + + # Time when the task finish + time_nothing = (sfreq * time_task) + 1 + # Time where the task actually begin, because the events of "stim" give us when the fix cross appear, but not when + # the task begin. + time_fixation_cross = sfreq * time_fix + for i in np.arange(len(events[:, 0])): + stim_data[0, int(events[i, 0] + time_fixation_cross)] = events[i, 2] + stim_data[0, int(events[i, 0] + time_fixation_cross + time_nothing)] = 9 + + info = create_info(ch_names=["STI"], ch_types=["stim"], sfreq=sfreq) + new_stim = RawArray(data=stim_data, info=info, verbose=verbose) + raw.add_channels([new_stim], force_update_info=True) + raw.drop_channels(["stim"]) # Delete old stim channel + event_id["nothing"] = 9 + + return raw, event_id + + @verbose def _convert_run_p300_sl(run, verbose=None): """Convert one p300 run from santa lucia file format.""" @@ -886,9 +1073,16 @@ def _convert_run_epfl(run, verbose=None): class MNEBNCI(BaseDataset): """Base BNCI dataset.""" + pseudoonline = False + def _get_single_subject_data(self, subject): """Return data for a single subject.""" - sessions = load_data(subject=subject, dataset=self.code, verbose=False) + sessions = load_data( + subject=subject, + dataset=self.code, + verbose=False, + pseudoonline=self.pseudoonline, + ) return sessions def data_path( @@ -902,6 +1096,7 @@ def data_path( path=path, force_update=force_update, only_filenames=True, + pseudoonline=self.pseudoonline, ) @@ -997,7 +1192,13 @@ def __init__(self): super().__init__( subjects=list(range(1, 10)), sessions_per_subject=2, - events={"left_hand": 1, "right_hand": 2, "feet": 3, "tongue": 4}, + events={ + "left_hand": 1, + "right_hand": 2, + "feet": 3, + "tongue": 4, + "nothing": 9, + }, code="BNCI2014-001", interval=[2, 6], paradigm="imagery", @@ -1050,7 +1251,7 @@ def __init__(self): super().__init__( subjects=list(range(1, 15)), sessions_per_subject=1, - events={"right_hand": 1, "feet": 2}, + events={"right_hand": 1, "feet": 2, "nothing": 9}, code="BNCI2014-002", interval=[3, 8], paradigm="imagery", @@ -1124,7 +1325,7 @@ def __init__(self): super().__init__( subjects=list(range(1, 10)), sessions_per_subject=5, - events={"left_hand": 1, "right_hand": 2}, + events={"left_hand": 1, "right_hand": 2, "nothing": 9}, code="BNCI2014-004", interval=[3, 7.5], paradigm="imagery", @@ -1286,7 +1487,7 @@ def __init__(self): super().__init__( subjects=list(range(1, 13)), sessions_per_subject=2, - events={"right_hand": 1, "feet": 2}, + events={"right_hand": 1, "feet": 2, "nothing": 9}, code="BNCI2015-001", interval=[0, 5], paradigm="imagery", diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index 590ed91e2..eea0ccbbb 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -81,6 +81,55 @@ def _unsafe_pick_events(events, include): raise e +def _events_pseudoonline(events, tmin, tmax, sfreq, overlap): + """ + This function create new events every duration length. + :param events: Real event created during registrations of the dataset + :param tmin: Minimum time where create new events(tmin MUST be 0). Is the starting time of epoch, and we consider as starting time + the initial value of the interval in normal MOABB [2, 6] + :param tmax: Maximum time of the windows. Is the final time of epoch. + :param sfreq: Sfreq of the recorded signal + :param overlap: Percentage of overlapping that we want in the sliding windows + :return: + return the new events, ove every starting point of the sliding windows and with univocal label + """ + # Compute duration of the windows in seconds + duration_s = tmax - tmin + # Convert the duration in time point. + duration = duration_s * sfreq + # The starting point of the new windows in time point + ove = (((tmax - tmin) / 100) * (100 - overlap)) * sfreq + + # Total number of new events that need to be created + total = int((events[-1, 0] - events[0, 0]) / (100 - overlap)) + events_new = np.zeros((total, 3), dtype=int) + # Fill the first event with the same old events + events_new[0, :] = events[0, :] + + j = 0 + i = 1 + # Go on while we are at a time sample less than the last events in the data acquisition + while events_new[i - 1, 0] + duration <= events[-1, -0]: + # Assign the time stamp to the new events, so we add ove + events_new[i, 0] = events_new[i - 1, 0] + ove + # Now we have to check. If the new added events plus the duration is less then the time stamp of the new event + # we assign an univocal label. If is not we check the percentage of time stamp associate with a label is predominant in a windows. + # If we have 50/50 we assign the label as the next event since the subject want to switch in that direction. + if events_new[i, 0] + duration <= events[j + 1, 0]: + events_new[i, 2] = events[j, 2] + else: + First = abs(events[j + 1, 0] - events_new[i, 0]) + Second = abs((events_new[i, 0] + duration) - events[j + 1, 0]) + if First > Second: + events_new[i, 2] = events[j, 2] + else: + events_new[i, 2] = events[j + 1, 2] + j = j + 1 + i = i + 1 + + return events_new + + class ForkPipelines(TransformerMixin, BaseEstimator): def __init__(self, transformers: List[Tuple[str, Union[Pipeline, TransformerMixin]]]): for _, t in transformers: @@ -196,6 +245,59 @@ def transform(self, raw, y=None): events = mne.find_events(raw, shortest_event=0, verbose=False) events = _unsafe_pick_events(events, include=_get_event_id_values(self.event_id)) events[:, 0] += offset + + if len(events) != 0: + annotations = mne.annotations_from_events( + events, + raw.info["sfreq"], + self.event_desc, + first_samp=raw.first_samp, + verbose=False, + ) + annotations.set_durations(duration) + raw.set_annotations(annotations) + # raw.plot() + # print("OK") + else: + log.warning("No events found, skipping setting annotations.") + return raw + + +class SetRawAnnotations_PseudoOnline(FixedTransformer): + """ + Always sets the annotations, even if the events list is empty + """ + + def __init__(self, event_id, interval: Tuple[float, float], tmin, tmax, overlap): + assert isinstance(event_id, dict) # not None + self.event_id = event_id + if len(set(event_id.values())) != len(event_id): + raise ValueError("Duplicate event code") + self.event_desc = dict((code, desc) for desc, code in self.event_id.items()) + self.interval = interval + self.overlap = overlap + self.tmin = tmin + self.tmax = tmax + + def transform(self, raw, y=None): + if raw.annotations: + return raw + stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) + if len(stim_channels) == 0: + log.warning( + "No stim channel nor annotations found, skipping setting annotations." + ) + return raw + events_ = mne.find_events(raw, shortest_event=0, verbose=False) + events = _events_pseudoonline( + events_, + tmin=self.tmin, + tmax=self.tmax, + sfreq=raw.info["sfreq"], + overlap=self.overlap, + ) + duration = self.tmax - self.tmin + if len(events) != 0: annotations = mne.annotations_from_events( events, @@ -206,6 +308,8 @@ def transform(self, raw, y=None): ) annotations.set_durations(duration) raw.set_annotations(annotations) + # raw.plot() + # print("OK") else: log.warning("No events found, skipping setting annotations.") return raw @@ -245,6 +349,54 @@ def transform(self, raw, y=None): return _unsafe_pick_events(events, list(self.event_id.values())) +class RawToEvents_PseudoOnline(FixedTransformer): + """ + Always returns an array for shape (n_events, 3), even if no events found + """ + + def __init__( + self, event_id: dict[str, int], interval: Tuple[float, float], tmin, tmax, overlap + ): + assert isinstance(event_id, dict) # not None + self.event_id = event_id + self.interval = interval + self.tmin = tmin + self.tmax = tmax + self.overlap = overlap + + def _find_events(self, raw): + stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) + if len(stim_channels) > 0: + # returns empty array if none found + if self.overlap is None: + events = mne.find_events(raw, shortest_event=0, verbose=False) + else: + events_ = mne.find_events(raw, shortest_event=0, verbose=False) + events = _events_pseudoonline( + events_, + tmin=self.tmin, + tmax=self.tmax, + sfreq=raw.info["sfreq"], + overlap=self.overlap, + ) + else: + try: + events, _ = mne.events_from_annotations( + raw, event_id=self.event_id, verbose=False + ) + offset = int(self.interval[0] * raw.info["sfreq"]) + events[:, 0] -= offset # return the original events onset + except ValueError as e: + if str(e) == "Could not find any of the events you specified.": + return np.zeros((0, 3), dtype="int32") + raise e + return events + + def transform(self, raw, y=None): + events = self._find_events(raw) + return _unsafe_pick_events(events, list(self.event_id.values())) + + class RawToEventsP300(RawToEvents): def __init__(self, event_id, interval, ignore_relabelling=False): self.ignore_relabelling = ignore_relabelling diff --git a/moabb/evaluations/utils.py b/moabb/evaluations/utils.py index ea5837e99..e4915fa83 100644 --- a/moabb/evaluations/utils.py +++ b/moabb/evaluations/utils.py @@ -9,7 +9,7 @@ from mne.utils.config import _open_lock from numpy import argmax from sklearn.base import ClassifierMixin -from sklearn.metrics import check_scoring +from sklearn.metrics import check_scoring, matthews_corrcoef from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline @@ -25,6 +25,11 @@ optuna_available = False +def _normalized_mcc(y_true, y_pred): + mcc = matthews_corrcoef(y_true, y_pred) + return (mcc + 1) / 2 + + def _ensure_fitted(estimator): """Ensure an estimator is properly marked as fitted for sklearn 1.8+. diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index ec7bfe5a7..b575b43dc 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -19,7 +19,9 @@ ForkPipelines, RawToEpochs, RawToEvents, + RawToEvents_PseudoOnline, SetRawAnnotations, + SetRawAnnotations_PseudoOnline, get_crop_pipeline, get_filter_pipeline, get_resample_pipeline, @@ -227,6 +229,7 @@ def __init__( baseline: Optional[Tuple[float, float]] = None, channels: Optional[List[str]] = None, resample: Optional[float] = None, + overlap: Optional[float] = None, ): if tmax is not None: if tmin >= tmax: @@ -238,6 +241,7 @@ def __init__( self.tmin = tmin self.tmax = tmax self.interpolate_missing_channels = False + self.overlap = overlap @property @abc.abstractmethod @@ -341,15 +345,29 @@ def make_process_pipelines( process_pipelines = [] for raw_pipeline in raw_pipelines: steps = [] - steps.append( - ( - StepType.RAW, - SetRawAnnotations( - dataset.event_id, - interval=dataset.interval, - ), + if self.overlap is not None: + steps.append( + ( + StepType.RAW, + SetRawAnnotations_PseudoOnline( + dataset.event_id, + interval=dataset.interval, + tmin=self.tmin, + tmax=self.tmax, + overlap=self.overlap, + ), + ) + ) + else: + steps.append( + ( + StepType.RAW, + SetRawAnnotations( + dataset.event_id, + interval=dataset.interval, + ), + ) ) - ) if raw_pipeline is not None: steps.append((StepType.RAW, raw_pipeline)) if epochs_pipeline is not None: @@ -718,6 +736,7 @@ def __init__( baseline=None, channels=None, resample=None, + overlap=None, scorer=None, ): super().__init__( @@ -727,6 +746,7 @@ def __init__( resample=resample, tmin=tmin, tmax=tmax, + overlap=overlap, ) self.events = events @@ -752,4 +772,16 @@ def scoring(self): def _get_events_pipeline(self, dataset): event_id = self.used_events(dataset) - return RawToEvents(event_id=event_id, interval=dataset.interval) + if self.overlap is not None: + return RawToEvents_PseudoOnline( + event_id=event_id, + interval=dataset.interval, + tmin=self.tmin, + tmax=self.tmax, + overlap=self.overlap, + ) + else: + return RawToEvents( + event_id=event_id, + interval=dataset.interval, + ) diff --git a/moabb/paradigms/motor_imagery.py b/moabb/paradigms/motor_imagery.py index befb9d50e..5448e42d0 100644 --- a/moabb/paradigms/motor_imagery.py +++ b/moabb/paradigms/motor_imagery.py @@ -3,8 +3,11 @@ import abc import logging +from sklearn.metrics import make_scorer + from moabb.datasets import utils from moabb.datasets.fake import FakeDataset +from moabb.evaluations.utils import _normalized_mcc from moabb.paradigms.base import BaseParadigm @@ -52,6 +55,8 @@ class BaseMotorImagery(BaseParadigm): resample: float | None (default None) If not None, resample the eeg data with the sampling rate provided. + overlap: Overlap (in percentage) of the sliding windows approach for the pseudoonline evaluation + scorer: sklearn-compatible string or a compatible sklearn scorer | None (default None) If None, and n_classes==2 use the roc_auc, else use accuracy. """ @@ -65,8 +70,14 @@ def __init__( baseline=None, channels=None, resample=None, + overlap=None, scorer=None, ): + + if overlap is not None: + print("Overlap available only for pseudo online evaluation") + tmin = 0.0 + super().__init__( filters=filters, events=events, @@ -75,6 +86,7 @@ def __init__( resample=resample, tmin=tmin, tmax=tmax, + overlap=overlap, scorer=scorer, ) @@ -109,7 +121,10 @@ def datasets(self): def scoring(self): if self.scorer is not None: return self.scorer - return "accuracy" + if self.overlap is None: + return "accuracy" + else: + return make_scorer(_normalized_mcc) class SinglePass(BaseMotorImagery): @@ -420,7 +435,10 @@ def scoring(self): return self.scorer if self.n_classes == 2: return "roc_auc" - return "accuracy" + if self.overlap is None: + return "accuracy" + else: + return make_scorer(_normalized_mcc) class FakeImageryParadigm(LeftRightImagery):