Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions src/spikeinterface/core/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,30 +112,33 @@ def check_sortings_equal(

max_spike_index = SX1.to_spike_vector()["sample_index"].max()

# TODO for later use to_spike_vector() to do this without looping
for segment_idx in range(SX1.get_num_segments()):
# get_unit_ids
ids1 = np.sort(np.array(SX1.get_unit_ids()))
ids2 = np.sort(np.array(SX2.get_unit_ids()))
assert_array_equal(ids1, ids2)
for id in ids1:
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx))
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx))
assert np.array_equal(train1, train2)
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30))
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30))
assert np.array_equal(train1, train2)
# test that slicing works correctly
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30))
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30))
assert np.array_equal(train1, train2)
train1 = np.sort(
SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30)
)
train2 = np.sort(
SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30)
)
assert np.array_equal(train1, train2)
def _sorted_spike_vector(SX):
spikes = SX.to_spike_vector()
order = np.lexsort((spikes["sample_index"], spikes["unit_index"], spikes["segment_index"]))
return spikes[order]

def _slice_spikes(spikes, start_frame=None, end_frame=None):
mask = np.ones(spikes.size, dtype=bool)
if start_frame is not None:
mask &= spikes["sample_index"] >= start_frame
if end_frame is not None:
mask &= spikes["sample_index"] <= end_frame
return spikes[mask]

s1 = _sorted_spike_vector(SX1)
s2 = _sorted_spike_vector(SX2)
assert_array_equal(s1, s2)

for start_frame, end_frame in [
(None, None),
(30, None),
(None, max_spike_index - 30),
(30, max_spike_index - 30),
]:

slice1 = _slice_spikes(s1, start_frame, end_frame)
slice2 = _slice_spikes(s2, start_frame, end_frame)
assert np.array_equal(slice1, slice2)

if check_annotations:
check_extractor_annotations_equal(SX1, SX2)
Expand Down