From b210494fa941914792d1e06b6b1708c044252127 Mon Sep 17 00:00:00 2001 From: tayheau Date: Tue, 23 Dec 2025 13:02:54 +0100 Subject: [PATCH 1/6] using to_spike_vector on sorting_equal test --- src/spikeinterface/core/testing.py | 53 +++++++++++++++++------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index 3f311f1bdd..a5ad90c6fc 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -113,29 +113,36 @@ def check_sortings_equal( max_spike_index = SX1.to_spike_vector()["sample_index"].max() # TODO for later use to_spike_vector() to do this without looping - for segment_idx in range(SX1.get_num_segments()): - # get_unit_ids - ids1 = np.sort(np.array(SX1.get_unit_ids())) - ids2 = np.sort(np.array(SX2.get_unit_ids())) - assert_array_equal(ids1, ids2) - for id in ids1: - train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx)) - train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx)) - assert np.array_equal(train1, train2) - train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30)) - train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30)) - assert np.array_equal(train1, train2) - # test that slicing works correctly - train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30)) - train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30)) - assert np.array_equal(train1, train2) - train1 = np.sort( - SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30) - ) - train2 = np.sort( - SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30) - ) - assert np.array_equal(train1, train2) + def _sorted_spike_vector(SX): + spikes = SX.to_spike_vector() + order = np.lexsort( + (spikes["sample_index"], spikes["unit_index"], spikes["segment_index"]) + ) + return spikes[order] + + def _slice_spikes(spikes, start_frame = None, end_frame = None): + mask = np.ones(spikes.size, dtype=bool) + if start_frame is not None: + mask &= spikes["sample_index"] >= start_frame + if end_frame is not None: + mask &= spikes["sample_index"] <= end_frame + return spikes[mask] + + s1 = _sorted_spike_vector(SX1) + s2 = _sorted_spike_vector(SX2) + assert_array_equal(s1, s2) + + for start_frame, end_frame in [ + (None, None), + (30, None), + (None, max_spike_index - 30), + (30, max_spike_index - 30), + ]: + + slice1 = _slice_spikes(s1, start_frame, end_frame) + slice2 = _slice_spikes(s2, start_frame, end_frame) + assert np.array_equal(slice1, slice2) + if check_annotations: check_extractor_annotations_equal(SX1, SX2) From a8842ea64f1c6e9fb7a03c00c88983513d7e952c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 12:11:40 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/testing.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index a5ad90c6fc..4cd55f1665 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -115,12 +115,10 @@ def check_sortings_equal( # TODO for later use to_spike_vector() to do this without looping def _sorted_spike_vector(SX): spikes = SX.to_spike_vector() - order = np.lexsort( - (spikes["sample_index"], spikes["unit_index"], spikes["segment_index"]) - ) + order = np.lexsort((spikes["sample_index"], spikes["unit_index"], spikes["segment_index"])) return spikes[order] - - def _slice_spikes(spikes, start_frame = None, end_frame = None): + + def _slice_spikes(spikes, start_frame=None, end_frame=None): mask = np.ones(spikes.size, dtype=bool) if start_frame is not None: mask &= spikes["sample_index"] >= start_frame @@ -133,17 +131,16 @@ def _slice_spikes(spikes, start_frame = None, end_frame = None): assert_array_equal(s1, s2) for start_frame, end_frame in [ - (None, None), - (30, None), - (None, max_spike_index - 30), - (30, max_spike_index - 30), - ]: + (None, None), + (30, None), + (None, max_spike_index - 30), + (30, max_spike_index - 30), + ]: slice1 = _slice_spikes(s1, start_frame, end_frame) slice2 = _slice_spikes(s2, start_frame, end_frame) assert np.array_equal(slice1, slice2) - if check_annotations: check_extractor_annotations_equal(SX1, SX2) if check_properties: From 497f150645fa40eaa2eb3df40c04026296401fcb Mon Sep 17 00:00:00 2001 From: tayheau Date: Tue, 23 Dec 2025 15:35:50 +0100 Subject: [PATCH 3/6] removed the sorting since already sorted --- src/spikeinterface/core/testing.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index 4cd55f1665..2bd95b6c95 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -113,12 +113,7 @@ def check_sortings_equal( max_spike_index = SX1.to_spike_vector()["sample_index"].max() # TODO for later use to_spike_vector() to do this without looping - def _sorted_spike_vector(SX): - spikes = SX.to_spike_vector() - order = np.lexsort((spikes["sample_index"], spikes["unit_index"], spikes["segment_index"])) - return spikes[order] - - def _slice_spikes(spikes, start_frame=None, end_frame=None): + def _slice_spikes(spikes, start_frame = None, end_frame = None): mask = np.ones(spikes.size, dtype=bool) if start_frame is not None: mask &= spikes["sample_index"] >= start_frame @@ -126,8 +121,8 @@ def _slice_spikes(spikes, start_frame=None, end_frame=None): mask &= spikes["sample_index"] <= end_frame return spikes[mask] - s1 = _sorted_spike_vector(SX1) - s2 = _sorted_spike_vector(SX2) + s1 = SX1.to_spike_vector() + s2 = SX2.to_spike_vector() assert_array_equal(s1, s2) for start_frame, end_frame in [ From df9d18fddb45dc2ea28437cf40ab1503569335a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:46:08 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index 2bd95b6c95..a9da4e0dae 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -113,7 +113,7 @@ def check_sortings_equal( max_spike_index = SX1.to_spike_vector()["sample_index"].max() # TODO for later use to_spike_vector() to do this without looping - def _slice_spikes(spikes, start_frame = None, end_frame = None): + def _slice_spikes(spikes, start_frame=None, end_frame=None): mask = np.ones(spikes.size, dtype=bool) if start_frame is not None: mask &= spikes["sample_index"] >= start_frame From 1f523e8a8f33d549717c2f02765c7a2997b74eb6 Mon Sep 17 00:00:00 2001 From: tayheau Date: Tue, 23 Dec 2025 17:25:14 +0100 Subject: [PATCH 5/6] restored lexsort --- src/spikeinterface/core/testing.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index a9da4e0dae..f93d6f329d 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -112,8 +112,14 @@ def check_sortings_equal( max_spike_index = SX1.to_spike_vector()["sample_index"].max() - # TODO for later use to_spike_vector() to do this without looping - def _slice_spikes(spikes, start_frame=None, end_frame=None): + def _sorted_spike_vector(SX): + spikes = SX.to_spike_vector() + order = np.lexsort( + (spikes["sample_index"], spikes["unit_index"], spikes["segment_index"]) + ) + return spikes[order] + + def _slice_spikes(spikes, start_frame = None, end_frame = None): mask = np.ones(spikes.size, dtype=bool) if start_frame is not None: mask &= spikes["sample_index"] >= start_frame @@ -121,21 +127,22 @@ def _slice_spikes(spikes, start_frame=None, end_frame=None): mask &= spikes["sample_index"] <= end_frame return spikes[mask] - s1 = SX1.to_spike_vector() - s2 = SX2.to_spike_vector() + s1 = _sorted_spike_vector(SX1) + s2 = _sorted_spike_vector(SX2) assert_array_equal(s1, s2) for start_frame, end_frame in [ - (None, None), - (30, None), - (None, max_spike_index - 30), - (30, max_spike_index - 30), - ]: + (None, None), + (30, None), + (None, max_spike_index - 30), + (30, max_spike_index - 30), + ]: slice1 = _slice_spikes(s1, start_frame, end_frame) slice2 = _slice_spikes(s2, start_frame, end_frame) assert np.array_equal(slice1, slice2) + if check_annotations: check_extractor_annotations_equal(SX1, SX2) if check_properties: From f75c3d40d6c8e2b59d98938faa7930daaf121a3c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 16:26:40 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/testing.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index f93d6f329d..2dbfbc604e 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -114,12 +114,10 @@ def check_sortings_equal( def _sorted_spike_vector(SX): spikes = SX.to_spike_vector() - order = np.lexsort( - (spikes["sample_index"], spikes["unit_index"], spikes["segment_index"]) - ) + order = np.lexsort((spikes["sample_index"], spikes["unit_index"], spikes["segment_index"])) return spikes[order] - - def _slice_spikes(spikes, start_frame = None, end_frame = None): + + def _slice_spikes(spikes, start_frame=None, end_frame=None): mask = np.ones(spikes.size, dtype=bool) if start_frame is not None: mask &= spikes["sample_index"] >= start_frame @@ -132,17 +130,16 @@ def _slice_spikes(spikes, start_frame = None, end_frame = None): assert_array_equal(s1, s2) for start_frame, end_frame in [ - (None, None), - (30, None), - (None, max_spike_index - 30), - (30, max_spike_index - 30), - ]: + (None, None), + (30, None), + (None, max_spike_index - 30), + (30, max_spike_index - 30), + ]: slice1 = _slice_spikes(s1, start_frame, end_frame) slice2 = _slice_spikes(s2, start_frame, end_frame) assert np.array_equal(slice1, slice2) - if check_annotations: check_extractor_annotations_equal(SX1, SX2) if check_properties: