Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
estimate_templates,
Templates,
create_sorting_analyzer,
ms_to_samples,
)
from spikeinterface.generation import generate_drifting_recording

Expand Down Expand Up @@ -54,8 +55,8 @@ def make_dataset(job_kwargs={}):
def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_in_uV=False, **job_kwargs):
spikes = gt_sorting.to_spike_vector() # [spike_indices]
fs = recording.sampling_frequency
nbefore = int(ms_before * fs / 1000)
nafter = int(ms_after * fs / 1000)
nbefore = ms_to_samples(ms_before, fs)
nafter = ms_to_samples(ms_after, fs)
templates_array = estimate_templates(
recording,
spikes,
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
read_python,
write_python,
normal_pdf,
ms_to_samples,
)
from .job_tools import (
get_best_job_kwargs,
Expand Down
9 changes: 5 additions & 4 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .template import Templates
from .sorting_tools import random_spikes_selection, select_sorting_periods_mask, spike_vector_to_indices
from .job_tools import fix_job_kwargs, split_job_kwargs
from .core_tools import ms_to_samples


class ComputeRandomSpikes(AnalyzerExtension):
Expand Down Expand Up @@ -170,11 +171,11 @@ class ComputeWaveforms(AnalyzerExtension):

@property
def nbefore(self):
return int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0)
return ms_to_samples(self.params["ms_before"], self.sorting_analyzer.sampling_frequency)

@property
def nafter(self):
return int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)
return ms_to_samples(self.params["ms_after"], self.sorting_analyzer.sampling_frequency)

def _run(self, verbose=False, **job_kwargs):
self.data.clear()
Expand Down Expand Up @@ -540,12 +541,12 @@ def _compute_and_append_from_waveforms(self, operators):

@property
def nbefore(self):
nbefore = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0)
nbefore = ms_to_samples(self.params["ms_before"], self.sorting_analyzer.sampling_frequency)
return nbefore

@property
def nafter(self):
nafter = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)
nafter = ms_to_samples(self.params["ms_after"], self.sorting_analyzer.sampling_frequency)
return nafter

def _select_extension_data(self, unit_ids):
Expand Down
5 changes: 5 additions & 0 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,3 +757,8 @@ def is_path_remote(path: str | Path) -> bool:
Whether the path is a remote path.
"""
return "s3://" in str(path) or "gcs://" in str(path)


def ms_to_samples(ms: float, sampling_frequency: float) -> int:
"""Convert a duration in milliseconds to the nearest number of samples."""
return round(ms * sampling_frequency / 1000.0)
14 changes: 7 additions & 7 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting
from .snippets_tools import snippets_from_sorting
from .core_tools import define_function_from_class
from .core_tools import define_function_from_class, ms_to_samples


def _ensure_seed(seed):
Expand Down Expand Up @@ -1598,8 +1598,8 @@ def generate_single_fake_waveform(
assert ms_after > depolarization_ms + repolarization_ms
assert ms_before > depolarization_ms

nbefore = int(sampling_frequency * ms_before / 1000.0)
nafter = int(sampling_frequency * ms_after / 1000.0)
nbefore = ms_to_samples(ms_before, sampling_frequency)
nafter = ms_to_samples(ms_after, sampling_frequency)
width = nbefore + nafter
wf = np.zeros(width, dtype=dtype)

Expand Down Expand Up @@ -1776,8 +1776,8 @@ def generate_templates(

num_units = units_locations.shape[0]
num_channels = channel_locations.shape[0]
nbefore = int(sampling_frequency * ms_before / 1000.0)
nafter = int(sampling_frequency * ms_after / 1000.0)
nbefore = ms_to_samples(ms_before, sampling_frequency)
nafter = ms_to_samples(ms_after, sampling_frequency)
width = nbefore + nafter

if upsample_factor is not None:
Expand Down Expand Up @@ -2451,8 +2451,8 @@ def generate_ground_truth_recording(
upsample_factor = templates.shape[3]
upsample_vector = rng.integers(0, upsample_factor, size=num_spikes)

nbefore = int(ms_before * sampling_frequency / 1000.0)
nafter = int(ms_after * sampling_frequency / 1000.0)
nbefore = ms_to_samples(ms_before, sampling_frequency)
nafter = ms_to_samples(ms_after, sampling_frequency)
assert (nbefore + nafter) == templates.shape[1]

# construct recording
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from spikeinterface.core import BaseRecording, get_chunk_with_margin
from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs, _shared_job_kwargs_doc
from spikeinterface.core import get_channel_distances
from spikeinterface.core.core_tools import ms_to_samples


class PipelineNode:
Expand Down Expand Up @@ -314,8 +315,8 @@ def __init__(
PipelineNode.__init__(self, recording=recording, parents=parents, return_output=return_output)
self.ms_before = ms_before
self.ms_after = ms_after
self.nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.0)
self.nafter = int(ms_after * recording.get_sampling_frequency() / 1000.0)
self.nbefore = ms_to_samples(ms_before, recording.get_sampling_frequency())
self.nafter = ms_to_samples(ms_after, recording.get_sampling_frequency())
self.neighbours_mask = None


Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .sorting_tools import random_spikes_selection
from .job_tools import _shared_job_kwargs_doc
from .waveform_tools import estimate_templates_with_accumulator
from .core_tools import ms_to_samples

_sparsity_doc = """
method : str
Expand Down Expand Up @@ -784,8 +785,8 @@ def estimate_sparsity(
probe = recording.create_dummy_probe_from_locations(chan_locs)

if method != "by_property":
nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
nafter = int(ms_after * recording.sampling_frequency / 1000.0)
nbefore = ms_to_samples(ms_before, recording.sampling_frequency)
nafter = ms_to_samples(ms_after, recording.sampling_frequency)

num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())]
random_spikes_indices = random_spikes_selection(
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/tests/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
generate_ground_truth_recording,
create_sorting_analyzer,
load,
ms_to_samples,
SortingAnalyzer,
Templates,
aggregate_channels,
Expand Down Expand Up @@ -71,7 +72,7 @@ def generate_templates_object():
templates = Templates(
templates_array=templates_arr,
sampling_frequency=sampling_frequency,
nbefore=int(ms_before * sampling_frequency / 1000),
nbefore=ms_to_samples(ms_before, sampling_frequency),
probe=probe,
)
return templates
Expand Down
14 changes: 7 additions & 7 deletions src/spikeinterface/core/tests/test_waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from spikeinterface.core import generate_recording, generate_sorting, generate_ground_truth_recording
from spikeinterface.core import generate_recording, generate_sorting, generate_ground_truth_recording, ms_to_samples
from spikeinterface.core.waveform_tools import (
extract_waveforms_to_buffers,
extract_waveforms_to_single_buffer,
Expand Down Expand Up @@ -56,8 +56,8 @@ def test_waveform_tools(create_cache_folder):
recording, sorting = get_dataset()
sampling_frequency = recording.sampling_frequency

nbefore = int(3.0 * sampling_frequency / 1000.0)
nafter = int(4.0 * sampling_frequency / 1000.0)
nbefore = ms_to_samples(3.0, sampling_frequency)
nafter = ms_to_samples(4.0, sampling_frequency)

dtype = recording.get_dtype()
# return_in_uV = False
Expand Down Expand Up @@ -164,8 +164,8 @@ def test_estimate_templates_with_accumulator():
ms_before = 1.0
ms_after = 1.5

nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
nafter = int(ms_after * recording.sampling_frequency / 1000.0)
nbefore = ms_to_samples(ms_before, recording.sampling_frequency)
nafter = ms_to_samples(ms_after, recording.sampling_frequency)

spikes = sorting.to_spike_vector()
# take one spikes every 10
Expand Down Expand Up @@ -218,8 +218,8 @@ def test_estimate_templates():
ms_before = 1.0
ms_after = 1.5

nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
nafter = int(ms_after * recording.sampling_frequency / 1000.0)
nbefore = ms_to_samples(ms_before, recording.sampling_frequency)
nafter = ms_to_samples(ms_after, recording.sampling_frequency)

spikes = sorting.to_spike_vector()
# take one spikes every 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .sparsity import ChannelSparsity
from .sortinganalyzer import SortingAnalyzer, load_sorting_analyzer
from .loading import load
from .core_tools import ms_to_samples
from .analyzer_extension_core import ComputeRandomSpikes, ComputeWaveforms, ComputeTemplates

_backwards_compatibility_msg = """####
Expand Down Expand Up @@ -162,12 +163,12 @@ def unit_ids(self) -> np.ndarray:
@property
def nbefore(self) -> int:
ms_before = self.sorting_analyzer.get_extension("waveforms").params["ms_before"]
return int(ms_before * self.sampling_frequency / 1000.0)
return ms_to_samples(ms_before, self.sampling_frequency)

@property
def nafter(self) -> int:
ms_after = self.sorting_analyzer.get_extension("waveforms").params["ms_after"]
return int(ms_after * self.sampling_frequency / 1000.0)
return ms_to_samples(ms_after, self.sampling_frequency)

@property
def nsamples(self) -> int:
Expand Down Expand Up @@ -522,8 +523,8 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
else:
max_num_channel = np.max(np.sum(sparsity.mask, axis=1))

nbefore = int(params["ms_before"] * sorting.sampling_frequency / 1000.0)
nafter = int(params["ms_after"] * sorting.sampling_frequency / 1000.0)
nbefore = ms_to_samples(params["ms_before"], sorting.sampling_frequency)
nafter = ms_to_samples(params["ms_after"], sorting.sampling_frequency)

waveforms = np.zeros((num_spikes, nbefore + nafter, max_num_channel), dtype=params["dtype"])
# then read waveforms per units
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from probeinterface import generate_multi_columns_probe

from spikeinterface import Templates
from spikeinterface.core import ms_to_samples
from spikeinterface.core.generate import (
generate_unit_locations,
generate_sorting,
Expand Down Expand Up @@ -516,7 +517,7 @@ def generate_drifting_recording(
)

ms_before = generate_templates_kwargs["ms_before"]
nbefore = int(sampling_frequency * ms_before / 1000.0)
nbefore = ms_to_samples(ms_before, sampling_frequency)
templates = Templates(
templates_array=templates_array,
sampling_frequency=sampling_frequency,
Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/generation/hybrid_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Literal
import numpy as np

from spikeinterface.core import BaseRecording, BaseSorting, Templates
from spikeinterface.core import BaseRecording, BaseSorting, Templates, ms_to_samples

from spikeinterface.core.generate import (
generate_templates,
Expand Down Expand Up @@ -71,8 +71,8 @@ def estimate_templates_from_recording(
spikes = sorting.to_spike_vector()
unit_ids = sorting.unit_ids
sampling_frequency = recording.get_sampling_frequency()
nbefore = int(ms_before * sampling_frequency / 1000.0)
nafter = int(ms_after * sampling_frequency / 1000.0)
nbefore = ms_to_samples(ms_before, sampling_frequency)
nafter = ms_to_samples(ms_after, sampling_frequency)

job_kwargs = job_kwargs or {}
templates_array = estimate_templates(recording, spikes, unit_ids, nbefore, nafter, return_in_uV=False, **job_kwargs)
Expand Down Expand Up @@ -440,8 +440,8 @@ def generate_hybrid_recording(
)
ms_before = generate_templates_kwargs["ms_before"]
ms_after = generate_templates_kwargs["ms_after"]
nbefore = int(ms_before * sampling_frequency / 1000.0)
nafter = int(ms_after * sampling_frequency / 1000.0)
nbefore = ms_to_samples(ms_before, sampling_frequency)
nafter = ms_to_samples(ms_after, sampling_frequency)
templates_ = Templates(templates_array, sampling_frequency, nbefore, True, None, None, None, probe)
else:
from spikeinterface.postprocessing.localization_tools import compute_monopolar_triangulation
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/generation/tests/test_hybrid_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from spikeinterface.core import Templates
from spikeinterface.core import Templates, ms_to_samples
from spikeinterface.core.generate import (
generate_ground_truth_recording,
generate_sorting,
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_generate_hybrid_from_templates():
templates_array = generate_templates(
channel_locations, unit_locations, rec.sampling_frequency, ms_before, ms_after, seed=0
)
nbefore = int(ms_before * rec.sampling_frequency / 1000)
nbefore = ms_to_samples(ms_before, rec.sampling_frequency)
templates = Templates(templates_array, rec.sampling_frequency, nbefore, True, None, None, None, rec.get_probe())
hybrid, sorting_hybrid = generate_hybrid_recording(rec, templates=templates, seed=0)
assert np.array_equal(hybrid.templates, templates.templates_array)
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from spikeinterface.core import ChannelSparsity
from spikeinterface.core.core_tools import ms_to_samples
from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore
from spikeinterface.core.sortinganalyzer import register_result_extension
from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension
Expand Down Expand Up @@ -87,15 +88,15 @@ def _get_pipeline_nodes(self):

# if ms_before / ms_after are set in params then the original templates are shorten
if self.params["ms_before"] is not None:
cut_out_before = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0)
cut_out_before = ms_to_samples(self.params["ms_before"], self.sorting_analyzer.sampling_frequency)
assert (
cut_out_before <= nbefore
), f"`ms_before` must be smaller than `ms_before` used in ComputeTemplates: {nbefore}"
else:
cut_out_before = nbefore

if self.params["ms_after"] is not None:
cut_out_after = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)
cut_out_after = ms_to_samples(self.params["ms_after"], self.sorting_analyzer.sampling_frequency)
assert (
cut_out_after <= nafter
), f"`ms_after` must be smaller than `ms_after` used in templates: {templates_ext.params['ms_after']}"
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/preprocessing/remove_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from spikeinterface.core.core_tools import define_function_handling_dict_from_class
from spikeinterface.core.core_tools import define_function_handling_dict_from_class, ms_to_samples

from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from spikeinterface.core import NumpySorting, estimate_templates
Expand Down Expand Up @@ -170,8 +170,8 @@ def __init__(
list_triggers, list_labels, recording.get_sampling_frequency()
)

nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
nafter = int(ms_after * recording.sampling_frequency / 1000.0)
nbefore = ms_to_samples(ms_before, recording.sampling_frequency)
nafter = ms_to_samples(ms_after, recording.sampling_frequency)

templates = estimate_templates(
recording=recording,
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/sorters/internal/lupin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
estimate_templates_with_accumulator,
Templates,
compute_sparsity,
ms_to_samples,
)

from spikeinterface.core.job_tools import fix_job_kwargs
Expand Down Expand Up @@ -332,8 +333,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
)

# Template are sparse from radius using unit_location
nbefore = int(ms_before * sampling_frequency / 1000.0)
nafter = int(ms_after * sampling_frequency / 1000.0)
nbefore = ms_to_samples(ms_before, sampling_frequency)
nafter = ms_to_samples(ms_after, sampling_frequency)
templates_array = estimate_templates_with_accumulator(
recording,
sorting_pre_peeler.to_spike_vector(),
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
estimate_templates_with_accumulator,
Templates,
compute_sparsity,
ms_to_samples,
)

from spikeinterface.core.job_tools import fix_job_kwargs
Expand Down Expand Up @@ -291,8 +292,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# we recompute the template even if the clustering give it already because we use different ms_before/ms_after
ms_before = params["ms_before"]
ms_after = params["ms_after"]
nbefore = int(ms_before * sampling_frequency / 1000.0)
nafter = int(ms_after * sampling_frequency / 1000.0)
nbefore = ms_to_samples(ms_before, sampling_frequency)
nafter = ms_to_samples(ms_after, sampling_frequency)

templates_array = estimate_templates_with_accumulator(
recording_for_peeler,
Expand Down
Loading
Loading