diff --git a/doc/changes/dev/#13684.newfeature.rst b/doc/changes/dev/#13684.newfeature.rst new file mode 100644 index 00000000000..e31d194a6fb --- /dev/null +++ b/doc/changes/dev/#13684.newfeature.rst @@ -0,0 +1 @@ +The EGI MFF reader has been refactored to use the ``mffpy`` backend, improving support for multi-stream files and high-precision metadata, by :newcontrib:`Pragnya Khandelwal`. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 09c5818b6af..f52ca025fa3 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -258,6 +258,8 @@ .. _Pierre Guetschel: https://github.com/PierreGtch .. _Pierre-Antoine Bannier: https://github.com/PABannier .. _Ping-Keng Jao: https://github.com/nafraw +.. _Pragnya Khandelwal: https://github.com/PragnyaKhandelwal +.. _Pragnya Khandelwal: https://github.com/PragnyaKhandelwal .. _Proloy Das: https://github.com/proloyd .. _Qian Chu: https://github.com/qian-chu .. _Qianliang Li: https://www.dtu.dk/english/service/phonebook/person?id=126774 diff --git a/mne/io/egi/egimff.py b/mne/io/egi/egimff.py index 870f58890a2..899a4518f67 100644 --- a/mne/io/egi/egimff.py +++ b/mne/io/egi/egimff.py @@ -5,7 +5,6 @@ """EGI NetStation Load Function.""" import datetime -import math import os.path as op import re from collections import OrderedDict @@ -13,6 +12,11 @@ import numpy as np +try: + import mffpy +except ImportError: + mffpy = None + from ..._fiff.constants import FIFF from ..._fiff.meas_info import _empty_info, _ensure_meas_date_none_or_dt, create_info from ..._fiff.proj import setup_proj @@ -22,12 +26,9 @@ from ...evoked import EvokedArray from ...utils import _check_fname, _check_option, _soft_import, logger, verbose, warn from ..base import BaseRaw -from .events import _combine_triggers, _read_events, _triage_include_exclude +from .events import _combine_triggers, _triage_include_exclude from .general import ( - _block_r, _extract, - _get_blocks, - _get_ep_info, _get_gains, _get_signalfname, ) @@ -35,217 +36,137 @@ REFERENCE_NAMES = ("VREF", "Vertex Reference") -def _read_mff_header(filepath): - """Read mff header.""" - _soft_import("defusedxml", "reading EGI MFF data") - from defusedxml.minidom import parse +def _get_mff_reader(input_fname): + """Instantiate an mffpy Reader (hard dependency for MFF reading).""" + mffpy = _import_mffpy() + return mffpy.Reader(input_fname) - all_files = _get_signalfname(filepath) - eeg_file = all_files["EEG"]["signal"] - eeg_info_file = all_files["EEG"]["info"] - info_filepath = op.join(filepath, "info.xml") # add with filepath - tags = ["mffVersion", "recordTime"] - version_and_date = _extract(tags, filepath=info_filepath) - version = "" - if len(version_and_date["mffVersion"]): - version = version_and_date["mffVersion"][0] - - fname = op.join(filepath, eeg_file) - signal_blocks = _get_blocks(fname) - epochs = _get_ep_info(filepath) - summaryinfo = dict(eeg_fname=eeg_file, info_fname=eeg_info_file) - summaryinfo.update(signal_blocks) - # sanity check and update relevant values - record_time = version_and_date["recordTime"][0] - # e.g., - # 2018-07-30T10:47:01.021673-04:00 - # 2017-09-20T09:55:44.072000000+01:00 +def _get_mff_startdatetime(input_fname, mff_reader): + """Get robust start datetime for MFF files, handling 9-digit fractional secs.""" + try: + return mff_reader.startdatetime + except Exception: + info_filepath = op.join(str(input_fname), "info.xml") + record_time = _extract(["recordTime"], filepath=info_filepath)["recordTime"][0] + if len(record_time) > 32: + dt, tz = [record_time[:26], record_time[-6:]] + record_time = dt + tz + return datetime.datetime.strptime(record_time, "%Y-%m-%dT%H:%M:%S.%f%z") + + +def _parse_egi_datetime(time_str): + """Parse EGI time strings allowing 6 or 9 fractional second digits.""" + if time_str is None: + return None + txt = time_str.strip() g = re.match( - r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.(\d{6}(?:\d{3})?)[+-]\d{2}:\d{2}", # noqa: E501 - record_time, + r"(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.)(\d+)([+-]\d{2}:?\d{2})$", + txt, ) if g is None: - raise RuntimeError(f"Could not parse recordTime {repr(record_time)}") - frac = g.groups()[0] - assert len(frac) in (6, 9) and all(f.isnumeric() for f in frac) # regex - div = 1000 if len(frac) == 6 else 1000000 - for key in ("last_samps", "first_samps"): - # convert from times in µS to samples - for ei, e in enumerate(epochs[key]): - if e % div != 0: - raise RuntimeError(f"Could not parse epoch time {e}") - epochs[key][ei] = e // div - epochs[key] = np.array(epochs[key], np.uint64) - # I guess they refer to times in milliseconds? - # What we really need to do here is: - # epochs[key] *= signal_blocks['sfreq'] - # epochs[key] //= 1000 - # But that multiplication risks an overflow, so let's only multiply - # by what we need to (e.g., a sample rate of 500 means we can multiply - # by 1 and divide by 2 rather than multiplying by 500 and dividing by - # 1000) - numerator = int(signal_blocks["sfreq"]) - denominator = 1000 - this_gcd = math.gcd(numerator, denominator) - numerator = numerator // this_gcd - denominator = denominator // this_gcd - with np.errstate(over="raise"): - epochs[key] *= numerator - epochs[key] //= denominator - # Should be safe to cast to int now, which makes things later not - # upbroadcast to float - epochs[key] = epochs[key].astype(np.int64) - n_samps_block = signal_blocks["samples_block"].sum() - n_samps_epochs = (epochs["last_samps"] - epochs["first_samps"]).sum() - bad = ( - n_samps_epochs != n_samps_block - or not (epochs["first_samps"] < epochs["last_samps"]).all() - or not (epochs["first_samps"][1:] >= epochs["last_samps"][:-1]).all() - ) - if bad: - raise RuntimeError( - "EGI epoch first/last samps could not be parsed:\n" - f"{list(epochs['first_samps'])}\n{list(epochs['last_samps'])}" - ) - summaryinfo.update(epochs) - # index which samples in raw are actually readable from disk (i.e., not - # in a skip) - disk_samps = np.full(epochs["last_samps"][-1], -1) - offset = 0 - for first, last in zip(epochs["first_samps"], epochs["last_samps"]): - n_this = last - first - disk_samps[first:last] = np.arange(offset, offset + n_this) - offset += n_this - summaryinfo["disk_samps"] = disk_samps + return datetime.datetime.strptime(txt, "%Y-%m-%dT%H:%M:%S.%f%z") + prefix, frac, tz = g.groups() + frac = (frac[:6]).ljust(6, "0") + return datetime.datetime.strptime(prefix + frac + tz, "%Y-%m-%dT%H:%M:%S.%f%z") - # Add the sensor info. - sensor_layout_file = op.join(filepath, "sensorLayout.xml") - sensor_layout_obj = parse(sensor_layout_file) - summaryinfo["device"] = sensor_layout_obj.getElementsByTagName("name")[ - 0 - ].firstChild.data +def _get_info_from_mff_reader(input_fname, mff_reader): + """Build EGI info dict from mffpy.Reader metadata.""" + input_fname = str(input_fname) + sfreq_dict = mff_reader.sampling_rates + sfreq = float(sfreq_dict.get("EEG", next(iter(sfreq_dict.values())))) + meas_dt_local = _get_mff_startdatetime(input_fname, mff_reader) + + all_files = _get_signalfname(input_fname) + eeg_file = all_files["EEG"]["signal"] + eeg_info_file = all_files["EEG"]["info"] + + # Parse channel metadata from sensorLayout.xml + _soft_import("defusedxml", "reading EGI MFF data") + from defusedxml.minidom import parse + + sensor_layout_file = op.join(input_fname, "sensorLayout.xml") + sensor_layout_obj = parse(sensor_layout_file) + device = sensor_layout_obj.getElementsByTagName("name")[0].firstChild.data sensors = sensor_layout_obj.getElementsByTagName("sensor") - chan_type = list() - chan_unit = list() + + chan_type = [] + chan_unit = [] + numbers = [] n_chans = 0 - numbers = list() # used for identification for sensor in sensors: sensortype = int(sensor.getElementsByTagName("type")[0].firstChild.data) if sensortype in [0, 1]: - sn = sensor.getElementsByTagName("number")[0].firstChild.data - sn = sn.encode() + sn = sensor.getElementsByTagName("number")[0].firstChild.data.encode() numbers.append(sn) chan_type.append("eeg") chan_unit.append("uV") - n_chans = n_chans + 1 - if n_chans != summaryinfo["n_channels"]: - raise RuntimeError( - f"Number of defined channels ({n_chans}) did not match the " - f"expected channels ({summaryinfo['n_channels']})." - ) + n_chans += 1 + + # Collect epoch bounds and per-epoch sample counts from mffpy + first_samps = [] + last_samps = [] + samples_block = [] + pns_samples_block = [] + for ei in range(len(mff_reader.epochs)): + epoch = mff_reader.epochs[ei] + data_epoch = mff_reader.get_physical_samples_from_epoch(epoch) + eeg_samples = int(data_epoch["EEG"][0].shape[1]) + first = int(np.round(epoch.t0 * sfreq)) + last = first + eeg_samples + first_samps.append(first) + last_samps.append(last) + samples_block.append(eeg_samples) + + pns_arr = data_epoch.get("PNSData") + pns_samples_block.append(0 if pns_arr is None else int(pns_arr[0].shape[1])) + + first_samps = np.array(first_samps, dtype=np.int64) + last_samps = np.array(last_samps, dtype=np.int64) + samples_block = np.array(samples_block, dtype=np.int64) + pns_samples_block = np.array(pns_samples_block, dtype=np.int64) + + # index which samples in raw are actually readable from disk (i.e., not in a skip) + disk_samps = np.full(last_samps[-1], -1, dtype=np.int64) + offset = 0 + for first, last in zip(first_samps, last_samps): + n_this = last - first + disk_samps[first:last] = np.arange(offset, offset + n_this) + offset += n_this - # Check presence of PNS data + # Parse PNS channel metadata if present pns_names = [] + pns_types = [] + pns_units = [] + pns_fname = None if "PNS" in all_files: - pns_fpath = op.join(filepath, all_files["PNS"]["signal"]) - pns_blocks = _get_blocks(pns_fpath) - pns_samples = pns_blocks["samples_block"] - signal_samples = signal_blocks["samples_block"] - same_blocks = np.array_equal( - pns_samples[:-1], signal_samples[:-1] - ) and pns_samples[-1] in (signal_samples[-1] - np.arange(2)) - if not same_blocks: - raise RuntimeError( - "PNS and signals samples did not match:\n" - f"{list(pns_samples)}\nvs\n{list(signal_samples)}" - ) - - pns_file = op.join(filepath, "pnsSet.xml") - pns_obj = parse(pns_file) - sensors = pns_obj.getElementsByTagName("sensor") - pns_types = [] - pns_units = [] - for sensor in sensors: - # sensor number: - # sensor.getElementsByTagName('number')[0].firstChild.data - name = sensor.getElementsByTagName("name")[0].firstChild.data - unit_elem = sensor.getElementsByTagName("unit")[0].firstChild - unit = "" - if unit_elem is not None: - unit = unit_elem.data - - if name == "ECG": - ch_type = "ecg" - elif "EMG" in name: - ch_type = "emg" - else: - ch_type = "bio" - pns_types.append(ch_type) - pns_units.append(unit) - pns_names.append(name) - - summaryinfo.update( - pns_types=pns_types, - pns_units=pns_units, - pns_fname=all_files["PNS"]["signal"], - pns_sample_blocks=pns_blocks, - ) - summaryinfo.update( - pns_names=pns_names, - version=version, - date=version_and_date["recordTime"][0], - chan_type=chan_type, - chan_unit=chan_unit, - numbers=numbers, - ) - - return summaryinfo - - -def _read_header(input_fname): - """Obtain the headers from the file package mff. - - Parameters - ---------- - input_fname : path-like - Path for the file - - Returns - ------- - info : dict - Main headers set. - """ - input_fname = str(input_fname) # cast to str any Paths - mff_hdr = _read_mff_header(input_fname) - with open(input_fname + "/signal1.bin", "rb") as fid: - version = np.fromfile(fid, np.int32, 1)[0] - """ - the datetime.strptime .f directive (milleseconds) - will only accept up to 6 digits. if there are more than - six millesecond digits in the provided timestamp string - (i.e. because of trailing zeros, as in test_egi_pns.mff) - then slice both the first 26 elements and the last 6 - elements of the timestamp string to truncate the - milleseconds to 6 digits and extract the timezone, - and then piece these together and assign back to mff_hdr['date'] - """ - if len(mff_hdr["date"]) > 32: - dt, tz = [mff_hdr["date"][:26], mff_hdr["date"][-6:]] - mff_hdr["date"] = dt + tz - - time_n = datetime.datetime.strptime(mff_hdr["date"], "%Y-%m-%dT%H:%M:%S.%f%z") + pns_fname = all_files["PNS"]["signal"] + pns_file = op.join(input_fname, "pnsSet.xml") + if op.exists(pns_file): + pns_obj = parse(pns_file) + pns_sensors = pns_obj.getElementsByTagName("sensor") + for sensor in pns_sensors: + name = sensor.getElementsByTagName("name")[0].firstChild.data + unit_elem = sensor.getElementsByTagName("unit")[0].firstChild + unit = "" if unit_elem is None else unit_elem.data + if name == "ECG": + ch_type = "ecg" + elif "EMG" in name: + ch_type = "emg" + else: + ch_type = "bio" + pns_names.append(name) + pns_types.append(ch_type) + pns_units.append(unit) info = dict( - version=version, - meas_dt_local=time_n, - utc_offset=time_n.strftime("%z"), + version=0, + meas_dt_local=meas_dt_local, + utc_offset=meas_dt_local.strftime("%z"), gain=0, bits=0, value_range=0, - ) - info.update( n_categories=0, n_segments=1, n_events=0, @@ -253,11 +174,104 @@ def _read_header(input_fname): category_names=[], category_lengths=[], pre_baseline=0, + sfreq=sfreq, + n_channels=n_chans, + eeg_fname=eeg_file, + info_fname=eeg_info_file, + device=device, + chan_type=chan_type, + chan_unit=chan_unit, + numbers=numbers, + first_samps=first_samps, + last_samps=last_samps, + samples_block=samples_block, + disk_samps=disk_samps, + pns_names=pns_names, + pns_types=pns_types, + pns_units=pns_units, + pns_fname=pns_fname, + pns_sample_blocks={ + "n_channels": len(pns_names), + "samples_block": pns_samples_block, + }, + mff_path=input_fname, ) - info.update(mff_hdr) return info +def _read_mff_events(input_fname, mff_reader, sfreq, n_samples, start_dt): + """Read event tracks using mffpy XML parsing and return dense event matrix.""" + from mffpy.xml_files import XML, EventTrack + + mff_events = OrderedDict() + basenames = mff_reader.directory.listdir() + for basename in basenames: + lower_name = basename.lower() + if not lower_name.endswith(".xml") or basename.startswith("._"): + continue + stem = Path(basename).stem + try: + with mff_reader.directory.filepointer(stem) as fp: + xml_obj = XML.from_file(fp, recover=False) + except Exception as err: + if "XMLSyntaxError" in type(err).__name__: + warn(f"Could not parse the XML file {basename}. Skipping it.") + continue + if not isinstance(xml_obj, EventTrack): + continue + try: + events_iter = xml_obj.events + for event in events_iter: + code = event.get("code") or event.get("label") or xml_obj.name + begin_time = event.get("beginTime") + if code is None or begin_time is None: + continue + delta = (begin_time - start_dt).total_seconds() + sample = int(np.floor(delta * sfreq)) + if 0 <= sample < n_samples: + mff_events.setdefault(code, []).append(sample) + except Exception: + _soft_import("defusedxml", "reading EGI MFF event tracks") + from defusedxml import ElementTree as ET + + xml_path = op.join(str(input_fname), basename) + try: + root = ET.parse(xml_path).getroot() + except Exception as err: + if ( + "ParseError" in type(err).__name__ + or "XMLSyntaxError" in type(err).__name__ + ): + warn(f"Could not parse the XML file {basename}. Skipping it.") + continue + for event_el in root.iter(): + if event_el.tag.split("}")[-1] != "event": + continue + event_fields = {} + for child in event_el: + event_fields[child.tag.split("}")[-1]] = child.text + code = ( + event_fields.get("code") + or event_fields.get("label") + or xml_obj.name + ) + begin_time = _parse_egi_datetime(event_fields.get("beginTime")) + if code is None or begin_time is None: + continue + delta = (begin_time - start_dt).total_seconds() + sample = int(np.floor(delta * sfreq)) + if 0 <= sample < n_samples: + mff_events.setdefault(code, []).append(sample) + + event_codes = list(mff_events.keys()) + egi_events = np.zeros((len(event_codes), n_samples)) + for event_idx, code in enumerate(event_codes): + if len(mff_events[code]): + event_samples = np.asarray(mff_events[code], dtype=int) + egi_events[event_idx, event_samples] = 1 + return egi_events, event_codes, mff_events + + def _get_eeg_calibration_info(filepath, egi_info): """Calculate calibration info for EEG channels.""" gains = _get_gains(op.join(filepath, egi_info["info_fname"])) @@ -402,17 +416,25 @@ def __init__( ) ) logger.info(f"Reading EGI MFF Header from {input_fname}...") - egi_info = _read_header(input_fname) + mff_reader = _get_mff_reader(input_fname) + egi_info = _get_info_from_mff_reader(input_fname, mff_reader) if eog is None: eog = [] if misc is None: misc = np.where(np.array(egi_info["chan_type"]) != "eeg")[0].tolist() logger.info(" Reading events ...") - egi_events, egi_info, mff_events = _read_events(input_fname, egi_info) + egi_events, event_codes, mff_events = _read_mff_events( + input_fname, + mff_reader, + egi_info["sfreq"], + egi_info["last_samps"][-1], + egi_info["meas_dt_local"], + ) + egi_info["n_events"] = len(event_codes) + egi_info["event_codes"] = event_codes cals = _get_eeg_calibration_info(input_fname, egi_info) logger.info(" Assembling measurement info ...") - event_codes = egi_info["event_codes"] include = _triage_include_exclude(include, exclude, egi_events, egi_info) if egi_info["n_events"] > 0 and not events_as_annotations: logger.info(' Synthesizing trigger channel "STI 014" ...') @@ -557,15 +579,22 @@ def __init__( ) # Annotate acquisition skips + has_skips = False for first, prev_last in zip( egi_info["first_samps"][1:], egi_info["last_samps"][:-1] ): gap = first - prev_last assert gap >= 0 if gap: + has_skips = True annot["onset"].append((prev_last - 0.5) / egi_info["sfreq"]) annot["duration"].append(gap / egi_info["sfreq"]) annot["description"].append("BAD_ACQ_SKIP") + if has_skips and (not events_as_annotations) and len(mff_events): + warn( + "Acquisition skips detected. EGI MFF file contains gaps between " + "recording epochs." + ) # create events from annotations if events_as_annotations: @@ -582,169 +611,63 @@ def __init__( def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of data.""" logger.debug(f"Reading MFF {start:6d} ... {stop:6d} ...") - dtype = "= bounds[1]) & (idx < bounds[2]))[0] stim_one = idx[stim_out] stim_in = idx[stim_out] - bounds[1] pns_out = np.where((idx >= bounds[2]) & (idx < bounds[3]))[0] pns_in = idx[pns_out] - bounds[2] - pns_one = idx[pns_out, np.newaxis] + pns_one = idx[pns_out] del eeg_out, stim_out, pns_out # take into account events (already extended to correct size) one[stim_one, :] = egi_info["egi_events"][stim_in, start:stop] - # Convert start and stop to limits in terms of the data - # actually on disk, plus an indexer (disk_use_idx) that populates - # the potentially larger `data` with it, taking skips into account - disk_samps = egi_info["disk_samps"][start:stop] - disk_use_idx = np.where(disk_samps > -1)[0] - # short circuit in case we don't need any samples - if not len(disk_use_idx): - _mult_cal_one(data, one, idx, cals, mult) - return - - start = disk_samps[disk_use_idx[0]] - stop = disk_samps[disk_use_idx[-1]] + 1 - assert len(disk_use_idx) == stop - start - - # Get starting/stopping block/samples - block_samples_offset = np.cumsum(samples_block) - offset_blocks = np.sum(block_samples_offset <= start) - offset_samples = start - ( - block_samples_offset[offset_blocks - 1] if offset_blocks > 0 else 0 - ) - - # TODO: Refactor this reading with the PNS reading in a single function - # (DRY) - samples_to_read = stop - start - with open(self.filenames[fi], "rb", buffering=0) as fid: - # Go to starting block - current_block = 0 - current_block_info = None - current_data_sample = 0 - while current_block < offset_blocks: - this_block_info = _block_r(fid) - if this_block_info is not None: - current_block_info = this_block_info - fid.seek(current_block_info["block_size"], 1) - current_block += 1 - - # Start reading samples - while samples_to_read > 0: - logger.debug(f" Reading from block {current_block}") - this_block_info = _block_r(fid) - current_block += 1 - if this_block_info is not None: - current_block_info = this_block_info - - to_read = current_block_info["nsamples"] * current_block_info["nc"] - block_data = np.fromfile(fid, dtype, to_read) - block_data = block_data.reshape(n_channels, -1, order="C") - - # Compute indexes - samples_read = block_data.shape[1] - logger.debug(f" Read {samples_read} samples") - logger.debug(f" Offset {offset_samples} samples") - if offset_samples > 0: - # First block read, skip to the offset: - block_data = block_data[:, offset_samples:] - samples_read = samples_read - offset_samples - offset_samples = 0 - if samples_to_read < samples_read: - # Last block to read, skip the last samples - block_data = block_data[:, :samples_to_read] - samples_read = samples_to_read - logger.debug(f" Keep {samples_read} samples") - - s_start = current_data_sample - s_end = s_start + samples_read - - one[eeg_one, disk_use_idx[s_start:s_end]] = block_data[eeg_in] - samples_to_read = samples_to_read - samples_read - current_data_sample = current_data_sample + samples_read - - if len(pns_one) > 0: - # PNS Data is present and should be read: - pns_filepath = egi_info["pns_filepath"] - pns_info = egi_info["pns_sample_blocks"] - n_channels = pns_info["n_channels"] - samples_block = pns_info["samples_block"] - - # Get starting/stopping block/samples - block_samples_offset = np.cumsum(samples_block) - offset_blocks = np.sum(block_samples_offset < start) - offset_samples = start - ( - block_samples_offset[offset_blocks - 1] if offset_blocks > 0 else 0 - ) - - samples_to_read = stop - start - with open(pns_filepath, "rb", buffering=0) as fid: - # Check file size - fid.seek(0, 2) - file_size = fid.tell() - fid.seek(0) - # Go to starting block - current_block = 0 - current_block_info = None - current_data_sample = 0 - while current_block < offset_blocks: - this_block_info = _block_r(fid) - if this_block_info is not None: - current_block_info = this_block_info - fid.seek(current_block_info["block_size"], 1) - current_block += 1 - - # Start reading samples - while samples_to_read > 0: - if samples_to_read == 1 and fid.tell() == file_size: - # We are in the presence of the EEG bug - # fill with zeros and break the loop - one[pns_one, -1] = 0 - break - - this_block_info = _block_r(fid) - if this_block_info is not None: - current_block_info = this_block_info - - to_read = current_block_info["nsamples"] * current_block_info["nc"] - block_data = np.fromfile(fid, dtype, to_read) - block_data = block_data.reshape(n_channels, -1, order="C") - - # Compute indexes - samples_read = block_data.shape[1] - if offset_samples > 0: - # First block read, skip to the offset: - block_data = block_data[:, offset_samples:] - samples_read = samples_read - offset_samples - offset_samples = 0 - - if samples_to_read < samples_read: - # Last block to read, skip the last samples - block_data = block_data[:, :samples_to_read] - samples_read = samples_to_read - - s_start = current_data_sample - s_end = s_start + samples_read - - one[pns_one, disk_use_idx[s_start:s_end]] = block_data[pns_in] - samples_to_read = samples_to_read - samples_read - current_data_sample = current_data_sample + samples_read + # Read only overlapping epoch segments from mffpy (keeps skips as zeros) + for epoch_idx, (first, last) in enumerate(zip(first_samps, last_samps)): + overlap_start = max(start, first) + overlap_stop = min(stop, last) + if overlap_stop <= overlap_start: + continue + + epoch = mff_reader.epochs[epoch_idx] + epoch_data = mff_reader.get_physical_samples_from_epoch(epoch) + eeg_block = epoch_data["EEG"][0][:n_channels] + src_start = overlap_start - first + src_stop = overlap_stop - first + dst_start = overlap_start - start + dst_stop = overlap_stop - start + + if len(eeg_one): + one[eeg_one, dst_start:dst_stop] = eeg_block[eeg_in, src_start:src_stop] + + if len(pns_one) and "PNSData" in epoch_data: + pns_block = epoch_data["PNSData"][0] + src_stop_pns = min(src_stop, pns_block.shape[1]) + if src_stop_pns > src_start: + dst_stop_pns = dst_start + (src_stop_pns - src_start) + one[pns_one, dst_start:dst_stop_pns] = pns_block[ + pns_in, src_start:src_stop_pns + ] # do the calibration _mult_cal_one(data, one, idx, cals, mult) @@ -844,10 +767,8 @@ def read_evokeds_mff( def _read_evoked_mff(fname, condition, channel_naming="E%d", verbose=None): """Read evoked data from MFF file.""" - import mffpy - - egi_info = _read_header(fname) mff = mffpy.Reader(fname) + egi_info = _get_info_from_mff_reader(str(fname), mff) categories = mff.categories.categories if isinstance(condition, str): @@ -965,10 +886,8 @@ def _read_evoked_mff(fname, condition, channel_naming="E%d", verbose=None): def _import_mffpy(why="read averaged .mff files"): """Import and return module mffpy.""" - try: - import mffpy - except ImportError as exp: - msg = f"mffpy is required to {why}, got:\n{exp}" + if mffpy is None: + msg = f"mffpy is required to {why}." raise ImportError(msg) return mffpy diff --git a/mne/io/egi/events.py b/mne/io/egi/events.py index c160ceb208c..540d0fb5f51 100644 --- a/mne/io/egi/events.py +++ b/mne/io/egi/events.py @@ -3,160 +3,9 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from datetime import datetime -from glob import glob -from os.path import basename, join, splitext - import numpy as np -from ...utils import _soft_import, _validate_type, logger, warn - - -def _read_events(input_fname, info): - """Read events for the record. - - Parameters - ---------- - input_fname : path-like - The file path. - info : dict - Header info array. - """ - n_samples = info["last_samps"][-1] - mff_events, event_codes = _read_mff_events(input_fname, info["sfreq"]) - info["n_events"] = len(event_codes) - info["event_codes"] = event_codes - events = np.zeros([info["n_events"], info["n_segments"] * n_samples]) - for n, event in enumerate(event_codes): - for i in mff_events[event]: - if (i < 0) or (i >= events.shape[1]): - continue - events[n][i] = n + 1 - return events, info, mff_events - - -def _read_mff_events(filename, sfreq): - """Extract the events. - - Parameters - ---------- - filename : path-like - File path. - sfreq : float - The sampling frequency - """ - orig = {} - for xml_file in glob(join(filename, "*.xml")): - xml_type = splitext(basename(xml_file))[0] - et = _parse_xml(xml_file) - if et is not None: - orig[xml_type] = et - xml_files = orig.keys() - xml_events = [x for x in xml_files if x[:7] == "Events_"] - for item in orig["info"]: - if "recordTime" in item: - start_time = _ns2py_time(item["recordTime"]) - break - markers = [] - code = [] - for xml in xml_events: - for event in orig[xml][2:]: - event_start = _ns2py_time(event["beginTime"]) - start = (event_start - start_time).total_seconds() - if event["code"] not in code: - code.append(event["code"]) - marker = { - "name": event["code"], - "start": start, - "start_sample": int(np.trunc(start * sfreq)), - "end": start + float(event["duration"]) / 1e9, - "chan": None, - } - markers.append(marker) - events_tims = dict() - for ev in code: - trig_samp = list( - c["start_sample"] for n, c in enumerate(markers) if c["name"] == ev - ) - events_tims.update({ev: trig_samp}) - return events_tims, code - - -def _parse_xml(xml_file: str) -> list[dict[str, str]] | None: - """Parse XML file.""" - defusedxml = _soft_import("defusedxml", "reading EGI MFF data") - try: - xml = defusedxml.ElementTree.parse(xml_file) - except defusedxml.ElementTree.ParseError as e: - warn(f"Could not parse the XML file {xml_file}: {e}") - return - root = xml.getroot() - return _xml2list(root) - - -def _xml2list(root): - """Parse XML item.""" - output = [] - for element in root: - if len(element) > 0: - if element[0].tag != element[-1].tag: - output.append(_xml2dict(element)) - else: - output.append(_xml2list(element)) - - elif element.text: - text = element.text.strip() - if text: - tag = _ns(element.tag) - output.append({tag: text}) - - return output - - -def _ns(s): - """Remove namespace, but only if there is a namespace to begin with.""" - if "}" in s: - return "}".join(s.split("}")[1:]) - else: - return s - - -def _xml2dict(root): - """Use functions instead of Class. - - remove namespace based on - http://stackoverflow.com/questions/2148119 - """ - output = {} - if root.items(): - output.update(dict(root.items())) - - for element in root: - if len(element) > 0: - if len(element) == 1 or element[0].tag != element[1].tag: - one_dict = _xml2dict(element) - else: - one_dict = {_ns(element[0].tag): _xml2list(element)} - - if element.items(): - one_dict.update(dict(element.items())) - output.update({_ns(element.tag): one_dict}) - - elif element.items(): - output.update({_ns(element.tag): dict(element.items())}) - - else: - output.update({_ns(element.tag): element.text}) - return output - - -def _ns2py_time(nstime): - """Parse times.""" - nsdate = nstime[0:10] - nstime0 = nstime[11:26] - nstime00 = nsdate + " " + nstime0 - pytime = datetime.strptime(nstime00, "%Y-%m-%d %H:%M:%S.%f") - return pytime +from ...utils import _validate_type, logger, warn def _combine_triggers(data, remapping=None): diff --git a/mne/io/egi/general.py b/mne/io/egi/general.py index ed028e3e5ed..d3819b2abc5 100644 --- a/mne/io/egi/general.py +++ b/mne/io/egi/general.py @@ -50,86 +50,6 @@ def _get_gains(filepath): return gains -def _get_ep_info(filepath): - """Get epoch info.""" - _soft_import("defusedxml", "reading EGI MFF data") - from defusedxml.minidom import parse - - epochfile = filepath + "/epochs.xml" - epochlist = parse(epochfile) - epochs = epochlist.getElementsByTagName("epoch") - keys = ("first_samps", "last_samps", "first_blocks", "last_blocks") - epoch_info = {key: list() for key in keys} - for epoch in epochs: - ep_begin = int(epoch.getElementsByTagName("beginTime")[0].firstChild.data) - ep_end = int(epoch.getElementsByTagName("endTime")[0].firstChild.data) - first_block = int(epoch.getElementsByTagName("firstBlock")[0].firstChild.data) - last_block = int(epoch.getElementsByTagName("lastBlock")[0].firstChild.data) - epoch_info["first_samps"].append(ep_begin) - epoch_info["last_samps"].append(ep_end) - epoch_info["first_blocks"].append(first_block) - epoch_info["last_blocks"].append(last_block) - # Don't turn into ndarray here, keep native int because it can deal with - # huge numbers (could use np.uint64 but it's more work) - return epoch_info - - -def _get_blocks(filepath): - """Get info from meta data blocks.""" - binfile = os.path.join(filepath) - n_blocks = 0 - samples_block = [] - header_sizes = [] - n_channels = [] - sfreq = [] - # Meta data consists of: - # * 1 byte of flag (1 for meta data, 0 for data) - # * 1 byte of header size - # * 1 byte of block size - # * 1 byte of n_channels - # * n_channels bytes of offsets - # * n_channels bytes of sigfreqs? - with open(binfile, "rb") as fid: - fid.seek(0, 2) # go to end of file - file_length = fid.tell() - block_size = file_length - fid.seek(0) - position = 0 - while position < file_length: - block = _block_r(fid) - if block is None: - samples_block.append(samples_block[n_blocks - 1]) - n_blocks += 1 - fid.seek(block_size, 1) - position = fid.tell() - continue - block_size = block["block_size"] - header_size = block["header_size"] - header_sizes.append(header_size) - samples_block.append(block["nsamples"]) - n_blocks += 1 - fid.seek(block_size, 1) - sfreq.append(block["sfreq"]) - n_channels.append(block["nc"]) - position = fid.tell() - - if any([n != n_channels[0] for n in n_channels]): - raise RuntimeError("All the blocks don't have the same amount of channels.") - if any([f != sfreq[0] for f in sfreq]): - raise RuntimeError("All the blocks don't have the same sampling frequency.") - if len(samples_block) < 1: - raise RuntimeError("There seems to be no data") - samples_block = np.array(samples_block) - signal_blocks = dict( - n_channels=n_channels[0], - sfreq=sfreq[0], - n_blocks=n_blocks, - samples_block=samples_block, - header_sizes=header_sizes, - ) - return signal_blocks - - def _get_signalfname(filepath): """Get filenames.""" _soft_import("defusedxml", "reading EGI MFF data") @@ -162,31 +82,3 @@ def _get_signalfname(filepath): f"found in {filepath}:\n{infofiles_str}" ) return all_files - - -def _block_r(fid): - """Read meta data.""" - if np.fromfile(fid, dtype=np.dtype("i4"), count=1).item() != 1: # not meta - return None - header_size = np.fromfile(fid, dtype=np.dtype("i4"), count=1).item() - block_size = np.fromfile(fid, dtype=np.dtype("i4"), count=1).item() - hl = int(block_size / 4) - nc = np.fromfile(fid, dtype=np.dtype("i4"), count=1).item() - nsamples = int(hl / nc) - np.fromfile(fid, dtype=np.dtype("i4"), count=nc) # sigoffset - sigfreq = np.fromfile(fid, dtype=np.dtype("i4"), count=nc) - depth = sigfreq[0] & 0xFF - if depth != 32: - raise ValueError("I do not know how to read this MFF (depth != 32)") - sfreq = sigfreq[0] >> 8 - count = int(header_size / 4 - (4 + 2 * nc)) - np.fromfile(fid, dtype=np.dtype("i4"), count=count) # sigoffset - block = dict( - nc=nc, - hl=hl, - nsamples=nsamples, - block_size=block_size, - header_size=header_size, - sfreq=sfreq, - ) - return block diff --git a/mne/io/egi/tests/test_egi.py b/mne/io/egi/tests/test_egi.py index 261a9c80da3..0ede82603ce 100644 --- a/mne/io/egi/tests/test_egi.py +++ b/mne/io/egi/tests/test_egi.py @@ -7,6 +7,7 @@ import shutil from copy import deepcopy from datetime import datetime, timezone +from importlib.util import find_spec from pathlib import Path import numpy as np @@ -36,6 +37,11 @@ egi_txt_evoked_cat1_fname = egi_path / "test_egi_evoked_cat1.txt" egi_txt_evoked_cat2_fname = egi_path / "test_egi_evoked_cat2.txt" +requires_mffpy = pytest.mark.skipif( + find_spec("mffpy") is None, + reason="Test requires mffpy", +) + # absolute event times from NetStation egi_pause_events = { "AM40": [7.224, 11.928, 14.413, 16.848], @@ -59,6 +65,7 @@ @requires_testing_data +@requires_mffpy @pytest.mark.parametrize( "fname, skip_times, event_times", [ @@ -121,6 +128,7 @@ def test_egi_mff_pause(fname, skip_times, event_times): @requires_testing_data +@requires_mffpy @pytest.mark.parametrize( "fname", [ @@ -143,6 +151,7 @@ def test_egi_mff_pause_chunks(fname, tmp_path): @requires_testing_data +@requires_mffpy @pytest.mark.parametrize("events_as_annotations", (True, False)) def test_io_egi_mff(events_as_annotations): """Test importing EGI MFF simple binary files.""" @@ -290,6 +299,7 @@ def test_io_egi(): @requires_testing_data +@requires_mffpy def test_io_egi_pns_mff(tmp_path): """Test importing EGI MFF with PNS data.""" pytest.importorskip("defusedxml") @@ -346,6 +356,7 @@ def test_io_egi_pns_mff(tmp_path): @requires_testing_data +@requires_mffpy @pytest.mark.parametrize("preload", (True, False)) def test_io_egi_pns_mff_bug(preload): """Test importing EGI MFF with PNS data (BUG).""" @@ -390,6 +401,7 @@ def test_io_egi_pns_mff_bug(preload): @requires_testing_data +@requires_mffpy def test_io_egi_crop_no_preload(): """Test crop non-preloaded EGI MFF data (BUG).""" pytest.importorskip("defusedxml") @@ -503,6 +515,7 @@ def test_read_evokeds_mff_bad_input(): @requires_testing_data +@requires_mffpy def test_egi_coord_frame(): """Test that EGI coordinate frame is changed to head.""" pytest.importorskip("defusedxml") @@ -532,6 +545,7 @@ def test_egi_coord_frame(): @requires_testing_data +@requires_mffpy @pytest.mark.parametrize( "fname, timestamp, utc_offset", [ @@ -556,6 +570,7 @@ def test_meas_date(fname, timestamp, utc_offset): @requires_testing_data +@requires_mffpy @pytest.mark.parametrize( "fname, standard_montage", [ @@ -590,6 +605,7 @@ def test_set_standard_montage_mff(fname, standard_montage): @requires_testing_data +@requires_mffpy def test_egi_mff_bad_xml(tmp_path): """Test that corrupt XML files are gracefully handled.""" pytest.importorskip("defusedxml")