diff --git a/doc/changes/dev/13921.bugfix.rst b/doc/changes/dev/13921.bugfix.rst new file mode 100644 index 00000000000..9179860b0b2 --- /dev/null +++ b/doc/changes/dev/13921.bugfix.rst @@ -0,0 +1 @@ +Ensure grouped OPM tangential topomaps use unsigned RMS magnitude and per-group colormap scaling, aligning behavior with Neuromag-style grouping, by `Pragnya Khandelwal`_. diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index a62d2379f03..d98d0b1c695 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1954,9 +1954,24 @@ def plot_evoked_joint( del times _, times_ts = _check_time_unit(ts_args["time_unit"], times_sec) - # prepare axes for topomap + ch_type = ch_types.pop() # set should only contain one element + use_opm_orientation_groups = False + if ch_type == "mag": + from .topomap import _should_use_opm_orientation_groups + + picks_topo, _, merge_channels, *_ = _prepare_topomap_plot( + evoked, ch_type, sphere=topomap_args.get("sphere") + ) + use_opm_orientation_groups = _should_use_opm_orientation_groups( + evoked.info, picks_topo, merge_channels, ch_type + ) + n_group_axes = 2 if use_opm_orientation_groups else 1 + + # prepare axes for topomap and butterfly plots if not got_axes: - fig, ts_ax, map_ax = _prepare_joint_axes(len(times_sec), figsize=(8.0, 4.2)) + fig, ts_ax, map_ax = _prepare_joint_axes( + len(times_sec) * n_group_axes, figsize=(8.0, 4.2) + ) cbar_ax = None else: ts_ax = ts_args["axes"] @@ -2002,7 +2017,7 @@ def plot_evoked_joint( # topomap contours = topomap_args.get("contours", 6) - ch_type = ch_types.pop() # set should only contain one element + # Since the data has all the ch_types, we get the limits from the plot. vmin, vmax = (None, None) norm = ch_type == "grad" diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index 0aaabb418e2..b9802e2d22f 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -616,7 +616,8 @@ def test_plot_components_opm(): ica = ICA(max_iter=1, random_state=0, n_components=10) ica.fit(RawArray(evoked.data, evoked.info), picks="mag", verbose="error") fig = ica.plot_components() - assert len(fig.axes) == 10 + # Biaxial OPM overlaps render grouped radial+tangential maps. + assert len(fig.axes) == 20 @pytest.mark.slowtest @@ -628,4 +629,4 @@ def test_plot_components_opm_triaxial(triaxial_raw): ica = ICA(max_iter=1, random_state=0, n_components=3) ica.fit(triaxial_raw, picks="mag", verbose="error") fig = ica.plot_components() - assert len(fig.axes) == 3 + assert len(fig.axes) == 6 diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index c0c976b484f..a9fc3353c89 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -794,7 +794,9 @@ def test_plot_topomap_opm(): fig_evoked = evoked.plot_topomap( times=[-0.1, 0, 0.1, 0.2], ch_type="mag", show=False ) - assert len(fig_evoked.axes) == 5 + # Biaxial OPM pairs trigger grouped rendering + # (4 radial + 4 tangential + 2 colorbars) + assert len(fig_evoked.axes) == 10 def test_prepare_topomap_plot_opm_non_quspin_coils(): @@ -851,6 +853,61 @@ def test_split_opm_overlaps(triaxial_evoked): assert tangential == ["OPM002", "OPM003", "OPM005", "OPM006"] +def test_opm_tangential_rms_unsigned(triaxial_evoked): + """Test that tangential OPM data is RMS magnitude and unsigned.""" + picks, pos, merge_channels, names, *_ = topomap._prepare_topomap_plot( + triaxial_evoked, "mag" + ) + data = triaxial_evoked.data[picks] + grouped = topomap._compute_opm_orientation_topomap_data( + data, names, pos, merge_channels + ) + tangential = [group for group in grouped if group[0] == "tangential"][0] + assert np.all(tangential[1] >= 0) + assert tangential[4] + + +def test_should_use_opm_orientation_groups_only_for_triaxial(): + """Test that OPM orientation grouping works for biaxial and triaxial overlaps.""" + ch_names = [f"OPM{k:03}" for k in range(1, 7)] + info = create_info(ch_names, 1000.0, ch_types="mag") + with info._unlock(): + for ch in info["chs"]: + ch["coil_type"] = FIFF.FIFFV_COIL_FIELDLINE_OPM_MAG_GEN1 + + picks = np.arange(len(ch_names)) + pair_overlaps = [ + np.array(["OPM001", "OPM002"]), + np.array(["OPM003", "OPM004"]), + ] + triax_overlaps = [ + np.array(["OPM001", "OPM002", "OPM003"]), + np.array(["OPM004", "OPM005", "OPM006"]), + ] + + # Both biaxial and triaxial overlaps should trigger grouping + assert topomap._should_use_opm_orientation_groups(info, picks, pair_overlaps, "mag") + assert topomap._should_use_opm_orientation_groups( + info, picks, triax_overlaps, "mag" + ) + + +def test_plot_evoked_topomap_opm_triaxial_groups(triaxial_evoked): + """Test grouped radial/tangential topomap rendering for triaxial OPM.""" + fig = triaxial_evoked.plot_topomap( + times=[0.0], + ch_type="mag", + contours=0, + res=8, + sensors=False, + show=False, + ) + assert len(fig.axes) == 4 + titles = [ax.get_title() for ax in fig.axes] + assert any("radial" in title for title in titles) + assert any("tangential" in title for title in titles) + + def test_plot_topomap_nirs_overlap(fnirs_epochs): """Test plotting nirs topomap with overlapping channels (gh-7414).""" fig = fnirs_epochs["A"].average(picks="hbo").plot_topomap() diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index fcbd213ce78..f95f8d7cd26 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -329,6 +329,87 @@ def _split_opm_overlaps(overlapping_channels): return radial, tangential +def _rms(data, axis=0): + """Compute root-mean-square magnitude along an axis.""" + return np.sqrt(np.mean(data**2, axis=axis)) + + +def _compute_opm_orientation_topomap_data(data, ch_names, pos, overlapping_channels): + """Compute radial and tangential OPM topomap data from overlap sets.""" + from ..channels.layout import _merge_ch_data + + # Radial data matches the existing OPM merge behavior and position layout. + radial_data, radial_names = _merge_ch_data( + data.copy(), "mag", copy.copy(ch_names), modality="opm" + ) + radial_pos = pos + + name_lookup = [name.removesuffix("_MERGE-REMOVE") for name in ch_names] + tangential_data = [] + tangential_names = [] + tangential_pos = [] + for overlap_set in overlapping_channels: + idx = [name_lookup.index(ch_name) for ch_name in overlap_set[1:]] + # Collapse multiple tangential channels at one location using RMS. + tangential_data.append(_rms(data[idx], axis=0)) + tangential_names.append(f"{overlap_set[0]}t") + tangential_pos.append(radial_pos[radial_names.index(overlap_set[0])]) + + tangential_data = np.array(tangential_data) + tangential_pos = np.array(tangential_pos) + + return [ + ("radial", radial_data, radial_pos, radial_names, False), + ("tangential", tangential_data, tangential_pos, tangential_names, True), + ] + + +def _compute_orientation_group_data( + data, + ch_names, + pos, + *, + ch_type, + modality, + merge_channels, + use_opm_orientation_groups, +): + """Compute grouped topomap data for OPM/Neuromag-like orientations.""" + from ..channels.layout import _merge_ch_data + + if merge_channels: + if modality == "opm" and use_opm_orientation_groups: + return _compute_opm_orientation_topomap_data( + data, ch_names, pos, merge_channels + ) + + data, ch_names = _merge_ch_data(data, ch_type, ch_names, modality=modality) + group_norm = ch_type == "grad" + return [(None, data, pos, ch_names, group_norm)] + + return [(None, data, pos, ch_names, False)] + + +def _should_use_opm_orientation_groups(info, picks, merge_channels, ch_type): + """Return whether OPM orientation grouping should be enabled. + + Grouping is used for OPM magnetometer channels with overlap sets that + include at least 2 colocated channels (biaxial or triaxial sensors). + """ + if ch_type != "mag" or not merge_channels: + return False + + pick_chs = [info["chs"][pick] for pick in picks] + if not pick_chs or not all(ch["coil_type"] in _opm_coils for ch in pick_chs): + return False + + # merge_channels should be a list of overlap sets, not a boolean + if not isinstance(merge_channels, (list, tuple)): + return False + + return any(len(overlap_set) >= 2 for overlap_set in merge_channels) + + def _plot_update_evoked_topomap(params, bools): """Update topomaps.""" from ..channels.layout import _merge_ch_data @@ -1681,7 +1762,6 @@ def plot_ica_components( """ # noqa E501 from matplotlib.pyplot import Axes - from ..channels.layout import _merge_ch_data from ..epochs import BaseEpochs from ..io import BaseRaw @@ -1714,9 +1794,9 @@ def plot_ica_components( axes = axes.flatten() if isinstance(axes, np.ndarray) else axes for k, picks in enumerate(pick_groups): - try: # either an iterable, 1D numpy array or others - _axes = axes[k * max_subplots : (k + 1) * max_subplots] - except TypeError: # None or Axes + if axes is None: + _axes = None + else: _axes = axes ( @@ -1729,7 +1809,6 @@ def plot_ica_components( clip_origin, ) = _prepare_topomap_plot(ica, ch_type, sphere=sphere) cmap = _setup_cmap(cmap, n_axes=len(picks)) - disp_names = _prepare_sensor_names(names, show_names) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) data = np.dot( @@ -1738,64 +1817,96 @@ def plot_ica_components( data = np.atleast_2d(data) data = data[:, data_picks] + use_opm_orientation_groups = _should_use_opm_orientation_groups( + ica.info, data_picks, merge_channels, ch_type + ) + n_group_axes = 2 if use_opm_orientation_groups else 1 + if title is None: title = "ICA components" user_passed_axes = _axes is not None if not user_passed_axes: - fig, _axes, _, _ = _prepare_trellis(len(data), ncols=ncols, nrows=nrows) + fig, _axes, _, _ = _prepare_trellis( + len(data) * n_group_axes, ncols=ncols, nrows=nrows + ) fig.suptitle(title) else: _axes = [_axes] if isinstance(_axes, Axes) else _axes + if len(_axes) != len(data) * n_group_axes: + raise RuntimeError( + "You must provide one axis per component and orientation " + "group for colocated OPM data." + ) fig = _axes[0].get_figure() subplot_titles = list() - for ii, data_, ax in zip(picks, data, _axes): + for comp_offset, (ii, data_) in enumerate(zip(picks, data)): kwargs = dict(color="gray") if ii in ica.exclude else dict() comp_title = ica._ica_names[ii] if len(set(ica.get_channel_types())) > 1: comp_title += f" ({ch_type})" - subplot_titles.append(ax.set_title(comp_title, fontsize=12, **kwargs)) - if merge_channels: - data_, names_ = _merge_ch_data(data_, ch_type, copy.copy(names)) - # ↓↓↓ NOTE: we intentionally use the default norm=False here, so that - # ↓↓↓ we get vlims that are symmetric-about-zero, even if the data for - # ↓↓↓ a given component happens to be one-sided. - _vlim = _setup_vmin_vmax(data_, *vlim) - im = plot_topomap( - data_.flatten(), + + modality = "opm" if use_opm_orientation_groups else "other" + grouped_data = _compute_orientation_group_data( + data_[:, np.newaxis], + copy.copy(names), pos, ch_type=ch_type, - sensors=sensors, - names=disp_names, - contours=contours, - outlines=outlines, - sphere=sphere, - image_interp=image_interp, - extrapolate=extrapolate, - border=border, - res=res, - size=size, - cmap=cmap[0], - vlim=_vlim, - cnorm=cnorm, - axes=ax, - show=False, - )[0] - - im.axes.set_label(ica._ica_names[ii]) - if colorbar: - cbar, cax = _add_colorbar( - ax, - im, - cmap, - title="AU", - format_=cbar_fmt, - kind="ica_comp_topomap", + modality=modality, + merge_channels=merge_channels, + use_opm_orientation_groups=use_opm_orientation_groups, + ) + + for group_idx, ( + group_label, + group_data, + group_pos, + group_names, + group_norm, + ) in enumerate(grouped_data): + ax_idx = comp_offset * n_group_axes + group_idx + ax = _axes[ax_idx] + plot_title = comp_title + if group_label is not None: + plot_title += f" [{group_label}]" + subplot_titles.append(ax.set_title(plot_title, fontsize=12, **kwargs)) + _vlim = _setup_vmin_vmax(group_data[:, 0], *vlim, norm=group_norm) + group_cmap = _setup_cmap(cmap, n_axes=len(picks), norm=group_norm) + im = plot_topomap( + group_data[:, 0].flatten(), + group_pos, ch_type=ch_type, - ) - cbar.ax.tick_params(labelsize=12) - cbar.set_ticks(_vlim) - _hide_frame(ax) + sensors=sensors, + names=_prepare_sensor_names(group_names, show_names), + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=group_cmap[0], + vlim=_vlim, + cnorm=cnorm, + axes=ax, + show=False, + )[0] + + im.axes.set_label(ica._ica_names[ii]) + if colorbar: + cbar, cax = _add_colorbar( + ax, + im, + group_cmap, + title="AU", + format_=cbar_fmt, + kind="ica_comp_topomap", + ch_type=ch_type, + ) + cbar.ax.tick_params(labelsize=12) + cbar.set_ticks(_vlim) + _hide_frame(ax) del pos fig.canvas.draw() @@ -2259,11 +2370,18 @@ def plot_evoked_topomap( clip_origin, ) = _prepare_topomap_plot(evoked, ch_type, sphere=sphere) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) + use_opm_orientation_groups = _should_use_opm_orientation_groups( + evoked.info, picks, merge_channels, ch_type + ) # check interactive axes_given = axes is not None interactive = isinstance(times, str) and times == "interactive" if interactive and axes_given: raise ValueError("User-provided axes not allowed when times='interactive'.") + if interactive and use_opm_orientation_groups: + raise NotImplementedError( + "times='interactive' is not supported for grouped OPM topomaps." + ) # units, scalings key = "grad" if ch_type.startswith("planar") else ch_type default_scaling = _handle_default("scalings", None)[key] @@ -2273,7 +2391,6 @@ def plot_evoked_topomap( unit = _handle_default("units", units)[key] # ch_names (required for NIRS) ch_names = names - names = _prepare_sensor_names(names, show_names) # apply projections before picking. NOTE: the `if proj is True` # anti-pattern is needed here to exclude proj='interactive' _check_option("proj", proj, (True, False, "interactive", "reconstruct")) @@ -2299,7 +2416,9 @@ def plot_evoked_topomap( f"Times should be between {evoked.times[0]:0.3} and {evoked.times[-1]:0.3}." ) # create axes - want_axes = n_times + int(colorbar) + n_groups = 2 if use_opm_orientation_groups else 1 + n_cbar = int(colorbar) * n_groups + want_axes = n_times * n_groups + n_cbar if interactive: height_ratios = [5, 1] nrows = 2 @@ -2313,7 +2432,7 @@ def plot_evoked_topomap( axes.append(plt.subplot(gs[0, ax_idx])) elif axes is None: fig, axes, ncols, nrows = _prepare_trellis( - n_times, ncols=ncols, nrows=nrows, size=size + n_times * n_groups, ncols=ncols, nrows=nrows, size=size ) else: nrows, ncols = None, None # Deactivate ncols when axes were passed @@ -2326,6 +2445,12 @@ def plot_evoked_topomap( f"each time{cbar_err}), got {len(axes)}." ) del want_axes + if axes_given and colorbar: + plot_axes = axes[:-n_cbar] + cbar_axes = axes[-n_cbar:] + else: + plot_axes = axes + cbar_axes = [] # find first index that's >= (to rounding error) to each time point time_idx = [ np.where( @@ -2386,53 +2511,61 @@ def plot_evoked_topomap( # apply scalings and merge channels data *= scaling - if merge_channels: - # check modality - if any(ch["coil_type"] in _opm_coils for ch in evoked.info["chs"]): - modality = "opm" - elif ch_type in _fnirs_types: - modality = "fnirs" - else: - modality = "other" - # merge data - data, ch_names = _merge_ch_data(data, ch_type, ch_names, modality=modality) - # if ch_type in _fnirs_types: - if modality != "other": - merge_channels = False - # apply mask if requested - if mask is not None: - mask = mask.astype(bool, copy=False) - if ch_type == "grad": - mask_ = ( - mask[np.ix_(picks[::2], time_idx)] | mask[np.ix_(picks[1::2], time_idx)] - ) - else: # mag, eeg, planar1, planar2 - mask_ = mask[np.ix_(picks, time_idx)] - # set up colormap - _vlim = [ - _setup_vmin_vmax(data[:, i], *vlim, norm=merge_channels) for i in range(n_times) - ] - _vlim = [np.min(_vlim), np.max(_vlim)] - cmap = _setup_cmap(cmap, n_axes=n_times, norm=_vlim[0] >= 0) - # set up contours - if not isinstance(contours, list | np.ndarray): - _, contours = _set_contour_locator(*_vlim, contours) + # check modality + is_opm_picks = len(evoked.info["chs"]) > 0 and all( + ch["coil_type"] in _opm_coils for ch in evoked.info["chs"] + ) + if is_opm_picks: + modality = "opm" + elif ch_type in _fnirs_types: + modality = "fnirs" else: - if vlim[0] is None and np.any(contours < _vlim[0]): - _vlim[0] = contours[0] - if vlim[1] is None and np.any(contours > _vlim[1]): - _vlim[1] = contours[-1] + modality = "other" + + grouped_data = _compute_orientation_group_data( + data, + ch_names, + pos, + ch_type=ch_type, + modality=modality, + merge_channels=merge_channels, + use_opm_orientation_groups=use_opm_orientation_groups, + ) + + if modality != "other" or use_opm_orientation_groups: + merge_channels = False + + # set up colormaps, vlims, and contours per group + group_vlims = [] + group_cmaps = [] + group_contours = [] + for group_label, group_data, group_pos, group_names, group_norm in grouped_data: + group_vlim = [ + _setup_vmin_vmax(group_data[:, i], *vlim, norm=group_norm) + for i in range(n_times) + ] + group_vlim = [np.min(group_vlim), np.max(group_vlim)] + if not isinstance(contours, list | np.ndarray): + _, group_contour = _set_contour_locator(*group_vlim, contours) + else: + group_contour = contours + if vlim[0] is None and np.any(group_contour < group_vlim[0]): + group_vlim[0] = group_contour[0] + if vlim[1] is None and np.any(group_contour > group_vlim[1]): + group_vlim[1] = group_contour[-1] + group_vlims.append(group_vlim) + group_cmaps.append( + _setup_cmap(cmap, n_axes=n_times, norm=group_vlim[0] >= 0)[0] + ) + group_contours.append(group_contour) # prepare for main loop over times kwargs = dict( sensors=sensors, res=res, - names=names, - cmap=cmap[0], cnorm=cnorm, mask_params=mask_params, outlines=outlines, - contours=contours, image_interp=image_interp, show=False, extrapolate=extrapolate, @@ -2441,38 +2574,51 @@ def plot_evoked_topomap( ch_type=ch_type, ) images, contours_ = [], [] - # loop over times - for average_idx, (time, this_average) in enumerate(zip(times, average)): - tp, cn, interp = _plot_topomap( - data[:, average_idx], - pos, - axes=axes[average_idx], - mask=mask_[:, average_idx] if mask is not None else None, - vmin=_vlim[0], - vmax=_vlim[1], - **kwargs, - ) + for group_idx, ( + group_label, + group_data, + group_pos, + group_names, + _group_norm, + ) in enumerate(grouped_data): + kwargs["names"] = _prepare_sensor_names(group_names, show_names) + kwargs["cmap"] = group_cmaps[group_idx] + kwargs["contours"] = group_contours[group_idx] + group_vlim = group_vlims[group_idx] + for average_idx, (time, this_average) in enumerate(zip(times, average)): + ax_idx = group_idx * n_times + average_idx + tp, cn, interp = _plot_topomap( + group_data[:, average_idx], + group_pos, + axes=plot_axes[ax_idx], + mask=None, + vmin=group_vlim[0], + vmax=group_vlim[1], + **kwargs, + ) - images.append(tp) - if cn is not None: - contours_.append(cn) - if time_format != "": - if this_average is None: - axes_title = time_format % (time * scaling_time) - else: - tmin_ = averaged_times[average_idx][0] - tmax_ = averaged_times[average_idx][-1] - from_time = time_format % (tmin_ * scaling_time) - from_time = from_time.split(" ")[0] # Remove unit - to_time = time_format % (tmax_ * scaling_time) - axes_title = f"{from_time} – {to_time}" - del from_time, to_time, tmin_, tmax_ - axes[average_idx].set_title(axes_title) + images.append(tp) + if cn is not None: + contours_.append(cn) + if time_format != "": + if this_average is None: + axes_title = time_format % (time * scaling_time) + else: + tmin_ = averaged_times[average_idx][0] + tmax_ = averaged_times[average_idx][-1] + from_time = time_format % (tmin_ * scaling_time) + from_time = from_time.split(" ")[0] # Remove unit + to_time = time_format % (tmax_ * scaling_time) + axes_title = f"{from_time} – {to_time}" + del from_time, to_time, tmin_, tmax_ + if group_label is not None: + axes_title = f"{group_label}\n{axes_title}" + plot_axes[ax_idx].set_title(axes_title) if interactive: # Add a slider to the figure and start publishing and subscribing to time_change # events. - kwargs.update(vlim=_vlim) + kwargs.update(vlim=group_vlims[0], cmap=group_cmaps[0]) axes.append(fig.add_subplot(gs[1])) slider = Slider( axes[-1], @@ -2523,17 +2669,30 @@ def _slider_changed(val): else: # use the default behavior cax = None - cbar = fig.colorbar(images[-1], ax=axes, cax=cax, format=cbar_fmt, shrink=0.6) - if unit is not None: - cbar.ax.set_title(unit) - if cn is not None: - cbar.set_ticks(contours) - cbar.ax.tick_params(labelsize=7) - if cmap[1]: - for im in images: - im.axes.CB = DraggableColorbar( - cbar, im, kind="evoked_topomap", ch_type=ch_type - ) + for group_idx in range(n_groups): + group_images = images[group_idx * n_times : (group_idx + 1) * n_times] + group_axes = plot_axes[group_idx * n_times : (group_idx + 1) * n_times] + if axes_given and colorbar: + cax = cbar_axes[group_idx] + else: + cax = None + cbar = fig.colorbar( + group_images[-1], + ax=group_axes, + cax=cax, + format=cbar_fmt, + shrink=0.6, + ) + if unit is not None: + cbar.ax.set_title(unit) + if cn is not None: + cbar.set_ticks(group_contours[group_idx]) + cbar.ax.tick_params(labelsize=7) + if group_cmaps[group_idx][1]: + for im in group_images: + im.axes.CB = DraggableColorbar( + cbar, im, kind="evoked_topomap", ch_type=ch_type + ) if proj == "interactive": _check_delayed_ssp(evoked)