Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 136 additions & 1 deletion src/vitabel/vitals.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import logging
import vitaldb
from collections import defaultdict

from copy import deepcopy
from typing import Any, Literal

from IPython.display import display
Expand All @@ -23,6 +23,7 @@
Label,
IntervalLabel,
TimeDataCollection,
TimeSeriesBase,
_timeseries_list_info,
)
from vitabel.utils import (
Expand Down Expand Up @@ -2860,3 +2861,137 @@ def area_under_threshold(
stop_time=stop_time,
threshold=threshold
)

def filter_by_intervallabel(
self,
channel_or_label_to_filter: Channel | Label,
filter_by: IntervalLabel,
invert: bool = False,
start_inclusive: bool = True,
end_inclusive: bool = True,
full_cover: bool = False,
) -> Channel | Label:
"""Filters a channel or label by a given interval label.

Parameters
----------
channel_or_label_to_filter
The channel or label to filter.
filter_by
The interval label to filter by.
invert
If ``True``, inverts the filter, i.e., keeps the parts
outside the intervals defined by `filter_by`.
start_inclusive
If ``True``, the start of the intervals in `filter_by`
is considered inclusive for filtering.
If `invert` is ``True``, this parameter becomes inverted too.
end_inclusive
If ``True``, the end of the intervals in `filter_by`
is considered inclusive for filtering.
If `invert` is ``True``, this parameter becomes inverted too.
full_cover
Only relevant for interval Labels:
- If True, keep label intervals only if they are **fully contained** in at least one filter interval.
- If False, keep label intervals if they **overlap at all** with any filter interval.

Returns
-------
Channel | Label
The filtered channel or label.
"""

if not isinstance(channel_or_label_to_filter, (Channel, Label)):
raise ValueError(
f"The specified channel_or_label_to_filter {channel_or_label_to_filter} is neither a "
"Channel nor a Label"
)
if not isinstance(filter_by, IntervalLabel):
raise ValueError(
f"The specified filter_by {filter_by} is not an IntervalLabel"
)

# Determine whether we are filtering timestamps (Channel or Label)
# or intervals (IntervalLabel)
timestamps_to_filter = not isinstance(channel_or_label_to_filter, IntervalLabel)
is_label = isinstance(channel_or_label_to_filter, Label)

# Get timestamps and the data of the channel or label
channel_or_label_to_filter = deepcopy(channel_or_label_to_filter)
time_index = channel_or_label_to_filter.get_data().time_index
data = channel_or_label_to_filter.get_data().data.copy() if channel_or_label_to_filter.get_data().data is not None else np.array([])
text_data = channel_or_label_to_filter.get_data().text_data if channel_or_label_to_filter.get_data().text_data is not None else np.arrayy([])


# build a mask for filtering
intervals = filter_by.get_data().time_index
start = intervals[:, 0]
stop = intervals[:, 1]

if timestamps_to_filter: # Channel or Label
timestamps_to_filter = time_index.to_numpy()

if start_inclusive:
condition_start = timestamps_to_filter[:, None] >= start
else:
condition_start = timestamps_to_filter[:, None] > start

if end_inclusive:
condition_end = timestamps_to_filter[:, None] <= stop
else:
condition_end = timestamps_to_filter[:, None] < stop

mask = (condition_start & condition_end).any(axis=1)

if invert:
mask = ~mask

masked_index = time_index[mask]

else: # IntervalLabel
intervals_to_filter = time_index

intervals_to_filter_start = intervals_to_filter[:, 0]
intervals_to_filter_end = intervals_to_filter[:, 1]

# A is the Interval to filter by B
if full_cover:
if start_inclusive and end_inclusive:
# require full containment of A in some B, including touching borders
mask = ((intervals_to_filter_start[:, None] >= start) & (intervals_to_filter_end[:, None] <= stop)).any(axis=1)
elif start_inclusive and not end_inclusive:
# require full containment of A in some B, including touching start border
mask = ((intervals_to_filter_start[:, None] >= start) & (intervals_to_filter_end[:, None] < stop)).any(axis=1)
elif not start_inclusive and end_inclusive:
# require full containment of A in some B, including touching end border
mask = ((intervals_to_filter_start[:, None] > start) & (intervals_to_filter_end[:, None] <= stop)).any(axis=1)

else: # any overlap
if start_inclusive and end_inclusive:
# require any overlap between A and some B, including touching borders
mask = ((intervals_to_filter_start[:, None] <= stop) & (intervals_to_filter_end[:, None] >= start)).any(axis=1)
elif start_inclusive and not end_inclusive:
# require any overlap between A and some B, including touching start border
mask = ((intervals_to_filter_start[:, None] < stop) & (intervals_to_filter_end[:, None] >= start)).any(axis=1)
elif not start_inclusive and end_inclusive:
# require any overlap between A and some B, including touching end border
mask = ((intervals_to_filter_start[:, None] <= stop) & (intervals_to_filter_end[:, None] > start)).any(axis=1)

if invert:
mask = ~mask

masked_index = time_index[mask]
masked_index = masked_index.ravel()

# Filter the given Channel or Label
TS = TimeSeriesBase(masked_index-channel_or_label_to_filter.offset)
new_index = TS.time_index
new_start = TS.time_start
channel_or_label_to_filter.time_index = new_index
channel_or_label_to_filter.time_start = new_start
channel_or_label_to_filter.data = data[mask] if len(data) > 0 else None
if is_label:
channel_or_label_to_filter.text_data = text_data[mask] if len(text_data) > 0 else None

return channel_or_label_to_filter

Loading