diff --git a/src/vitabel/vitals.py b/src/vitabel/vitals.py index 7e1fb02..92d4a56 100644 --- a/src/vitabel/vitals.py +++ b/src/vitabel/vitals.py @@ -12,7 +12,7 @@ import logging import vitaldb from collections import defaultdict - +from copy import deepcopy from typing import Any, Literal from IPython.display import display @@ -23,6 +23,7 @@ Label, IntervalLabel, TimeDataCollection, + TimeSeriesBase, _timeseries_list_info, ) from vitabel.utils import ( @@ -2860,3 +2861,137 @@ def area_under_threshold( stop_time=stop_time, threshold=threshold ) + + def filter_by_intervallabel( + self, + channel_or_label_to_filter: Channel | Label, + filter_by: IntervalLabel, + invert: bool = False, + start_inclusive: bool = True, + end_inclusive: bool = True, + full_cover: bool = False, + ) -> Channel | Label: + """Filters a channel or label by a given interval label. + + Parameters + ---------- + channel_or_label_to_filter + The channel or label to filter. + filter_by + The interval label to filter by. + invert + If ``True``, inverts the filter, i.e., keeps the parts + outside the intervals defined by `filter_by`. + start_inclusive + If ``True``, the start of the intervals in `filter_by` + is considered inclusive for filtering. + If `invert` is ``True``, this parameter becomes inverted too. + end_inclusive + If ``True``, the end of the intervals in `filter_by` + is considered inclusive for filtering. + If `invert` is ``True``, this parameter becomes inverted too. + full_cover + Only relevant for interval Labels: + - If True, keep label intervals only if they are **fully contained** in at least one filter interval. + - If False, keep label intervals if they **overlap at all** with any filter interval. + + Returns + ------- + Channel | Label + The filtered channel or label. + """ + + if not isinstance(channel_or_label_to_filter, (Channel, Label)): + raise ValueError( + f"The specified channel_or_label_to_filter {channel_or_label_to_filter} is neither a " + "Channel nor a Label" + ) + if not isinstance(filter_by, IntervalLabel): + raise ValueError( + f"The specified filter_by {filter_by} is not an IntervalLabel" + ) + + # Determine whether we are filtering timestamps (Channel or Label) + # or intervals (IntervalLabel) + timestamps_to_filter = not isinstance(channel_or_label_to_filter, IntervalLabel) + is_label = isinstance(channel_or_label_to_filter, Label) + + # Get timestamps and the data of the channel or label + channel_or_label_to_filter = deepcopy(channel_or_label_to_filter) + time_index = channel_or_label_to_filter.get_data().time_index + data = channel_or_label_to_filter.get_data().data.copy() if channel_or_label_to_filter.get_data().data is not None else np.array([]) + text_data = channel_or_label_to_filter.get_data().text_data if channel_or_label_to_filter.get_data().text_data is not None else np.arrayy([]) + + + # build a mask for filtering + intervals = filter_by.get_data().time_index + start = intervals[:, 0] + stop = intervals[:, 1] + + if timestamps_to_filter: # Channel or Label + timestamps_to_filter = time_index.to_numpy() + + if start_inclusive: + condition_start = timestamps_to_filter[:, None] >= start + else: + condition_start = timestamps_to_filter[:, None] > start + + if end_inclusive: + condition_end = timestamps_to_filter[:, None] <= stop + else: + condition_end = timestamps_to_filter[:, None] < stop + + mask = (condition_start & condition_end).any(axis=1) + + if invert: + mask = ~mask + + masked_index = time_index[mask] + + else: # IntervalLabel + intervals_to_filter = time_index + + intervals_to_filter_start = intervals_to_filter[:, 0] + intervals_to_filter_end = intervals_to_filter[:, 1] + + # A is the Interval to filter by B + if full_cover: + if start_inclusive and end_inclusive: + # require full containment of A in some B, including touching borders + mask = ((intervals_to_filter_start[:, None] >= start) & (intervals_to_filter_end[:, None] <= stop)).any(axis=1) + elif start_inclusive and not end_inclusive: + # require full containment of A in some B, including touching start border + mask = ((intervals_to_filter_start[:, None] >= start) & (intervals_to_filter_end[:, None] < stop)).any(axis=1) + elif not start_inclusive and end_inclusive: + # require full containment of A in some B, including touching end border + mask = ((intervals_to_filter_start[:, None] > start) & (intervals_to_filter_end[:, None] <= stop)).any(axis=1) + + else: # any overlap + if start_inclusive and end_inclusive: + # require any overlap between A and some B, including touching borders + mask = ((intervals_to_filter_start[:, None] <= stop) & (intervals_to_filter_end[:, None] >= start)).any(axis=1) + elif start_inclusive and not end_inclusive: + # require any overlap between A and some B, including touching start border + mask = ((intervals_to_filter_start[:, None] < stop) & (intervals_to_filter_end[:, None] >= start)).any(axis=1) + elif not start_inclusive and end_inclusive: + # require any overlap between A and some B, including touching end border + mask = ((intervals_to_filter_start[:, None] <= stop) & (intervals_to_filter_end[:, None] > start)).any(axis=1) + + if invert: + mask = ~mask + + masked_index = time_index[mask] + masked_index = masked_index.ravel() + + # Filter the given Channel or Label + TS = TimeSeriesBase(masked_index-channel_or_label_to_filter.offset) + new_index = TS.time_index + new_start = TS.time_start + channel_or_label_to_filter.time_index = new_index + channel_or_label_to_filter.time_start = new_start + channel_or_label_to_filter.data = data[mask] if len(data) > 0 else None + if is_label: + channel_or_label_to_filter.text_data = text_data[mask] if len(text_data) > 0 else None + + return channel_or_label_to_filter +