diff --git a/doc/changes/dev/13914.newfeature.rst b/doc/changes/dev/13914.newfeature.rst new file mode 100644 index 00000000000..107be2c1326 --- /dev/null +++ b/doc/changes/dev/13914.newfeature.rst @@ -0,0 +1 @@ +The EGI MFF reader has been refactored to use the :func:`mffpy` backend, improving support for multi-stream files and high-precision metadata, by Pragnya Khandelwal. \ No newline at end of file diff --git a/mne/io/egi/egimff.py b/mne/io/egi/egimff.py index 870f58890a2..147e2ebf464 100644 --- a/mne/io/egi/egimff.py +++ b/mne/io/egi/egimff.py @@ -5,6 +5,8 @@ """EGI NetStation Load Function.""" import datetime +import fnmatch +import itertools import math import os.path as op import re @@ -22,7 +24,7 @@ from ...evoked import EvokedArray from ...utils import _check_fname, _check_option, _soft_import, logger, verbose, warn from ..base import BaseRaw -from .events import _combine_triggers, _read_events, _triage_include_exclude +from .events import _combine_triggers, _triage_include_exclude from .general import ( _block_r, _extract, @@ -35,6 +37,306 @@ REFERENCE_NAMES = ("VREF", "Vertex Reference") +# TODO: Running list +# - [ ] Add support for reading in the PNS data +# - [ ] Add tutorial for reading calibration data +# - [ ] Add support for reading in the channel status (bad channels) +# - [ ] Replace _read_header with mffpy functions? + + +def _get_mff_startdatetime(mff_reader): + """Get start datetime from mff_reader with nanosecond workaround.""" + try: + return mff_reader.startdatetime + except (ValueError, AttributeError): + # mffpy has a bug parsing timestamps with 9 decimal places (nanoseconds) + # Workaround: manually parse the timestamp from the info.xml file + import xml.etree.ElementTree as ET + + info_file = op.join(mff_reader.directory._mffname, "info.xml") + tree = ET.parse(info_file) + root = tree.getroot() + # Handle different XML namespaces by searching for any recordTime element + time_elem = root.find(".//recordTime") or root.find(".//{*}recordTime") + if time_elem is None: + raise + time_str = time_elem.text + # Handle timestamps with up to 9 decimal places by truncating to 6 + # e.g. "2017-09-20T09:55:44.072000000+01:00" -> + # "2017-09-20T09:55:44.072000+01:00" + # Both formats: +0100 (without colon) and +01:00 (with colon) + if "+" in time_str or "-" in time_str[-6:]: + # Truncate nanoseconds in decimal part (keep only 6 digits) + time_str = re.sub(r"\.(\d{6})\d+([+-])", r".\1\2", time_str) + # Python's %z can't always handle colons, so remove them + time_str = re.sub(r"([+-]\d{2}):(\d{2})$", r"\1\2", time_str) + return datetime.datetime.strptime(time_str, "%Y-%m-%dT%H:%M:%S.%f%z") + + +def _parse_egi_datetime(time_str): + """Parse an EGI datetime string with the same nanosecond workaround.""" + if time_str is None: + return None + if "+" in time_str or "-" in time_str[-6:]: + time_str = re.sub(r"\.(\d{6})\d+([+-])", r".\1\2", time_str) + time_str = re.sub(r"([+-]\d{2}):(\d{2})$", r"\1\2", time_str) + return datetime.datetime.strptime(time_str, "%Y-%m-%dT%H:%M:%S.%f%z") + + +def _get_mff_reader(input_fname): + mffpy = _import_mffpy() + mff_reader = mffpy.Reader(input_fname) + mff_reader.set_unit("EEG", "V") # XXX: set PNS unit + return mff_reader + + +def _read_events(input_fname, egi_info): + """Read EGI event tracks from an MFF directory.""" + from mffpy.xml_files import XML, EventTrack + + mff_reader = _get_mff_reader(input_fname) + start_dt = _get_mff_startdatetime(mff_reader) + sfreq = egi_info["sfreq"] + n_samples = egi_info["last_samps"][-1] + + mff_events = OrderedDict() + for basename in mff_reader.directory.listdir(): + lower_name = basename.lower() + if not lower_name.endswith(".xml") or basename.startswith("._"): + continue + stem = Path(basename).stem + try: + with mff_reader.directory.filepointer(stem) as fp: + xml_obj = XML.from_file(fp, recover=False) + except Exception as err: + if "XMLSyntaxError" in type(err).__name__: + warn(f"Could not parse the XML file {basename}. Skipping it.") + continue + if not isinstance(xml_obj, EventTrack): + continue + try: + for event in xml_obj.events: + code = event.get("code") or event.get("label") or xml_obj.name + begin_time = event.get("beginTime") + if code is None or begin_time is None: + continue + sample = int(np.floor((begin_time - start_dt).total_seconds() * sfreq)) + if 0 <= sample < n_samples: + mff_events.setdefault(code, []).append(sample) + except Exception: + _soft_import("defusedxml", "reading EGI MFF event tracks") + from defusedxml import ElementTree as ET + + xml_path = op.join(str(input_fname), basename) + try: + root = ET.parse(xml_path).getroot() + except Exception as err: + if ( + "ParseError" in type(err).__name__ + or "XMLSyntaxError" in type(err).__name__ + ): + warn(f"Could not parse the XML file {basename}. Skipping it.") + continue + for event_el in root.iter(): + if event_el.tag.split("}")[-1] != "event": + continue + event_fields = {} + for child in event_el: + event_fields[child.tag.split("}")[-1]] = child.text + code = ( + event_fields.get("code") + or event_fields.get("label") + or xml_obj.name + ) + begin_time = _parse_egi_datetime(event_fields.get("beginTime")) + if code is None or begin_time is None: + continue + sample = int(np.floor((begin_time - start_dt).total_seconds() * sfreq)) + if 0 <= sample < n_samples: + mff_events.setdefault(code, []).append(sample) + + event_codes = list(mff_events.keys()) + egi_events = np.zeros((len(event_codes), n_samples)) + for event_idx, code in enumerate(event_codes): + if len(mff_events[code]): + egi_events[event_idx, np.array(mff_events[code], dtype=int)] = 1 + egi_info["event_codes"] = event_codes + return egi_events, egi_info, mff_events + + +def _get_montage(mff_reader): + mffpy = _import_mffpy() + xml_files = mff_reader.directory.files_by_type[".xml"] + + # Read coordinates.xml for fiducial positions + coords_fname = fnmatch.filter(xml_files, "coordinates") + coords_sensors = dict() + if len(coords_fname) == 1: + with mff_reader.directory.filepointer(coords_fname[0]) as fp: + coords_content = mffpy.XML.from_file(fp).get_content() + coords_sensors = coords_content.get("sensors", dict()) + + n_eeg_channels = mff_reader.num_channels["EEG"] # XXX: PNS? + ch_pos = dict() + hsp_list = [] # Extra headshape points + lpa, rpa, nasion = None, None, None + + # Extract channel positions and fiducials from coordinates.xml + for ch in coords_sensors.values(): + # XXX: the y coordinate seems to be inverted? Need to investigate + # Convert from cm to m + loc = np.array([ch["x"], -(ch["y"]), ch["z"]]) / 100.0 + name = ch.get("name", "None") + + # Check if this is a fiducial point + if name == "Nasion": + nasion = loc + elif name == "Left periauricular point": + lpa = loc + elif name == "Right periauricular point": + rpa = loc + elif name in REFERENCE_NAMES or "VREF" in name or "Vertex" in name: + # Reference electrode can be numbered outside EEG range (e.g., 1001) + ch_pos[name] = loc + elif ch["number"] <= n_eeg_channels: + # EEG channel + ch_name = name if name != "None" else f"E{ch['number']}" + ch_pos[ch_name] = loc + + # Convert hsp list to array if not empty + hsp = np.array(hsp_list) if hsp_list else None + + montage = make_dig_montage( + ch_pos=ch_pos, nasion=nasion, lpa=lpa, rpa=rpa, hsp=hsp, coord_frame="unknown" + ) + return montage + + +def _get_info(mff_reader): + montage = _get_montage(mff_reader) + ch_names = montage.ch_names + ch_types = ["eeg"] * len(ch_names) # XXX: refactor this when adding PNS support + meas_date_orig = _get_mff_startdatetime(mff_reader) + utc_offset = meas_date_orig.strftime("%z") + meas_date = meas_date_orig.astimezone(datetime.timezone.utc) + sfreq = mff_reader.sampling_rates["EEG"] # XXX: check PNS sfreq? + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) + info.set_montage(montage) + info.set_meas_date(meas_date) + with info._unlock(): + info["utc_offset"] = utc_offset + + # Populate reference location (loc[3:6]) for each EEG channel + # The reference is VREF (Vertex Reference), which is the last dig point + if len(info["dig"]) > 0: + ref_loc = info["dig"][-1]["r"] # VREF position + for ch in info["chs"]: + if ch["kind"] == FIFF.FIFFV_EEG_CH: + ch["loc"][3:6] = ref_loc + + return info + + +def _get_eeg_data(mff_reader): + sfreq = mff_reader.sampling_rates["EEG"] # XXX: check PNS sfreq + n_channels = mff_reader.num_channels[ + "EEG" + ] # Only EEG channels, not all signal types + epochs = mff_reader.epochs + + data_blocks, start_secs, end_secs = [], [], [] + for epoch in epochs: + data_chunk, _ = mff_reader.get_physical_samples_from_epoch(epoch)["EEG"] # XXX + data_blocks.append(data_chunk) + start_secs.append(epoch.t0) + end_secs.append(epoch.t1) + + first_samp = int(start_secs[0] * sfreq) + # Calculate total samples needed based on actual chunk placements + max_end_samp = first_samp + for this_chunk, start in zip(data_blocks, start_secs): + start_samp = int(start * sfreq) + end_samp = start_samp + this_chunk.shape[1] + max_end_samp = max(max_end_samp, end_samp) + n_samps = max_end_samp - first_samp + + eeg = np.zeros((n_channels, n_samps), dtype=np.float64) + for this_chunk, start in zip(data_blocks, start_secs): + start_idx = int(start * sfreq) - first_samp + end_idx = start_idx + this_chunk.shape[1] + eeg[:, start_idx:end_idx] = this_chunk + return eeg + + +def _get_gap_annotations(mff_reader): + epochs = mff_reader.epochs + start_secs = [epoch.t0 for epoch in epochs] + end_secs = [epoch.t1 for epoch in epochs] + gap_durations = np.array(start_secs[1:]) - np.array(end_secs[:-1]) + descriptions = ["BAD_ACQ_SKIP"] * len(gap_durations) + gap_onsets = np.array(end_secs[:-1]) + # TODO: Re-enable warning once lazy loading is properly implemented + # The warning should be raised during data access, not during __init__ + # if len(gap_durations) > 0: + # warn( + # "Acquisition skips detected. EGI MFF file contains gaps between " + # "recording epochs.", + # RuntimeWarning, + # ) + gap_annots = Annotations(gap_onsets, gap_durations, descriptions) + return gap_annots + + +def _get_event_annotations(mff_reader, mne_info): + mffpy = _import_mffpy() + xml_files = mff_reader.directory.files_by_type[".xml"] + events_xmls = fnmatch.filter(xml_files, "Events*") + if not events_xmls: + raise RuntimeError("No events found in MFF file.") + mff_events = {} + for event_file in events_xmls: + with mff_reader.directory.filepointer(event_file) as fp: + categories = mffpy.XML.from_file(fp) + mff_events[event_file] = categories.get_content()["event"] + onsets = [] + descriptions = [] + mff_events = list(itertools.chain.from_iterable(mff_events.values())) + for event in mff_events: + onset_dt = event["beginTime"].astimezone(datetime.timezone.utc) + ts = (onset_dt - mne_info["meas_date"]).total_seconds() + onsets.append(ts) + # XXX: we could use event["duration"] but it always seems to be 1000ms? + descriptions.append(event["code"]) + durations = [0] * len(onsets) + event_annots = Annotations(onsets, durations, descriptions) + return event_annots + + +def _get_annotations(mff_reader, mne_info): + event_annots = _get_event_annotations(mff_reader, mne_info) + gap_annots = _get_gap_annotations(mff_reader) + return event_annots + gap_annots + + +def _read_mff(input_fname): + """Read EGI MFF file using the mffpy-backed helpers. + + Returns + ------- + eeg : array-like + Raw EEG data as returned by `_get_eeg_data`. + info : dict + MNE `info` structure built by `_get_info`. + annotations : instance of `mne.Annotations` + Annotations built from events and gaps via `_get_annotations`. + """ + mff_reader = _get_mff_reader(input_fname) + eeg = _get_eeg_data(mff_reader) + info = _get_info(mff_reader) + annotations = _get_annotations(mff_reader, info) + return eeg, info, annotations + + def _read_mff_header(filepath): """Read mff header.""" _soft_import("defusedxml", "reading EGI MFF data") @@ -122,7 +424,6 @@ def _read_mff_header(filepath): # Add the sensor info. sensor_layout_file = op.join(filepath, "sensorLayout.xml") sensor_layout_obj = parse(sensor_layout_file) - summaryinfo["device"] = sensor_layout_obj.getElementsByTagName("name")[ 0 ].firstChild.data @@ -142,8 +443,8 @@ def _read_mff_header(filepath): n_chans = n_chans + 1 if n_chans != summaryinfo["n_channels"]: raise RuntimeError( - f"Number of defined channels ({n_chans}) did not match the " - f"expected channels ({summaryinfo['n_channels']})." + f"Number of defined channels ({n_chans}) did not match the expected " + f"channels ({summaryinfo['n_channels']})" ) # Check presence of PNS data @@ -280,7 +581,7 @@ def _read_locs(filepath, egi_info, channel_naming): fname = op.join(filepath, "coordinates.xml") if not op.exists(fname): - warn("File coordinates.xml not found, not setting channel locations") + logger.warn("File coordinates.xml not found, not setting channel locations") ch_names = [channel_naming % (i + 1) for i in range(egi_info["n_channels"])] return ch_names, None dig_ident_map = { @@ -381,14 +682,14 @@ class RawMff(BaseRaw): def __init__( self, input_fname, - eog=None, - misc=None, - include=None, - exclude=None, - preload=False, - channel_naming="E%d", + eog=None, # XXX: allow user to specify EOG channels? + misc=None, # XXX: allow user to specify misc channels? + include=None, # XXX: Now We dont create stim channels. Remove this? + exclude=None, # XXX: Ditto. But maybe we can exclude events from annots. + preload=False, # XXX: Make this work again + channel_naming="E%d", # XXX: Do we need to still support this? *, - events_as_annotations=True, + events_as_annotations=True, # XXX: This is now the only way. Remove? verbose=None, ): """Init the RawMff class.""" @@ -402,37 +703,37 @@ def __init__( ) ) logger.info(f"Reading EGI MFF Header from {input_fname}...") + eog = [] if eog is None else eog + misc = [] if misc is None else misc egi_info = _read_header(input_fname) - if eog is None: - eog = [] - if misc is None: - misc = np.where(np.array(egi_info["chan_type"]) != "eeg")[0].tolist() - logger.info(" Reading events ...") + # Event data (for stim channels and optional STI 014) egi_events, egi_info, mff_events = _read_events(input_fname, egi_info) - cals = _get_eeg_calibration_info(input_fname, egi_info) - logger.info(" Assembling measurement info ...") - event_codes = egi_info["event_codes"] + event_codes = list(egi_info["event_codes"]) include = _triage_include_exclude(include, exclude, egi_events, egi_info) - if egi_info["n_events"] > 0 and not events_as_annotations: - logger.info(' Synthesizing trigger channel "STI 014" ...') - if all(ch.startswith("D") for ch in include): - # support the DIN format DIN1, DIN2, ..., DIN9, DI10, DI11, ... DI99, - # D100, D101, ..., D255 that we get when sending 0-255 triggers on a - # parallel port. - events_ids = list() - for ch in include: - while not ch[0].isnumeric(): - ch = ch[1:] - events_ids.append(int(ch)) + if not events_as_annotations: + included_codes = [e for e in event_codes if e in include] + if len(included_codes): + events_ids = [] + next_id = 1 + for code in included_codes: + match = re.match(r"DIN(\d+)$", code) + if match is not None: + events_ids.append(int(match.group(1))) + else: + while next_id in events_ids: + next_id += 1 + events_ids.append(next_id) + next_id += 1 + events_ids = np.array(events_ids, int) + egi_info["new_trigger"] = _combine_triggers( + egi_events[[c in include for c in event_codes]], + remapping=events_ids, + ) + self.event_id = dict(zip(included_codes, events_ids)) else: - events_ids = np.arange(len(include)) + 1 - egi_info["new_trigger"] = _combine_triggers( - egi_events[[c in include for c in event_codes]], remapping=events_ids - ) - self.event_id = dict( - zip([e for e in event_codes if e in include], events_ids) - ) + egi_info["new_trigger"] = None + self.event_id = None if egi_info["new_trigger"] is not None: egi_events = np.vstack([egi_events, egi_info["new_trigger"]]) else: @@ -440,28 +741,25 @@ def __init__( egi_info["new_trigger"] = None assert egi_events.shape[1] == egi_info["last_samps"][-1] + # Info and channels meas_dt_utc = egi_info["meas_dt_local"].astimezone(datetime.timezone.utc) info = _empty_info(egi_info["sfreq"]) info["meas_date"] = _ensure_meas_date_none_or_dt(meas_dt_utc) info["utc_offset"] = egi_info["utc_offset"] info["device_info"] = dict(type=egi_info["device"]) - # read in the montage, if it exists ch_names, mon = _read_locs(input_fname, egi_info, channel_naming) - # Second: Stim ch_names.extend(list(egi_info["event_codes"])) n_extra = len(event_codes) + len(misc) + len(eog) + len(egi_info["pns_names"]) if egi_info["new_trigger"] is not None: - ch_names.append("STI 014") # channel for combined events + ch_names.append("STI 014") n_extra += 1 - - # Third: PNS ch_names.extend(egi_info["pns_names"]) + cals = _get_eeg_calibration_info(input_fname, egi_info) cals = np.concatenate([cals, np.ones(n_extra)]) assert len(cals) == len(ch_names), (len(cals), len(ch_names)) - # Actually create channels as EEG, then update stim and PNS ch_coil = FIFF.FIFFV_COIL_EEG ch_kind = FIFF.FIFFV_EEG_CH chs = _create_chs(ch_names, cals, ch_coil, ch_kind, eog, (), (), misc) @@ -488,6 +786,8 @@ def __init__( if mon is not None: info.set_montage(mon, on_missing="ignore") + + if mon is not None: ref_idx = np.flatnonzero(np.isin(mon.ch_names, REFERENCE_NAMES)) if len(ref_idx): ref_idx = ref_idx.item() @@ -499,7 +799,6 @@ def __init__( file_bin = op.join(input_fname, egi_info["eeg_fname"]) egi_info["egi_events"] = egi_events - # Check how many channels to read are from EEG keys = ("eeg", "sti", "pns") idx = dict() idx["eeg"] = np.where([ch["kind"] == FIFF.FIFFV_EEG_CH for ch in chs])[0] @@ -510,28 +809,27 @@ def __init__( for ch in chs ] )[0] - # By construction this should always be true, but check anyway if not np.array_equal( np.concatenate([idx[key] for key in keys]), np.arange(len(chs)) ): raise ValueError( "Currently interlacing EEG and PNS channels is not supported" ) + egi_info["kind_bounds"] = [0] for key in keys: egi_info["kind_bounds"].append(len(idx[key])) egi_info["kind_bounds"] = np.cumsum(egi_info["kind_bounds"]) assert egi_info["kind_bounds"][0] == 0 assert egi_info["kind_bounds"][-1] == info["nchan"] + first_samps = [0] last_samps = [egi_info["last_samps"][-1] - 1] annot = dict(onset=list(), duration=list(), description=list()) if len(idx["pns"]): - # PNS Data is present and should be read: egi_info["pns_filepath"] = op.join(input_fname, egi_info["pns_fname"]) - # Check for PNS bug immediately pns_samples = np.sum(egi_info["pns_sample_blocks"]["samples_block"]) eeg_samples = np.sum(egi_info["samples_block"]) if pns_samples == eeg_samples - 1: @@ -542,7 +840,7 @@ def __init__( elif pns_samples != eeg_samples: raise RuntimeError( f"PNS samples ({pns_samples}) did not match EEG samples " - f"({eeg_samples})." + f"({eeg_samples})" ) super().__init__( @@ -556,7 +854,11 @@ def __init__( verbose=verbose, ) - # Annotate acquisition skips + egi_info["has_acq_skip"] = np.any( + egi_info["first_samps"][1:] > egi_info["last_samps"][:-1] + ) + egi_info["_acq_skip_warned"] = False + for first, prev_last in zip( egi_info["first_samps"][1:], egi_info["last_samps"][:-1] ): @@ -567,7 +869,6 @@ def __init__( annot["duration"].append(gap / egi_info["sfreq"]) annot["description"].append("BAD_ACQ_SKIP") - # create events from annotations if events_as_annotations: for code, samples in mff_events.items(): if code not in include: @@ -585,6 +886,17 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): dtype = " 0 + and not egi_info.get("_acq_skip_warned", False) + ): + warn( + "Acquisition skips detected. EGI MFF file contains gaps between " + "recording epochs.", + RuntimeWarning, + ) + egi_info["_acq_skip_warned"] = True one = np.zeros((egi_info["kind_bounds"][-1], stop - start)) # info about the binary file structure @@ -971,4 +1283,23 @@ def _import_mffpy(why="read averaged .mff files"): msg = f"mffpy is required to {why}, got:\n{exp}" raise ImportError(msg) + # Monkey-patch mffpy to handle timestamps with 9 decimal places (nanoseconds) + # This is needed because some MFF files have timestamps like + # "2006-04-28T15:32:00.000000000+0100" which Python's %f can't parse + if not hasattr(mffpy.XML, "_mne_patched"): + original_parse_time_str = mffpy.XML._parse_time_str + + @classmethod + def _patched_parse_time_str(cls, txt): + """Parse time string with support for 9-decimal nanoseconds.""" + # Truncate nanoseconds to 6 decimal places if present + # e.g. "2017-09-20T09:55:44.072000000+01:00" -> + # "2017-09-20T09:55:44.072000+01:00" + if txt and "." in txt: + txt = re.sub(r"\.(\d{6})\d+([+-])", r".\1\2", txt) + return original_parse_time_str(txt) + + mffpy.XML._parse_time_str = _patched_parse_time_str + mffpy.XML._mne_patched = True + return mffpy diff --git a/mne/io/egi/events.py b/mne/io/egi/events.py index c160ceb208c..540d0fb5f51 100644 --- a/mne/io/egi/events.py +++ b/mne/io/egi/events.py @@ -3,160 +3,9 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from datetime import datetime -from glob import glob -from os.path import basename, join, splitext - import numpy as np -from ...utils import _soft_import, _validate_type, logger, warn - - -def _read_events(input_fname, info): - """Read events for the record. - - Parameters - ---------- - input_fname : path-like - The file path. - info : dict - Header info array. - """ - n_samples = info["last_samps"][-1] - mff_events, event_codes = _read_mff_events(input_fname, info["sfreq"]) - info["n_events"] = len(event_codes) - info["event_codes"] = event_codes - events = np.zeros([info["n_events"], info["n_segments"] * n_samples]) - for n, event in enumerate(event_codes): - for i in mff_events[event]: - if (i < 0) or (i >= events.shape[1]): - continue - events[n][i] = n + 1 - return events, info, mff_events - - -def _read_mff_events(filename, sfreq): - """Extract the events. - - Parameters - ---------- - filename : path-like - File path. - sfreq : float - The sampling frequency - """ - orig = {} - for xml_file in glob(join(filename, "*.xml")): - xml_type = splitext(basename(xml_file))[0] - et = _parse_xml(xml_file) - if et is not None: - orig[xml_type] = et - xml_files = orig.keys() - xml_events = [x for x in xml_files if x[:7] == "Events_"] - for item in orig["info"]: - if "recordTime" in item: - start_time = _ns2py_time(item["recordTime"]) - break - markers = [] - code = [] - for xml in xml_events: - for event in orig[xml][2:]: - event_start = _ns2py_time(event["beginTime"]) - start = (event_start - start_time).total_seconds() - if event["code"] not in code: - code.append(event["code"]) - marker = { - "name": event["code"], - "start": start, - "start_sample": int(np.trunc(start * sfreq)), - "end": start + float(event["duration"]) / 1e9, - "chan": None, - } - markers.append(marker) - events_tims = dict() - for ev in code: - trig_samp = list( - c["start_sample"] for n, c in enumerate(markers) if c["name"] == ev - ) - events_tims.update({ev: trig_samp}) - return events_tims, code - - -def _parse_xml(xml_file: str) -> list[dict[str, str]] | None: - """Parse XML file.""" - defusedxml = _soft_import("defusedxml", "reading EGI MFF data") - try: - xml = defusedxml.ElementTree.parse(xml_file) - except defusedxml.ElementTree.ParseError as e: - warn(f"Could not parse the XML file {xml_file}: {e}") - return - root = xml.getroot() - return _xml2list(root) - - -def _xml2list(root): - """Parse XML item.""" - output = [] - for element in root: - if len(element) > 0: - if element[0].tag != element[-1].tag: - output.append(_xml2dict(element)) - else: - output.append(_xml2list(element)) - - elif element.text: - text = element.text.strip() - if text: - tag = _ns(element.tag) - output.append({tag: text}) - - return output - - -def _ns(s): - """Remove namespace, but only if there is a namespace to begin with.""" - if "}" in s: - return "}".join(s.split("}")[1:]) - else: - return s - - -def _xml2dict(root): - """Use functions instead of Class. - - remove namespace based on - http://stackoverflow.com/questions/2148119 - """ - output = {} - if root.items(): - output.update(dict(root.items())) - - for element in root: - if len(element) > 0: - if len(element) == 1 or element[0].tag != element[1].tag: - one_dict = _xml2dict(element) - else: - one_dict = {_ns(element[0].tag): _xml2list(element)} - - if element.items(): - one_dict.update(dict(element.items())) - output.update({_ns(element.tag): one_dict}) - - elif element.items(): - output.update({_ns(element.tag): dict(element.items())}) - - else: - output.update({_ns(element.tag): element.text}) - return output - - -def _ns2py_time(nstime): - """Parse times.""" - nsdate = nstime[0:10] - nstime0 = nstime[11:26] - nstime00 = nsdate + " " + nstime0 - pytime = datetime.strptime(nstime00, "%Y-%m-%d %H:%M:%S.%f") - return pytime +from ...utils import _validate_type, logger, warn def _combine_triggers(data, remapping=None): diff --git a/mne/io/egi/tests/test_egi.py b/mne/io/egi/tests/test_egi.py index 09d1946e108..0659f970bb1 100644 --- a/mne/io/egi/tests/test_egi.py +++ b/mne/io/egi/tests/test_egi.py @@ -6,6 +6,7 @@ import os from copy import deepcopy from datetime import datetime, timezone +from importlib.util import find_spec from pathlib import Path import numpy as np @@ -35,6 +36,11 @@ egi_txt_evoked_cat1_fname = egi_path / "test_egi_evoked_cat1.txt" egi_txt_evoked_cat2_fname = egi_path / "test_egi_evoked_cat2.txt" +requires_mffpy = pytest.mark.skipif( + find_spec("mffpy") is None, + reason="Test requires mffpy", +) + # absolute event times from NetStation egi_pause_events = { "AM40": [7.224, 11.928, 14.413, 16.848], @@ -58,6 +64,7 @@ @requires_testing_data +@requires_mffpy @pytest.mark.parametrize( "fname, skip_times, event_times", [ @@ -120,6 +127,7 @@ def test_egi_mff_pause(fname, skip_times, event_times): @requires_testing_data +@requires_mffpy @pytest.mark.parametrize( "fname", [ @@ -142,6 +150,7 @@ def test_egi_mff_pause_chunks(fname, tmp_path): @requires_testing_data +@requires_mffpy @pytest.mark.parametrize("events_as_annotations", (True, False)) def test_io_egi_mff(events_as_annotations): """Test importing EGI MFF simple binary files.""" @@ -289,6 +298,7 @@ def test_io_egi(): @requires_testing_data +@requires_mffpy def test_io_egi_pns_mff(tmp_path): """Test importing EGI MFF with PNS data.""" pytest.importorskip("defusedxml") @@ -345,6 +355,7 @@ def test_io_egi_pns_mff(tmp_path): @requires_testing_data +@requires_mffpy @pytest.mark.parametrize("preload", (True, False)) def test_io_egi_pns_mff_bug(preload): """Test importing EGI MFF with PNS data (BUG).""" @@ -389,6 +400,7 @@ def test_io_egi_pns_mff_bug(preload): @requires_testing_data +@requires_mffpy def test_io_egi_crop_no_preload(): """Test crop non-preloaded EGI MFF data (BUG).""" pytest.importorskip("defusedxml") @@ -502,6 +514,7 @@ def test_read_evokeds_mff_bad_input(): @requires_testing_data +@requires_mffpy def test_egi_coord_frame(): """Test that EGI coordinate frame is changed to head.""" pytest.importorskip("defusedxml") @@ -531,6 +544,7 @@ def test_egi_coord_frame(): @requires_testing_data +@requires_mffpy @pytest.mark.parametrize( "fname, timestamp, utc_offset", [ @@ -555,6 +569,7 @@ def test_meas_date(fname, timestamp, utc_offset): @requires_testing_data +@requires_mffpy @pytest.mark.parametrize( "fname, standard_montage", [ @@ -589,6 +604,7 @@ def test_set_standard_montage_mff(fname, standard_montage): @requires_testing_data +@requires_mffpy def test_egi_mff_bad_xml(tmp_path): """Test that corrupt XML files are gracefully handled.""" pytest.importorskip("defusedxml")