From f91f693f328df3f6f524c093ea77ab81a903fab5 Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Sat, 6 Dec 2025 07:19:41 +0100 Subject: [PATCH 1/6] minor: change deprecated channel_slice to select_channels --- src/spikeinterface/core/channelslice.py | 2 +- .../preprocessing/tests/test_detect_bad_channels.py | 2 +- .../sorters/external/tests/test_docker_containers.py | 2 +- .../sorters/external/tests/test_singularity_containers.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 8a4f29e86c..bc5143fa13 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -106,7 +106,7 @@ class ChannelSliceSnippets(BaseSnippets): """ Class to slice a Snippets object based on channel_ids. - Do not use this class directly but use `snippets.channel_slice(...)` + Do not use this class directly but use `snippets.select_channels(...)` """ diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 31ea5f5523..836a9adefb 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -230,7 +230,7 @@ def test_detect_bad_channels_ibl(num_channels): # Test on randomly sorted channels rng = np.random.default_rng(seed=None) - recording_scrambled = recording.channel_slice( + recording_scrambled = recording.select_channels( rng.choice(recording.channel_ids, len(recording.channel_ids), replace=False) ) bad_channel_ids_scrambled, bad_channel_label_scrambled = detect_bad_channels( diff --git a/src/spikeinterface/sorters/external/tests/test_docker_containers.py b/src/spikeinterface/sorters/external/tests/test_docker_containers.py index a3841a9d3e..6204d57ef2 100644 --- a/src/spikeinterface/sorters/external/tests/test_docker_containers.py +++ b/src/spikeinterface/sorters/external/tests/test_docker_containers.py @@ -93,7 +93,7 @@ def test_combinato(run_kwargs, create_cache_folder): cache_folder = create_cache_folder rec = run_kwargs["recording"] channels = rec.get_channel_ids()[0:1] - rec_one_channel = rec.channel_slice(channels) + rec_one_channel = rec.select_channels(channels) run_kwargs["recording"] = rec_one_channel sorting = ss.run_sorter(sorter_name="combinato", folder=cache_folder / "combinato", **run_kwargs) print(sorting) diff --git a/src/spikeinterface/sorters/external/tests/test_singularity_containers.py b/src/spikeinterface/sorters/external/tests/test_singularity_containers.py index 61b928b6f7..a26ecc2d11 100644 --- a/src/spikeinterface/sorters/external/tests/test_singularity_containers.py +++ b/src/spikeinterface/sorters/external/tests/test_singularity_containers.py @@ -108,7 +108,7 @@ def test_combinato(run_kwargs, create_cache_folder): clean_singularity_cache() rec = run_kwargs["recording"] channels = rec.get_channel_ids()[0:1] - rec_one_channel = rec.channel_slice(channels) + rec_one_channel = rec.select_channels(channels) run_kwargs["recording"] = rec_one_channel sorting = ss.run_sorter(sorter_name="combinato", folder=cache_folder / "combinato", **run_kwargs) print(sorting) From 7882052de293c9605cfa92672be39ad09180487b Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Sat, 6 Dec 2025 07:20:45 +0100 Subject: [PATCH 2/6] delete _channel_slice (used by no previous channel_slice function) --- src/spikeinterface/core/basesnippets.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 09c4416c2d..6360ecaa45 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -201,18 +201,6 @@ def select_channels(self, channel_ids: list | np.array | tuple) -> "BaseSnippets return ChannelSliceSnippets(self, channel_ids) - def _channel_slice(self, channel_ids, renamed_channel_ids=None): - from .channelslice import ChannelSliceSnippets - import warnings - - warnings.warn( - "Snippets.channel_slice will be removed in version 0.103, use `select_channels` or `rename_channels` instead.", - DeprecationWarning, - stacklevel=2, - ) - sub_recording = ChannelSliceSnippets(self, channel_ids, renamed_channel_ids=renamed_channel_ids) - return sub_recording - def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceSnippets From 47e31d710ce6be95c6f39afcde69e688750306e7 Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Thu, 25 Dec 2025 22:37:06 +0100 Subject: [PATCH 3/6] fix edges in density_map --- .../widgets/unit_waveforms_density_map.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9543cbf734..627e6c2af9 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -83,8 +83,8 @@ def __init__( templates = ext_templates.get_templates(unit_ids=unit_ids) bin_min = np.min(templates) * 1.3 bin_max = np.max(templates) * 1.3 - bin_size = (bin_max - bin_min) / 100 - bins = np.arange(bin_min, bin_max, bin_size) + num_bins = 100 + bins = np.linspace(bin_min, bin_max, num_bins + 1) # 2d histograms if same_axis: @@ -121,14 +121,9 @@ def __init__( wfs = wfs_ # make histogram density - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) - hist2d = np.zeros((wfs_flat.shape[1], bins.size)) - indexes0 = np.arange(wfs_flat.shape[1]) - - wf_bined = np.floor((wfs_flat - bin_min) / bin_size).astype("int32") - wf_bined = wf_bined.clip(0, bins.size - 1) - for d in wf_bined: - hist2d[indexes0, d] += 1 + wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) # num_spikes x times*num_channels + hists_per_timepoint = [np.histogram(one_timepoint, bins=bins)[0] for one_timepoint in wfs_flat.T] + hist2d = np.stack(hists_per_timepoint) if same_axis: if all_hist2d is None: @@ -169,7 +164,6 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) From c30f5879ceab3cfff75175119470d970f38b5416 Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Thu, 8 Jan 2026 18:07:39 +0100 Subject: [PATCH 4/6] minor: code formatting --- .../widgets/unit_waveforms_density_map.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 627e6c2af9..3bbb2fbae1 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -168,16 +168,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) - if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None: - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - else: - if dp.same_axis: - num_axes = 1 - else: - num_axes = len(dp.unit_ids) + if backend_kwargs['axes'] is None and backend_kwargs['ax'] is None: backend_kwargs["ncols"] = 1 - backend_kwargs["num_axes"] = num_axes - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + backend_kwargs["num_axes"] = 1 if dp.same_axis else len(dp.unit_ids) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) if dp.same_axis: ax = self.ax From e5b932c35a0422cdd510c2a4d754cd09e732eac9 Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Fri, 9 Jan 2026 02:10:39 +0100 Subject: [PATCH 5/6] plot density map with msecs on x axis --- src/spikeinterface/widgets/unit_summary.py | 1 + .../widgets/unit_waveforms_density_map.py | 32 +++++++++---------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index fb26a228ef..e86689ebf4 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -180,6 +180,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): **unitwaveformdensitymapwidget_kwargs, ) col_counter += 1 + ax_waveform_density.set_xlabel(None) ax_waveform_density.set_ylabel(None) if sorting_analyzer.has_extension("correlograms"): diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 3bbb2fbae1..d4cafe1eb6 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -121,7 +121,7 @@ def __init__( wfs = wfs_ # make histogram density - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) # num_spikes x times*num_channels + wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) # num_spikes x (num_channels * timepoints) hists_per_timepoint = [np.histogram(one_timepoint, bins=bins)[0] for one_timepoint in wfs_flat.T] hist2d = np.stack(hists_per_timepoint) @@ -157,6 +157,7 @@ def __init__( bin_min=bin_min, bin_max=bin_max, all_hist2d=all_hist2d, + sampling_frequency=sorting_analyzer.sampling_frequency, templates_flat=templates_flat, template_width=wfs.shape[1], ) @@ -173,37 +174,36 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): backend_kwargs["num_axes"] = 1 if dp.same_axis else len(dp.unit_ids) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + freq_khz = dp.sampling_frequency / 1000 # samples / msec if dp.same_axis: - ax = self.ax hist2d = dp.all_hist2d - im = ax.imshow( + x_max = len(hist2d) / freq_khz # in milliseconds + self.ax.imshow( hist2d.T, interpolation="nearest", origin="lower", aspect="auto", - extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), + extent=(0, x_max, dp.bin_min, dp.bin_max), cmap="hot", ) else: - for unit_index, unit_id in enumerate(dp.unit_ids): + for ax, unit_id in zip(self.axes.flatten(), dp.unit_ids): hist2d = dp.all_hist2d[unit_id] - ax = self.axes.flatten()[unit_index] - im = ax.imshow( + x_max = len(hist2d) / freq_khz # in milliseconds + ax.imshow( hist2d.T, interpolation="nearest", origin="lower", aspect="auto", - extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), + extent=(0, x_max, dp.bin_min, dp.bin_max), cmap="hot", ) for unit_index, unit_id in enumerate(dp.unit_ids): - if dp.same_axis: - ax = self.ax - else: - ax = self.axes.flatten()[unit_index] + ax = self.ax if dp.same_axis else self.axes.flatten()[unit_index] color = dp.unit_colors[unit_id] - ax.plot(dp.templates_flat[unit_id], color=color, lw=1) + x = np.arange(len(dp.templates_flat[unit_id])) / freq_khz + ax.plot(x, dp.templates_flat[unit_id], color=color, lw=1) # final cosmetics for unit_index, unit_id in enumerate(dp.unit_ids): @@ -216,11 +216,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): chan_inds = dp.channel_inds[unit_id] for i, chan_ind in enumerate(chan_inds): if i != 0: - ax.axvline(i * dp.template_width, color="w", lw=3) + ax.axvline(i * dp.template_width / freq_khz, color="w", lw=3) channel_id = dp.channel_ids[chan_ind] - x = i * dp.template_width + dp.template_width // 2 + x = (i + 0.5) * dp.template_width / freq_khz y = (dp.bin_max + dp.bin_min) / 2.0 ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") - ax.set_xticks([]) + ax.set_xlabel('Time [ms]') ax.set_ylabel(f"unit_id {unit_id}") From 4a08ec642eeda68d7bdb770af197006ef4c03412 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 02:57:05 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/unit_waveforms_density_map.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index d4cafe1eb6..a4260ac752 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -169,7 +169,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) - if backend_kwargs['axes'] is None and backend_kwargs['ax'] is None: + if backend_kwargs["axes"] is None and backend_kwargs["ax"] is None: backend_kwargs["ncols"] = 1 backend_kwargs["num_axes"] = 1 if dp.same_axis else len(dp.unit_ids) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) @@ -189,7 +189,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: for ax, unit_id in zip(self.axes.flatten(), dp.unit_ids): hist2d = dp.all_hist2d[unit_id] - x_max = len(hist2d) / freq_khz # in milliseconds + x_max = len(hist2d) / freq_khz # in milliseconds ax.imshow( hist2d.T, interpolation="nearest", @@ -222,5 +222,5 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): y = (dp.bin_max + dp.bin_min) / 2.0 ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") - ax.set_xlabel('Time [ms]') + ax.set_xlabel("Time [ms]") ax.set_ylabel(f"unit_id {unit_id}")