diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 56f558bce7..b2f7a40669 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -61,7 +61,7 @@ def is_filtered(self): # the is_filtered is handle with annotation return self._annotations.get("is_filtered", False) - def set_probe(self, probe, group_mode="by_probe", in_place=False): + def set_probe(self, probe, group_mode="auto", in_place=False): """ Attach a list of Probe object to a recording. @@ -69,9 +69,9 @@ def set_probe(self, probe, group_mode="by_probe", in_place=False): ---------- probe_or_probegroup: Probe, list of Probe, or ProbeGroup The probe(s) to be attached to the recording - group_mode: "by_probe" | "by_shank", default: "by_probe - "by_probe" or "by_shank". Adds grouping property to the recording based on the probes ("by_probe") - or shanks ("by_shanks") + group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" + How to add the "group" property. + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. in_place: bool False by default. Useful internally when extractor do self.set_probegroup(probe) @@ -86,10 +86,10 @@ def set_probe(self, probe, group_mode="by_probe", in_place=False): probegroup.add_probe(probe) return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) - def set_probegroup(self, probegroup, group_mode="by_probe", in_place=False): + def set_probegroup(self, probegroup, group_mode="auto", in_place=False): return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) - def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False): + def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): """ Attach a list of Probe objects to a recording. For this Probe.device_channel_indices is used to link contacts to recording channels. @@ -103,9 +103,9 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False ---------- probe_or_probegroup: Probe, list of Probe, or ProbeGroup The probe(s) to be attached to the recording - group_mode: "by_probe" | "by_shank", default: "by_probe" - "by_probe" or "by_shank". Adds grouping property to the recording based on the probes ("by_probe") - or shanks ("by_shank") + group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" + How to add the "group" property. + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. in_place: bool False by default. Useful internally when extractor do self.set_probegroup(probe) @@ -115,7 +115,12 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False sub_recording: BaseRecording A view of the recording (ChannelSlice or clone or itself) """ - assert group_mode in ("by_probe", "by_shank"), "'group_mode' can be 'by_probe' or 'by_shank'" + assert group_mode in ( + "auto", + "by_probe", + "by_shank", + "by_side", + ), "'group_mode' can be 'auto' 'by_probe' 'by_shank' or 'by_side'" # handle several input possibilities if isinstance(probe_or_probegroup, Probe): @@ -150,6 +155,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False warn("The given probes have unconnected contacts: they are removed") probe_as_numpy_array = probe_as_numpy_array[keep] + device_channel_indices = probe_as_numpy_array["device_channel_indices"] order = np.argsort(device_channel_indices) device_channel_indices = device_channel_indices[order] @@ -199,20 +205,32 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False sub_recording.set_property("location", locations, ids=None) # handle groups - groups = np.zeros(probe_as_numpy_array.size, dtype="int64") - if group_mode == "by_probe": - for group, probe_index in enumerate(np.unique(probe_as_numpy_array["probe_index"])): - mask = probe_as_numpy_array["probe_index"] == probe_index - groups[mask] = group + has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields + has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields + if group_mode == "auto": + group_keys = ["probe_index"] + if has_shank_id: + group_keys += ["shank_ids"] + if has_contact_side: + group_keys += ["contact_sides"] + elif group_mode == "by_probe": + group_keys = ["probe_index"] elif group_mode == "by_shank": - assert all( - probe.shank_ids is not None for probe in probegroup.probes - ), "shank_ids is None in probe, you cannot group by shank" - for group, a in enumerate(np.unique(probe_as_numpy_array[["probe_index", "shank_ids"]])): - mask = (probe_as_numpy_array["probe_index"] == a["probe_index"]) & ( - probe_as_numpy_array["shank_ids"] == a["shank_ids"] - ) - groups[mask] = group + assert has_shank_id, "shank_ids is None in probe, you cannot group by shank" + group_keys = ["probe_index", "shank_ids"] + elif group_mode == "by_side": + assert has_contact_side, "contact_sides is None in probe, you cannot group by side" + if has_shank_id: + group_keys = ["probe_index", "shank_ids", "contact_sides"] + else: + group_keys = ["probe_index", "contact_sides"] + groups = np.zeros(probe_as_numpy_array.size, dtype="int64") + unique_keys = np.unique(probe_as_numpy_array[group_keys]) + for group, a in enumerate(unique_keys): + mask = np.ones(probe_as_numpy_array.size, dtype=bool) + for k in group_keys: + mask &= probe_as_numpy_array[k] == a[k] + groups[mask] = group sub_recording.set_property("group", groups, ids=None) # add probe annotations to recording diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 9de800b33d..1ebeb677c6 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -179,11 +179,23 @@ def test_BaseRecording(create_cache_folder): # set/get Probe only 2 channels probe = Probe(ndim=2) - positions = [[0.0, 0.0], [0.0, 15.0], [0, 30.0]] - probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}) - probe.set_device_channel_indices([2, -1, 0]) + positions = [ + [0.0, 0.0], + [0.0, 15.0], + [0, 30.0], + [100.0, 0.0], + [100.0, 15.0], + [100.0, 30.0], + ] + probe.set_contacts( + positions=positions, shapes="circle", shape_params={"radius": 5}, shank_ids=["a"] * 3 + ["b"] * 3 + ) + probe.set_device_channel_indices( + [2, -1, 0, -1, -1, -1], + ) probe.create_auto_shape() + rec_p = rec.set_probe(probe, group_mode="auto") rec_p = rec.set_probe(probe, group_mode="by_shank") rec_p = rec.set_probe(probe, group_mode="by_probe") positions2 = rec_p.get_channel_locations() @@ -213,10 +225,35 @@ def test_BaseRecording(create_cache_folder): # plot_probe(probe2) # plt.show() + # test different group mode + probe = Probe(ndim=2) + positions_two_side = positions + positions + shank_ids = ["a", "a", "a", "b", "b", "b"] * 2 + contact_sides = ["front"] * 6 + ["back"] * 6 + probe.set_contacts( + positions=positions_two_side, + shapes="circle", + shape_params={"radius": 5}, + shank_ids=shank_ids, + contact_sides=contact_sides, + ) + probe.set_device_channel_indices(np.arange(12)) + probe.create_auto_shape() + traces = np.zeros((1000, 12), dtype="int16") + rec = NumpyRecording([traces], 30000.0) + rec1 = rec.set_probe(probe, group_mode="auto") + assert np.unique(rec1.get_property("group")).size == 4 + rec2 = rec.set_probe(probe, group_mode="by_probe") + assert np.unique(rec2.get_property("group")).size == 1 + rec3 = rec.set_probe(probe, group_mode="by_shank") + assert np.unique(rec3.get_property("group")).size == 2 + rec4 = rec.set_probe(probe, group_mode="by_side") + assert np.unique(rec4.get_property("group")).size == 4 + # set unconnected probe probe = Probe(ndim=2) positions = [[0.0, 0.0], [0.0, 15.0], [0, 30.0]] - probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}) + probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}, shank_ids=["a", "a", "a"]) probe.set_device_channel_indices([-1, -1, -1]) probe.create_auto_shape() diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 04ebd1bd1d..751a03460c 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -142,7 +142,7 @@ def test_BaseSnippets(create_cache_folder): probe.set_device_channel_indices([2, -1, 0]) probe.create_auto_shape() - snippets_p = snippets.set_probe(probe, group_mode="by_shank") + snippets_p = snippets.set_probe(probe, group_mode="auto") snippets_p = snippets.set_probe(probe, group_mode="by_probe") positions2 = snippets_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) diff --git a/src/spikeinterface/sorters/tests/test_runsorter.py b/src/spikeinterface/sorters/tests/test_runsorter.py index 67f0582b7c..332e6e857e 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter.py +++ b/src/spikeinterface/sorters/tests/test_runsorter.py @@ -56,7 +56,7 @@ def test_run_sorter_dict(generate_recording, create_cache_folder): recording.set_property(key="split_property", values=[4, 4, "g", "g", 4, 4, 4, "g"]) dict_of_recordings = recording.split_by("split_property") - sorter_params = {"detection": {"detect_threshold": 4.9}} + sorter_params = {"detect_threshold": 4.9} folder = cache_folder / "sorting_tdc_local_dict"