From c6a6b891ff35c931aa42bb6bd45b20bb24bd72df Mon Sep 17 00:00:00 2001 From: Antorhythms Date: Tue, 31 Mar 2026 19:43:06 +0200 Subject: [PATCH] =?UTF-8?q?Replace=20int()=20truncation=20with=20shared=20?= =?UTF-8?q?ms=5Fto=5Fsamples()=20utility=20for=20waveform/template=20sampl?= =?UTF-8?q?e=20count=20conversions.=20Adds=20ms=5Fto=5Fsamples()=20to=20co?= =?UTF-8?q?re=5Ftools.py=20using=20round()=20instead=20of=20int()=20to=20p?= =?UTF-8?q?revent=20=C2=B11=20sampleinconsistencies=20across=20datasets=20?= =?UTF-8?q?with=20nearly=20identical=20sampling=20rates=20(e.g.=2029999=20?= =?UTF-8?q?vs=2030000=20Hz).=20Adopted=20across=2021=20production=20files?= =?UTF-8?q?=20and=205=20test=20files=20for=20all=20waveform/template=20win?= =?UTF-8?q?dow=20geometry;=20non-waveform=20conversions=20(margins,=20excl?= =?UTF-8?q?usion=20zones,=20refractory=20periods)=20left=20unchanged.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../benchmark/tests/common_benchmark_testing.py | 5 +++-- src/spikeinterface/core/__init__.py | 1 + src/spikeinterface/core/analyzer_extension_core.py | 9 +++++---- src/spikeinterface/core/core_tools.py | 5 +++++ src/spikeinterface/core/generate.py | 14 +++++++------- src/spikeinterface/core/node_pipeline.py | 5 +++-- src/spikeinterface/core/sparsity.py | 5 +++-- src/spikeinterface/core/tests/test_loading.py | 3 ++- .../core/tests/test_waveform_tools.py | 14 +++++++------- .../waveforms_extractor_backwards_compatibility.py | 9 +++++---- .../generation/drifting_generator.py | 3 ++- src/spikeinterface/generation/hybrid_tools.py | 10 +++++----- .../generation/tests/test_hybrid_tools.py | 4 ++-- .../postprocessing/amplitude_scalings.py | 5 +++-- .../preprocessing/remove_artifacts.py | 6 +++--- src/spikeinterface/sorters/internal/lupin.py | 5 +++-- .../sorters/internal/tridesclous2.py | 5 +++-- .../clustering/iterative_isosplit.py | 5 +++-- .../clustering/random_projections.py | 5 +++-- .../sortingcomponents/clustering/tools.py | 9 +++++---- .../sortingcomponents/matching/tdc_peeler.py | 5 +++-- .../peak_detection/matched_filtering.py | 3 ++- src/spikeinterface/sortingcomponents/tools.py | 11 ++++++----- .../sortingcomponents/waveforms/peak_svd.py | 5 +++-- src/spikeinterface/widgets/spikes_on_traces.py | 3 ++- 25 files changed, 89 insertions(+), 65 deletions(-) diff --git a/src/spikeinterface/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/benchmark/tests/common_benchmark_testing.py index 2c9cd957e4..49f9e9350a 100644 --- a/src/spikeinterface/benchmark/tests/common_benchmark_testing.py +++ b/src/spikeinterface/benchmark/tests/common_benchmark_testing.py @@ -16,6 +16,7 @@ estimate_templates, Templates, create_sorting_analyzer, + ms_to_samples, ) from spikeinterface.generation import generate_drifting_recording @@ -54,8 +55,8 @@ def make_dataset(job_kwargs={}): def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_in_uV=False, **job_kwargs): spikes = gt_sorting.to_spike_vector() # [spike_indices] fs = recording.sampling_frequency - nbefore = int(ms_before * fs / 1000) - nafter = int(ms_after * fs / 1000) + nbefore = ms_to_samples(ms_before, fs) + nafter = ms_to_samples(ms_after, fs) templates_array = estimate_templates( recording, spikes, diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 168494caf7..d757d6d98b 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -90,6 +90,7 @@ read_python, write_python, normal_pdf, + ms_to_samples, ) from .job_tools import ( get_best_job_kwargs, diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 53fe7be1f2..1c261e3cad 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -21,6 +21,7 @@ from .template import Templates from .sorting_tools import random_spikes_selection, select_sorting_periods_mask, spike_vector_to_indices from .job_tools import fix_job_kwargs, split_job_kwargs +from .core_tools import ms_to_samples class ComputeRandomSpikes(AnalyzerExtension): @@ -170,11 +171,11 @@ class ComputeWaveforms(AnalyzerExtension): @property def nbefore(self): - return int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0) + return ms_to_samples(self.params["ms_before"], self.sorting_analyzer.sampling_frequency) @property def nafter(self): - return int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) + return ms_to_samples(self.params["ms_after"], self.sorting_analyzer.sampling_frequency) def _run(self, verbose=False, **job_kwargs): self.data.clear() @@ -540,12 +541,12 @@ def _compute_and_append_from_waveforms(self, operators): @property def nbefore(self): - nbefore = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0) + nbefore = ms_to_samples(self.params["ms_before"], self.sorting_analyzer.sampling_frequency) return nbefore @property def nafter(self): - nafter = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) + nafter = ms_to_samples(self.params["ms_after"], self.sorting_analyzer.sampling_frequency) return nafter def _select_extension_data(self, unit_ids): diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index ed98613553..fda08ff1b0 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -757,3 +757,8 @@ def is_path_remote(path: str | Path) -> bool: Whether the path is a remote path. """ return "s3://" in str(path) or "gcs://" in str(path) + + +def ms_to_samples(ms: float, sampling_frequency: float) -> int: + """Convert a duration in milliseconds to the nearest number of samples.""" + return round(ms * sampling_frequency / 1000.0) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 35116a9e4c..1c9ece728f 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -12,7 +12,7 @@ from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting from .snippets_tools import snippets_from_sorting -from .core_tools import define_function_from_class +from .core_tools import define_function_from_class, ms_to_samples def _ensure_seed(seed): @@ -1598,8 +1598,8 @@ def generate_single_fake_waveform( assert ms_after > depolarization_ms + repolarization_ms assert ms_before > depolarization_ms - nbefore = int(sampling_frequency * ms_before / 1000.0) - nafter = int(sampling_frequency * ms_after / 1000.0) + nbefore = ms_to_samples(ms_before, sampling_frequency) + nafter = ms_to_samples(ms_after, sampling_frequency) width = nbefore + nafter wf = np.zeros(width, dtype=dtype) @@ -1776,8 +1776,8 @@ def generate_templates( num_units = units_locations.shape[0] num_channels = channel_locations.shape[0] - nbefore = int(sampling_frequency * ms_before / 1000.0) - nafter = int(sampling_frequency * ms_after / 1000.0) + nbefore = ms_to_samples(ms_before, sampling_frequency) + nafter = ms_to_samples(ms_after, sampling_frequency) width = nbefore + nafter if upsample_factor is not None: @@ -2451,8 +2451,8 @@ def generate_ground_truth_recording( upsample_factor = templates.shape[3] upsample_vector = rng.integers(0, upsample_factor, size=num_spikes) - nbefore = int(ms_before * sampling_frequency / 1000.0) - nafter = int(ms_after * sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, sampling_frequency) + nafter = ms_to_samples(ms_after, sampling_frequency) assert (nbefore + nafter) == templates.shape[1] # construct recording diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 43cdd30c87..ad953382c9 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -14,6 +14,7 @@ from spikeinterface.core import BaseRecording, get_chunk_with_margin from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs, _shared_job_kwargs_doc from spikeinterface.core import get_channel_distances +from spikeinterface.core.core_tools import ms_to_samples class PipelineNode: @@ -314,8 +315,8 @@ def __init__( PipelineNode.__init__(self, recording=recording, parents=parents, return_output=return_output) self.ms_before = ms_before self.ms_after = ms_after - self.nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.0) - self.nafter = int(ms_after * recording.get_sampling_frequency() / 1000.0) + self.nbefore = ms_to_samples(ms_before, recording.get_sampling_frequency()) + self.nafter = ms_to_samples(ms_after, recording.get_sampling_frequency()) self.neighbours_mask = None diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 05963520cd..a203c2ff05 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -7,6 +7,7 @@ from .sorting_tools import random_spikes_selection from .job_tools import _shared_job_kwargs_doc from .waveform_tools import estimate_templates_with_accumulator +from .core_tools import ms_to_samples _sparsity_doc = """ method : str @@ -784,8 +785,8 @@ def estimate_sparsity( probe = recording.create_dummy_probe_from_locations(chan_locs) if method != "by_property": - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) + nafter = ms_to_samples(ms_after, recording.sampling_frequency) num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] random_spikes_indices = random_spikes_selection( diff --git a/src/spikeinterface/core/tests/test_loading.py b/src/spikeinterface/core/tests/test_loading.py index bfaf97ec4a..c9d6e888f9 100644 --- a/src/spikeinterface/core/tests/test_loading.py +++ b/src/spikeinterface/core/tests/test_loading.py @@ -5,6 +5,7 @@ generate_ground_truth_recording, create_sorting_analyzer, load, + ms_to_samples, SortingAnalyzer, Templates, aggregate_channels, @@ -71,7 +72,7 @@ def generate_templates_object(): templates = Templates( templates_array=templates_arr, sampling_frequency=sampling_frequency, - nbefore=int(ms_before * sampling_frequency / 1000), + nbefore=ms_to_samples(ms_before, sampling_frequency), probe=probe, ) return templates diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 88a60a660b..5e0350f833 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -5,7 +5,7 @@ import numpy as np -from spikeinterface.core import generate_recording, generate_sorting, generate_ground_truth_recording +from spikeinterface.core import generate_recording, generate_sorting, generate_ground_truth_recording, ms_to_samples from spikeinterface.core.waveform_tools import ( extract_waveforms_to_buffers, extract_waveforms_to_single_buffer, @@ -56,8 +56,8 @@ def test_waveform_tools(create_cache_folder): recording, sorting = get_dataset() sampling_frequency = recording.sampling_frequency - nbefore = int(3.0 * sampling_frequency / 1000.0) - nafter = int(4.0 * sampling_frequency / 1000.0) + nbefore = ms_to_samples(3.0, sampling_frequency) + nafter = ms_to_samples(4.0, sampling_frequency) dtype = recording.get_dtype() # return_in_uV = False @@ -164,8 +164,8 @@ def test_estimate_templates_with_accumulator(): ms_before = 1.0 ms_after = 1.5 - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) + nafter = ms_to_samples(ms_after, recording.sampling_frequency) spikes = sorting.to_spike_vector() # take one spikes every 10 @@ -218,8 +218,8 @@ def test_estimate_templates(): ms_before = 1.0 ms_after = 1.5 - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) + nafter = ms_to_samples(ms_after, recording.sampling_frequency) spikes = sorting.to_spike_vector() # take one spikes every 10 diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index 61591acb4c..0fde6e70dc 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -21,6 +21,7 @@ from .sparsity import ChannelSparsity from .sortinganalyzer import SortingAnalyzer, load_sorting_analyzer from .loading import load +from .core_tools import ms_to_samples from .analyzer_extension_core import ComputeRandomSpikes, ComputeWaveforms, ComputeTemplates _backwards_compatibility_msg = """#### @@ -162,12 +163,12 @@ def unit_ids(self) -> np.ndarray: @property def nbefore(self) -> int: ms_before = self.sorting_analyzer.get_extension("waveforms").params["ms_before"] - return int(ms_before * self.sampling_frequency / 1000.0) + return ms_to_samples(ms_before, self.sampling_frequency) @property def nafter(self) -> int: ms_after = self.sorting_analyzer.get_extension("waveforms").params["ms_after"] - return int(ms_after * self.sampling_frequency / 1000.0) + return ms_to_samples(ms_after, self.sampling_frequency) @property def nsamples(self) -> int: @@ -522,8 +523,8 @@ def _read_old_waveforms_extractor_binary(folder, sorting): else: max_num_channel = np.max(np.sum(sparsity.mask, axis=1)) - nbefore = int(params["ms_before"] * sorting.sampling_frequency / 1000.0) - nafter = int(params["ms_after"] * sorting.sampling_frequency / 1000.0) + nbefore = ms_to_samples(params["ms_before"], sorting.sampling_frequency) + nafter = ms_to_samples(params["ms_after"], sorting.sampling_frequency) waveforms = np.zeros((num_spikes, nbefore + nafter, max_num_channel), dtype=params["dtype"]) # then read waveforms per units diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 3a80c2aef2..7c388713d7 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -13,6 +13,7 @@ from probeinterface import generate_multi_columns_probe from spikeinterface import Templates +from spikeinterface.core import ms_to_samples from spikeinterface.core.generate import ( generate_unit_locations, generate_sorting, @@ -516,7 +517,7 @@ def generate_drifting_recording( ) ms_before = generate_templates_kwargs["ms_before"] - nbefore = int(sampling_frequency * ms_before / 1000.0) + nbefore = ms_to_samples(ms_before, sampling_frequency) templates = Templates( templates_array=templates_array, sampling_frequency=sampling_frequency, diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 4add37e8a6..2476bc6336 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -2,7 +2,7 @@ from typing import Literal import numpy as np -from spikeinterface.core import BaseRecording, BaseSorting, Templates +from spikeinterface.core import BaseRecording, BaseSorting, Templates, ms_to_samples from spikeinterface.core.generate import ( generate_templates, @@ -71,8 +71,8 @@ def estimate_templates_from_recording( spikes = sorting.to_spike_vector() unit_ids = sorting.unit_ids sampling_frequency = recording.get_sampling_frequency() - nbefore = int(ms_before * sampling_frequency / 1000.0) - nafter = int(ms_after * sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, sampling_frequency) + nafter = ms_to_samples(ms_after, sampling_frequency) job_kwargs = job_kwargs or {} templates_array = estimate_templates(recording, spikes, unit_ids, nbefore, nafter, return_in_uV=False, **job_kwargs) @@ -440,8 +440,8 @@ def generate_hybrid_recording( ) ms_before = generate_templates_kwargs["ms_before"] ms_after = generate_templates_kwargs["ms_after"] - nbefore = int(ms_before * sampling_frequency / 1000.0) - nafter = int(ms_after * sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, sampling_frequency) + nafter = ms_to_samples(ms_after, sampling_frequency) templates_ = Templates(templates_array, sampling_frequency, nbefore, True, None, None, None, probe) else: from spikeinterface.postprocessing.localization_tools import compute_monopolar_triangulation diff --git a/src/spikeinterface/generation/tests/test_hybrid_tools.py b/src/spikeinterface/generation/tests/test_hybrid_tools.py index 936f9dbdc9..772a95a21f 100644 --- a/src/spikeinterface/generation/tests/test_hybrid_tools.py +++ b/src/spikeinterface/generation/tests/test_hybrid_tools.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core import Templates +from spikeinterface.core import Templates, ms_to_samples from spikeinterface.core.generate import ( generate_ground_truth_recording, generate_sorting, @@ -60,7 +60,7 @@ def test_generate_hybrid_from_templates(): templates_array = generate_templates( channel_locations, unit_locations, rec.sampling_frequency, ms_before, ms_after, seed=0 ) - nbefore = int(ms_before * rec.sampling_frequency / 1000) + nbefore = ms_to_samples(ms_before, rec.sampling_frequency) templates = Templates(templates_array, rec.sampling_frequency, nbefore, True, None, None, None, rec.get_probe()) hybrid, sorting_hybrid = generate_hybrid_recording(rec, templates=templates, seed=0) assert np.array_equal(hybrid.templates, templates.templates_array) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 310be8cceb..656febd6bb 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -1,6 +1,7 @@ import numpy as np from spikeinterface.core import ChannelSparsity +from spikeinterface.core.core_tools import ms_to_samples from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -87,7 +88,7 @@ def _get_pipeline_nodes(self): # if ms_before / ms_after are set in params then the original templates are shorten if self.params["ms_before"] is not None: - cut_out_before = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0) + cut_out_before = ms_to_samples(self.params["ms_before"], self.sorting_analyzer.sampling_frequency) assert ( cut_out_before <= nbefore ), f"`ms_before` must be smaller than `ms_before` used in ComputeTemplates: {nbefore}" @@ -95,7 +96,7 @@ def _get_pipeline_nodes(self): cut_out_before = nbefore if self.params["ms_after"] is not None: - cut_out_after = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) + cut_out_after = ms_to_samples(self.params["ms_after"], self.sorting_analyzer.sampling_frequency) assert ( cut_out_after <= nafter ), f"`ms_after` must be smaller than `ms_after` used in templates: {templates_ext.params['ms_after']}" diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 3fc5449ff2..2129e01d90 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -2,7 +2,7 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class, ms_to_samples from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.core import NumpySorting, estimate_templates @@ -170,8 +170,8 @@ def __init__( list_triggers, list_labels, recording.get_sampling_frequency() ) - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) + nafter = ms_to_samples(ms_after, recording.sampling_frequency) templates = estimate_templates( recording=recording, diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index 98d7cdbf7d..54c8ad5fcf 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -8,6 +8,7 @@ estimate_templates_with_accumulator, Templates, compute_sparsity, + ms_to_samples, ) from spikeinterface.core.job_tools import fix_job_kwargs @@ -332,8 +333,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) # Template are sparse from radius using unit_location - nbefore = int(ms_before * sampling_frequency / 1000.0) - nafter = int(ms_after * sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, sampling_frequency) + nafter = ms_to_samples(ms_after, sampling_frequency) templates_array = estimate_templates_with_accumulator( recording, sorting_pre_peeler.to_spike_vector(), diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index fb833b1a46..c042d8ee56 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -11,6 +11,7 @@ estimate_templates_with_accumulator, Templates, compute_sparsity, + ms_to_samples, ) from spikeinterface.core.job_tools import fix_job_kwargs @@ -291,8 +292,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # we recompute the template even if the clustering give it already because we use different ms_before/ms_after ms_before = params["ms_before"] ms_after = params["ms_after"] - nbefore = int(ms_before * sampling_frequency / 1000.0) - nafter = int(ms_after * sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, sampling_frequency) + nafter = ms_to_samples(ms_after, sampling_frequency) templates_array = estimate_templates_with_accumulator( recording_for_peeler, diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index a8d13dd9df..7e039da18c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -3,6 +3,7 @@ import numpy as np from spikeinterface.core import get_channel_distances, Templates, ChannelSparsity +from spikeinterface.core.core_tools import ms_to_samples from spikeinterface.sortingcomponents.clustering.itersplit_tools import split_clusters # from spikeinterface.sortingcomponents.clustering.merge import merge_clusters @@ -107,8 +108,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ms_before = params["peaks_svd"]["ms_before"] ms_after = params["peaks_svd"]["ms_after"] - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) + nafter = ms_to_samples(ms_after, recording.sampling_frequency) # radius_um = params["waveforms"]["radius_um"] verbose = params["verbose"] diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index df4c4ae39d..cd451f4d07 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -12,6 +12,7 @@ HAVE_HDBSCAN = False from spikeinterface.core.base import minimum_spike_dtype +from spikeinterface.core.core_tools import ms_to_samples from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.sortingcomponents.clustering.merging_tools import merge_peak_labels_from_templates from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser @@ -60,8 +61,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): radius_um = params.get("radius_um", 30) ms_before = params["waveforms"].get("ms_before", 0.5) ms_after = params["waveforms"].get("ms_before", 1.5) - nbefore = int(ms_before * fs / 1000.0) - nafter = int(ms_after * fs / 1000.0) + nbefore = ms_to_samples(ms_before, fs) + nafter = ms_to_samples(ms_after, fs) verbose = params.get("verbose", True) num_chans = recording.get_num_channels() debug_folder = params.get("debug_folder", None) diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 96ae218e1a..2fb05bbc6b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -3,6 +3,7 @@ import numpy as np from spikeinterface.core import Templates, estimate_templates, fix_job_kwargs +from spikeinterface.core.core_tools import ms_to_samples from spikeinterface.core.base import minimum_spike_dtype # TODO find a way to attach a a sparse_mask to a given features (waveforms, pca, tsvd ....) @@ -190,8 +191,8 @@ def get_templates_from_peaks_and_recording( labels, indices = np.unique(valid_labels, return_inverse=True) fs = recording.get_sampling_frequency() - nbefore = int(ms_before * fs / 1000.0) - nafter = int(ms_after * fs / 1000.0) + nbefore = ms_to_samples(ms_before, fs) + nafter = ms_to_samples(ms_after, fs) spikes = np.zeros(valid_peaks.size, dtype=minimum_spike_dtype) spikes["sample_index"] = valid_peaks["sample_index"] @@ -280,8 +281,8 @@ def get_templates_from_peaks_and_svd( labels = np.unique(valid_labels) fs = recording.get_sampling_frequency() - nbefore = int(ms_before * fs / 1000.0) - nafter = int(ms_after * fs / 1000.0) + nbefore = ms_to_samples(ms_before, fs) + nafter = ms_to_samples(ms_after, fs) num_channels = recording.get_num_channels() templates_array = np.zeros((len(labels), nbefore + nafter, num_channels), dtype=np.float32) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index 947eaf391f..f6ceecf705 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -5,6 +5,7 @@ get_channel_distances, get_template_extremum_channel, ) +from spikeinterface.core.core_tools import ms_to_samples from spikeinterface.sortingcomponents.peak_detection.method_list import ( LocallyExclusivePeakDetector, @@ -132,8 +133,8 @@ def __init__( self.peak_sign = peak_sign - nbefore_short = int(ms_before * sr / 1000.0) - nafter_short = int(ms_after * sr / 1000.0) + nbefore_short = ms_to_samples(ms_before, sr) + nafter_short = ms_to_samples(ms_after, sr) assert nbefore_short <= templates.nbefore assert nafter_short <= templates.nafter self.nbefore_short = nbefore_short diff --git a/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py b/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py index 509c3f76f8..efef8c710e 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py @@ -4,6 +4,7 @@ import numpy as np from spikeinterface.core.base import base_peak_dtype +from spikeinterface.core.core_tools import ms_to_samples from spikeinterface.core.node_pipeline import PeakDetector from spikeinterface.core.recording_tools import get_channel_distances, get_random_data_chunks from spikeinterface.postprocessing.localization_tools import get_convolution_weights @@ -63,7 +64,7 @@ def __init__( self.conv_margin = prototype.shape[0] assert peak_sign in ("both", "neg", "pos") - self.nbefore = int(ms_before * recording.sampling_frequency / 1000) + self.nbefore = ms_to_samples(ms_before, recording.sampling_frequency) if peak_sign == "neg": assert prototype[self.nbefore] < 0, "Prototype should have a negative peak" peak_sign = "pos" diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 73a14bdee7..baa8b4471c 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -16,6 +16,7 @@ from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.sorting_tools import get_numba_vector_to_list_of_spiketrain +from spikeinterface.core.core_tools import ms_to_samples def make_multi_method_doc(methods, indent=" "): @@ -57,8 +58,8 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, job spikes["unit_index"] = peaks["channel_index"] spikes["segment_index"] = peaks["segment_index"] - nbefore = int(ms_before * rec.sampling_frequency / 1000.0) - nafter = int(ms_after * rec.sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, rec.sampling_frequency) + nafter = ms_to_samples(ms_after, rec.sampling_frequency) all_wfs = extract_waveforms_to_single_buffer( rec, @@ -115,8 +116,8 @@ def get_prototype_and_waveforms_from_peaks( job_kwargs = fix_job_kwargs(job_kwargs) - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) + nafter = ms_to_samples(ms_after, recording.sampling_frequency) few_peaks = select_peaks( peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed @@ -179,7 +180,7 @@ def get_prototype_and_waveforms_from_recording( node0 = LocallyExclusivePeakDetector(recording, return_output=True, **detection_kwargs) - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) node1 = ExtractSparseWaveforms( recording, parents=[node0], diff --git a/src/spikeinterface/sortingcomponents/waveforms/peak_svd.py b/src/spikeinterface/sortingcomponents/waveforms/peak_svd.py index 548deafb5d..331345975f 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/peak_svd.py +++ b/src/spikeinterface/sortingcomponents/waveforms/peak_svd.py @@ -5,6 +5,7 @@ import numpy as np from spikeinterface.core import get_channel_distances, fix_job_kwargs +from spikeinterface.core.core_tools import ms_to_samples from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.waveforms.temporal_pca import ( @@ -48,8 +49,8 @@ def extract_peaks_svd( job_kwargs = fix_job_kwargs(job_kwargs) - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) + nafter = ms_to_samples(ms_after, recording.sampling_frequency) # Step 1 : select a few peaks to fit the SVD if svd_model is None: diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index d2b1a21fdd..8505a706e9 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -6,6 +6,7 @@ from .utils import get_unit_colors from .traces import TracesWidget from spikeinterface.core import ChannelSparsity +from spikeinterface.core.core_tools import ms_to_samples from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import SortingAnalyzer from spikeinterface.core.baserecording import BaseRecording @@ -230,7 +231,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): vspacing = traces_widget.data_plot["vspacing"] traces = traces_widget.data_plot["list_traces"][0] * dp.options["scale"] - nbefore = nafter = int(dp.spike_width_ms / 2 * sorting_analyzer.sampling_frequency / 1000) + nbefore = nafter = ms_to_samples(dp.spike_width_ms / 2, sorting_analyzer.sampling_frequency) waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-nbefore, nafter) - frame_range[0] waveform_idxs = np.clip(waveform_idxs, 0, len(traces_widget.data_plot["times_in_range"]) - 1)