diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index 3f311f1bdd..2dbfbc604e 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -112,30 +112,33 @@ def check_sortings_equal( max_spike_index = SX1.to_spike_vector()["sample_index"].max() - # TODO for later use to_spike_vector() to do this without looping - for segment_idx in range(SX1.get_num_segments()): - # get_unit_ids - ids1 = np.sort(np.array(SX1.get_unit_ids())) - ids2 = np.sort(np.array(SX2.get_unit_ids())) - assert_array_equal(ids1, ids2) - for id in ids1: - train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx)) - train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx)) - assert np.array_equal(train1, train2) - train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30)) - train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30)) - assert np.array_equal(train1, train2) - # test that slicing works correctly - train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30)) - train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30)) - assert np.array_equal(train1, train2) - train1 = np.sort( - SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30) - ) - train2 = np.sort( - SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30) - ) - assert np.array_equal(train1, train2) + def _sorted_spike_vector(SX): + spikes = SX.to_spike_vector() + order = np.lexsort((spikes["sample_index"], spikes["unit_index"], spikes["segment_index"])) + return spikes[order] + + def _slice_spikes(spikes, start_frame=None, end_frame=None): + mask = np.ones(spikes.size, dtype=bool) + if start_frame is not None: + mask &= spikes["sample_index"] >= start_frame + if end_frame is not None: + mask &= spikes["sample_index"] <= end_frame + return spikes[mask] + + s1 = _sorted_spike_vector(SX1) + s2 = _sorted_spike_vector(SX2) + assert_array_equal(s1, s2) + + for start_frame, end_frame in [ + (None, None), + (30, None), + (None, max_spike_index - 30), + (30, max_spike_index - 30), + ]: + + slice1 = _slice_spikes(s1, start_frame, end_frame) + slice2 = _slice_spikes(s2, start_frame, end_frame) + assert np.array_equal(slice1, slice2) if check_annotations: check_extractor_annotations_equal(SX1, SX2)