diff --git a/doc/changes/dev/13931.newfeature.rst b/doc/changes/dev/13931.newfeature.rst new file mode 100644 index 00000000000..73cc9276b88 --- /dev/null +++ b/doc/changes/dev/13931.newfeature.rst @@ -0,0 +1 @@ +Add support for interactive label browsing using ``hover=True`` in :meth:`mne.viz.Brain.add_annotation`, by `Eric Larson`_. \ No newline at end of file diff --git a/mne/conftest.py b/mne/conftest.py index 2ef937c7bd8..bcf9c41cfc4 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -94,7 +94,8 @@ def pytest_configure(config: pytest.Config): "slowtest: mark a test as slow", "ultraslowtest: mark a test as ultraslow or to be run rarely", "pgtest: mark a test as relevant for mne-qt-browser", - "pvtest: mark a test as relevant for pyvistaqt", + # used by PyVista's MNE integration tests (but also useful in some testing): + "pvtest: mark a test as relevant for pyvista", ): config.addinivalue_line("markers", marker) diff --git a/mne/label.py b/mne/label.py index 2d55de755c0..9689f5959b7 100644 --- a/mne/label.py +++ b/mne/label.py @@ -40,6 +40,7 @@ _check_fname, _check_option, _check_subject, + _import_nibabel, _validate_type, check_random_state, fill_doc, @@ -2123,87 +2124,6 @@ def _read_annot_cands(dir_name, raise_error=True): return cands -def _read_annot(fname): - """Read a Freesurfer annotation from a .annot file. - - Note : Copied from PySurfer - - Parameters - ---------- - fname : str - Path to annotation file - - Returns - ------- - annot : numpy array, shape=(n_verts) - Annotation id at each vertex - ctab : numpy array, shape=(n_entries, 5) - RGBA + label id colortable array - names : list of str - List of region names as stored in the annot file - - """ - if not op.isfile(fname): - dir_name = op.split(fname)[0] - cands = _read_annot_cands(dir_name) - if len(cands) == 0: - raise OSError( - f"No such file {fname}, no candidate parcellations found in directory" - ) - else: - raise OSError( - f"No such file {fname}, candidate parcellations in " - "that directory:\n" + "\n".join(cands) - ) - with open(fname, "rb") as fid: - n_verts = np.fromfile(fid, ">i4", 1)[0] - data = np.fromfile(fid, ">i4", n_verts * 2).reshape(n_verts, 2) - annot = data[data[:, 0], 1] - ctab_exists = np.fromfile(fid, ">i4", 1)[0] - if not ctab_exists: - raise Exception("Color table not found in annotation file") - n_entries = np.fromfile(fid, ">i4", 1)[0] - if n_entries > 0: - length = np.fromfile(fid, ">i4", 1)[0] - np.fromfile(fid, ">c", length) # discard orig_tab - - names = list() - ctab = np.zeros((n_entries, 5), np.int64) - for i in range(n_entries): - name_length = np.fromfile(fid, ">i4", 1)[0] - name = np.fromfile(fid, f"|S{name_length}", 1)[0] - names.append(name) - ctab[i, :4] = np.fromfile(fid, ">i4", 4) - ctab[i, 4] = ( - ctab[i, 0] - + ctab[i, 1] * (2**8) - + ctab[i, 2] * (2**16) - + ctab[i, 3] * (2**24) - ) - else: - ctab_version = -n_entries - if ctab_version != 2: - raise Exception("Color table version not supported") - n_entries = np.fromfile(fid, ">i4", 1)[0] - ctab = np.zeros((n_entries, 5), np.int64) - length = np.fromfile(fid, ">i4", 1)[0] - np.fromfile(fid, f"|S{length}", 1) # Orig table path - entries_to_read = np.fromfile(fid, ">i4", 1)[0] - names = list() - for i in range(entries_to_read): - np.fromfile(fid, ">i4", 1) # Structure - name_length = np.fromfile(fid, ">i4", 1)[0] - name = np.fromfile(fid, f"|S{name_length}", 1)[0] - names.append(name) - ctab[i, :4] = np.fromfile(fid, ">i4", 4) - ctab[i, 4] = ctab[i, 0] + ctab[i, 1] * (2**8) + ctab[i, 2] * (2**16) - - # convert to more common alpha value - ctab[:, 3] = 255 - ctab[:, 3] - - return annot, ctab, names - - def _get_annot_fname(annot_fname, subject, hemi, parc, subjects_dir): """Get the .annot filenames and hemispheres.""" if annot_fname is not None: @@ -2251,6 +2171,7 @@ def read_labels_from_annot( parc="aparc", hemi="both", surf_name="white", + *, annot_fname=None, regexp=None, subjects_dir=None, @@ -2295,6 +2216,8 @@ def read_labels_from_annot( write_labels_to_annot morph_labels """ + nib = _import_nibabel("Reading labels from parcellations") + logger.info("Reading labels from parcellation...") subjects_dir = get_subjects_dir(subjects_dir) @@ -2318,7 +2241,9 @@ def read_labels_from_annot( orig_names = set() for fname, hemi in zip(annot_fname, hemis): # read annotation - annot, ctab, label_names = _read_annot(fname) + _check_fname(fname, overwrite="read", must_exist=True, name="annotation file") + annot, ctab, label_names = nib.freesurfer.io.read_annot(fname, orig_ids=True) + ctab[:, 3] = 255 - ctab[:, 3] label_rgbas = ctab[:, :4] / 255.0 label_ids = ctab[:, -1] @@ -2362,7 +2287,7 @@ def read_labels_from_annot( labels = sorted(labels, key=lambda label: label.name) if len(labels) == 0: - msg = "No labels found." + msg = f"No labels found in {annot_fname[0]}." if regexp is not None: orig_names = "\n".join(sorted(orig_names)) msg += ( diff --git a/mne/tests/test_label.py b/mne/tests/test_label.py index f93721188d4..a174f745656 100644 --- a/mne/tests/test_label.py +++ b/mne/tests/test_label.py @@ -41,7 +41,6 @@ _blend_colors, _load_vert_pos, _n_colors, - _read_annot, _read_annot_cands, label_sign_flip, select_sources, @@ -380,7 +379,7 @@ def test_annot_io(tmp_path): shutil.copy(surf_src / "rh.white", surf_dir) # read original labels - with pytest.raises(OSError, match="\nPALS_B12_Lobes$"): + with pytest.raises(OSError, match="PALS_B12_Lobesey"): read_labels_from_annot(subject, "PALS_B12_Lobesey", subjects_dir=tmp_path) labels = read_labels_from_annot(subject, "PALS_B12_Lobes", subjects_dir=tmp_path) @@ -486,8 +485,6 @@ def test_read_labels_from_annot(tmp_path): ) with pytest.raises(OSError, match="does not exist"): _read_annot_cands("foo") - with pytest.raises(OSError, match="no candidate"): - _read_annot(str(tmp_path)) # read labels using hemi specification labels_lh = read_labels_from_annot("sample", hemi="lh", subjects_dir=subjects_dir) diff --git a/mne/utils/misc.py b/mne/utils/misc.py index 12e22197c76..38c14cbceb0 100644 --- a/mne/utils/misc.py +++ b/mne/utils/misc.py @@ -497,7 +497,13 @@ def _auto_weakref(function): __weakref_values__ = dict() evaldict = dict(__weakref_values__=__weakref_values__) for name, value in zip(names, function.__closure__): - __weakref_values__[name] = weakref.ref(value.cell_contents) + try: + __weakref_values__[name] = weakref.ref(value.cell_contents) + except TypeError: # pragma: no cover + raise TypeError( + f"Cannot create weak reference to {name} " + f"(type {type(value.cell_contents)})" + ) body = dedent(inspect.getsource(function)) body = body.splitlines() for li, line in enumerate(body): diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index ad7c7a40995..6a670d67398 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -40,10 +40,8 @@ from ...utils import ( Bunch, _auto_weakref, - _check_fname, _check_option, _ensure_int, - _path_like, _ReuseCycle, _to_rgb, _validate_type, @@ -1278,9 +1276,8 @@ def _on_pick(self, vtk_picker, event): vtk_cell.GetPointId(point_id) for point_id in range(vtk_cell.GetNumberOfPoints()) ] - vertices = mesh.points[cell] - idx = np.argmin(abs(vertices - pos), axis=0) - vertex_id = cell[idx[0]] + vert_pos = mesh.points[cell] + vertex_id = cell[np.argmin(np.linalg.norm(vert_pos - pos, axis=1))] publish(self, VertexSelect(hemi=hemi, vertex_id=vertex_id)) @@ -1957,20 +1954,24 @@ def _iter_views(self, hemi): def remove_labels(self): """Remove all the ROI labels from the image.""" for hemi in self._hemis: - mesh = self._layered_meshes[hemi] for label in self._labels[hemi]: - mesh.remove_overlay(label.name) + self._layered_meshes[hemi].remove_overlay(label.name) self._labels[hemi].clear() self._renderer._update() def remove_annotations(self): """Remove all annotations from the image.""" - for hemi in self._hemis: - if hemi in self._layered_meshes: - mesh = self._layered_meshes[hemi] - mesh.remove_overlay(self._annots[hemi]) - if hemi in self._annots: - self._annots[hemi].clear() + for hemi, overlayer in self._layered_meshes.items(): + overlayer.remove_overlay([annot["name"] for annot in self._annots[hemi]]) + for annot in self._annots[hemi]: + if "caption" in annot: + for _ in self._iter_views(hemi): + self.plotter.remove_actor(annot["caption"], render=False) + try: + self.plotter.RemoveObserver(annot["obs"]) + except AttributeError: # can happen during cleanup + pass + self._annots[hemi].clear() self._renderer._update() def _add_volume_data(self, hemi, src, volume_options): @@ -2905,19 +2906,22 @@ def _configure_label_time_course(self): @fill_doc def add_annotation( - self, annot, borders=True, alpha=1, hemi=None, remove_existing=True, color=None + self, + annot, + borders=True, + alpha=1, + hemi=None, + *, + remove_existing=True, + color=None, + hover=True, ): """Add an annotation file. Parameters ---------- - annot : str | tuple - Either path to annotation file or annotation name. Alternatively, - the annotation can be specified as a ``(labels, ctab)`` tuple per - hemisphere, i.e. ``annot=(labels, ctab)`` for a single hemisphere - or ``annot=((lh_labels, lh_ctab), (rh_labels, rh_ctab))`` for both - hemispheres. ``labels`` and ``ctab`` should be arrays as returned - by :func:`nibabel.freesurfer.io.read_annot`. + annot : str + Either path to annotation file or annotation name. borders : bool | int Show only label borders. If int, specify the number of steps (away from the true border) along the cortical mesh to include @@ -2932,93 +2936,172 @@ def add_annotation( color : matplotlib-style color code If used, show all annotations in the same (specified) color. Probably useful only when showing annotation borders. + hover : bool + If True, show annotation labels on hover. + + .. versionadded:: 1.13 """ - from ...label import _read_annot + from ...label import read_labels_from_annot hemis = self._check_hemis(hemi) - - # Figure out where the data is coming from - if _path_like(annot): - if os.path.isfile(annot): - filepath = _check_fname(annot, overwrite="read") - file_hemi, annot = filepath.name.split(".", 1) - if len(hemis) > 1: - if file_hemi == "lh": - filepaths = [filepath, filepath.parent / ("rh." + annot)] - elif file_hemi == "rh": - filepaths = [filepath.parent / ("lh." + annot), filepath] - else: - raise RuntimeError( - "To add both hemispheres simultaneously, filename must " - 'begin with "lh." or "rh."' - ) - else: - filepaths = [filepath] - else: - filepaths = [] - for hemi in hemis: - filepath = op.join( - self._subjects_dir, - self._subject, - "label", - ".".join([hemi, annot, "annot"]), - ) - if not os.path.exists(filepath): - raise ValueError(f"Annotation file {filepath} does not exist") - filepaths += [filepath] - annots = [] - for hemi, filepath in zip(hemis, filepaths): - # Read in the data - labels, cmap, _ = _read_annot(filepath) - annots.append((labels, cmap)) + kwargs = dict() + if os.path.isfile(annot): + kwargs["annot_fname"] = annot else: - annots = [annot] if len(hemis) == 1 else annot - annot = "annotation" + kwargs["parc"] = annot + + for hemi in hemis: + labels = read_labels_from_annot( + self._subject, hemi=hemi, subjects_dir=self._subjects_dir, **kwargs + ) + n_labels = len(labels) + ids = np.zeros(self.geo[hemi].coords.shape[0], dtype=int) + cmap = np.zeros((len(labels) + 1, 4)) + cmap[:, 3] = 1 + cmap[0] = np.array(self._brain_color) + cmap[0, 3] = 0.0 + centroids = np.zeros((len(labels) + 1, 3)) + for li, label in enumerate(labels): + ids[label.vertices] = li # will have one added later + cmap[li + 1] = label.color + label.values[:] = 1 + centroids[li] = self.geo[hemi].coords[ + label.center_of_mass(subjects_dir=self._subjects_dir) + ] + self._annots[hemi].append( + dict(name=annot, labels=labels, ids=ids, centroids=centroids) + ) + del labels - for hemi, (labels, cmap) in zip(hemis, annots): # Maybe zero-out the non-border vertices - self._to_borders(labels, hemi, borders) - - # Handle null labels properly - cmap[:, 3] = 255 - bgcolor = np.round(np.array(self._brain_color) * 255).astype(int) - bgcolor[-1] = 0 - cmap[cmap[:, 4] < 0, 4] += 2**24 # wrap to positive - cmap[cmap[:, 4] <= 0, :4] = bgcolor - if np.any(labels == 0) and not np.any(cmap[:, -1] <= 0): - cmap = np.vstack((cmap, np.concatenate([bgcolor, [0]]))) - - # Set label ids sensibly - order = np.argsort(cmap[:, -1]) - cmap = cmap[order] - ids = np.searchsorted(cmap[:, -1], labels) - cmap = cmap[:, :4] - - # Set the alpha level - alpha_vec = cmap[:, 3] - alpha_vec[alpha_vec > 0] = alpha * 255 + scalars = ids + 1 # make a copy and reindex + self._to_borders(scalars, hemi, borders) # Override the cmap when a single color is used if color is not None: - rgb = np.round(np.multiply(_to_rgb(color), 255)) - cmap[:, :3] = rgb.astype(cmap.dtype) + cmap[1:, :3] = _to_rgb(color) ctable = cmap.astype(np.float64) for _ in self._iter_views(hemi): mesh = self._layered_meshes[hemi] mesh.add_overlay( - scalars=ids, - colormap=ctable, - rng=[np.min(ids), np.max(ids)], + scalars=scalars, + colormap=ctable * 255, + rng=[0, n_labels], opacity=alpha, name=annot, ) - self._annots[hemi].append(annot) - if not self.time_viewer or self.traces_mode == "vertex": - self._renderer._set_colormap_range( - mesh._actor, cmap.astype(np.uint8), None + + if hover: + obs = self.plotter.AddObserver("MouseMoveEvent", self._on_annotation_hover) + for hemi in hemis: + caption = self._create_caption() + self._annots[hemi][-1].update(caption=caption, obs=obs) + for _ in self._iter_views(hemi): + self.plotter.add_actor( + caption, + name=None, + culling=False, + pickable=False, + reset_camera=False, + render=False, ) + self._renderer._update() + def _create_caption(self): + from vtkmodules.vtkRenderingAnnotation import vtkCaptionActor2D + + caption = vtkCaptionActor2D() + caption.SetVisibility(False) + caption.SetLeader(True) + caption.SetBorder(False) # use the text border instead + caption.GetPositionCoordinate().SetCoordinateSystemToDisplay() + caption.GetPosition2Coordinate().SetCoordinateSystemToDisplay() + caption.SetThreeDimensionalLeader(False) + caption.GetPositionCoordinate().SetValue(20, 20) + caption.GetTextActor().SetTextScaleModeToNone() + prop = caption.GetCaptionTextProperty() + prop.SetFontSize(14) + prop.SetItalic(False) + prop.SetShadow(False) + prop.SetBackgroundOpacity(0.5) + prop.SetColor(*self._fg_color[:3]) + prop.SetFrame(True) + prop.SetFrameWidth(3) + prop.SetBackgroundColor(*self._bg_color[:3]) + return caption + + def _on_annotation_hover(self, iren, event): # event == "MouseMoveEvent" + from pyvista import DataSetMapper + + x, y = iren.GetEventPosition() + picked_renderer = iren.FindPokedRenderer(x, y) + vtk_picker = self._renderer._picker + vtk_picker.Pick(x, y, 0, picked_renderer) + cell_id = vtk_picker.GetCellId() + # This returns a vtkPolyData we don't seem to have access to: + # vtk_picker.GetDataSet() + # So we need to go through the mapper: + mapper = vtk_picker.GetMapper() + if not isinstance(mapper, DataSetMapper) or cell_id == -1: + do_update = False + for annot in self._annots.values(): + if "caption" not in annot[-1]: + continue + caption = annot[-1]["caption"] + if caption.GetVisibility(): + logger.debug("No mesh picked, hiding caption") + caption.SetVisibility(False) + do_update = True + if do_update: + self._renderer._update() + return # didn't find a mesh + for hemi, this_mesh in self._layered_meshes.items(): + if this_mesh._polydata is mapper.dataset: + mesh = this_mesh._polydata + break + else: + return + pos = np.array(vtk_picker.GetPickPosition()) + vtk_cell = mesh.GetCell(cell_id) + cell = [ + vtk_cell.GetPointId(point_id) + for point_id in range(vtk_cell.GetNumberOfPoints()) + ] + vert_pos = mesh.points[cell] + vertex_id = cell[np.argmin(np.linalg.norm(vert_pos - pos, axis=1))] + lidx = self._annots[hemi][-1]["ids"][vertex_id] + label = self._annots[hemi][-1]["labels"][lidx] + centroid = self._annots[hemi][-1]["centroids"][lidx] + caption = self._annots[hemi][-1]["caption"] + if caption.GetCaption() == label.name: + logger.debug("Same label hovered, skipping update") + return # no-op to save a render call + # We have lots of options here... can have the text move with the cursor + # but that's a bit distracting (and slower UX because it takes some + # time to render each time). Could also shift the label in world coords, + # but it's cleaner just to move it by some number of pixels. + logger.debug( + "Hovering label %s from %s %d @ %s", + label.name, + hemi, + vertex_id, + centroid, + ) + other_hemi = "lh" if hemi == "rh" else "rh" + if other_hemi in self._annots: + self._annots[other_hemi][-1]["caption"].SetVisibility(False) + caption.SetCaption(label.name) + caption.SetAttachmentPoint(*centroid) + caption.SetVisibility(True) + actor = caption.GetTextActor() + prop = caption.GetCaptionTextProperty() + prop.SetFrameColor(*label.color[:3]) + # This maybe isn't strictly needed because we hide the frame anyway, but for + # completeness and future compat let's fix our size + wh = np.zeros(2) + actor.GetSize(self.plotter.renderer, wh) + caption.SetPosition2(wh) self._renderer._update() def close(self): diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 212f300c81b..4138817e8a9 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -37,7 +37,7 @@ from mne.minimum_norm import apply_inverse, make_inverse_operator from mne.source_estimate import _BaseSourceEstimate from mne.source_space import read_source_spaces, setup_volume_source_space -from mne.utils import check_version +from mne.utils import catch_logging, check_version from mne.viz import ui_events from mne.viz._brain import Brain, _BrainScraper, _LayeredMesh, _LinkViewer from mne.viz._brain.colormap import calculate_lut @@ -540,7 +540,9 @@ def __init__(self): brain.close() - # add annotation + +def test_add_annotation(renderer_interactive_pyvistaqt, brain_gc): + """Test add_annotation.""" annots = [ "aparc", subjects_dir / "fsaverage" / "label" / "lh.PALS_B12_Lobes.annot", @@ -548,6 +550,7 @@ def __init__(self): borders = [True, 2] alphas = [1, 0.5] colors = [None, "r"] + size = (100, 100) brain = Brain( subject="fsaverage", hemi="both", @@ -555,13 +558,54 @@ def __init__(self): surf="inflated", subjects_dir=subjects_dir, ) - with pytest.raises(ValueError, match="does not exist"): + with pytest.raises(FileNotFoundError, match="does not exist"): brain.add_annotation("foo") - brain.add_annotation(annots[1]) + brain.add_annotation(annots[1], hover=True) + + # mock some events + class MockIrenAndPicker: + def __init__(self): + self._cell_id = 0 + + def GetEventPosition(self): + return 50, 50 # middle of display + + def FindPokedRenderer(self, x, y): + return brain.plotter.renderers[0] + + def Pick(self, x, y, z, renderer): + pass + + def GetMapper(self): + return brain.plotter.mapper + + def GetCellId(self): + return self._cell_id + + def GetPickPosition(self): + return np.zeros(3) + + mocked = MockIrenAndPicker() + brain._renderer._picker = mocked + + with catch_logging(verbose="debug") as log: + brain._on_annotation_hover(mocked, "MouseMoveEvent") + log = log.getvalue() + assert "Hovering label LOBE.FRONTAL" in log + with catch_logging(verbose="debug") as log: + brain._on_annotation_hover(mocked, "MouseMoveEvent") + log = log.getvalue() + assert "Same label hovered" in log + mocked._cell_id = -1 + with catch_logging(verbose="debug") as log: + brain._on_annotation_hover(mocked, "MouseMoveEvent") + log = log.getvalue() + assert "No mesh picked, hiding" in log brain.close() + brain = Brain( subject="fsaverage", - hemi=hemi, + hemi="lh", size=size, surf="inflated", subjects_dir=subjects_dir, diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 50038046e69..4f22430573e 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -263,6 +263,7 @@ def __init__( self._hide_axes() self._toggle_antialias() self._enable_depth_peeling() + self._picker = vtkCellPicker() # FIX: https://github.com/pyvista/pyvistaqt/pull/68 if not hasattr(self.plotter, "iren"): @@ -907,7 +908,6 @@ def _update_picking_callback( add_obs(vtkCommand.RenderEvent, on_mouse_move) add_obs(vtkCommand.LeftButtonPressEvent, on_button_press) add_obs(vtkCommand.EndInteractionEvent, on_button_release) - self._picker = vtkCellPicker() self._picker.AddObserver(vtkCommand.EndPickEvent, on_pick) self._picker.SetVolumeOpacityIsovalue(0.0) @@ -1071,7 +1071,7 @@ def _compute_normals(mesh): ) -def _add_mesh(plotter, *args, **kwargs): +def _add_mesh(plotter, **kwargs): """Patch PyVista add_mesh.""" mesh = kwargs.get("mesh") if "smooth_shading" in kwargs: @@ -1084,7 +1084,7 @@ def _add_mesh(plotter, *args, **kwargs): kwargs["render"] = False if "reset_camera" not in kwargs: kwargs["reset_camera"] = False - actor = plotter.add_mesh(*args, **kwargs) + actor = plotter.add_mesh(**kwargs) if smooth_shading and "Normals" in mesh.point_data: prop = actor.GetProperty() prop.SetInterpolationToPhong()