Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 315 additions & 1 deletion src/orcapod/core/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from os import PathLike
from pathlib import Path
from typing import Any, Literal

import pandas as pd
import polars as pl
Copy link

Copilot AI Jul 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import 'polars as pl' is not used in this file. Consider removing it to keep imports clean.

Copilot uses AI. Check for mistakes.

from orcapod.core.base import Source
Expand Down Expand Up @@ -202,3 +202,317 @@ def keys(

def computed_label(self) -> str | None:
return self.stream.label


class DataFrameSource(Source):
"""
A stream source that sources data from a pandas DataFrame.

For each row in the DataFrame, yields a tuple containing:
- A tag generated either by the provided tag_function or defaulting to the row index
- A packet containing values from specified columns as key-value pairs

Parameters
----------
columns : list[str]
List of column names to include in the packet. These will serve as the keys
in the packet, with the corresponding row values as the packet values.
data : pd.DataFrame
The pandas DataFrame to source data from
tag_function : Callable[[pd.Series, int], Tag] | None, default=None
Optional function to generate a tag from a DataFrame row and its index.
The function receives the row as a pandas Series and the row index as arguments.
If None, uses the row index in a dict with key 'row_index'
tag_function_hash_mode : Literal["content", "signature", "name"], default="name"
How to hash the tag function for identity purposes
expected_tag_keys : Collection[str] | None, default=None
Expected tag keys for the stream
label : str | None, default=None
Optional label for the source

Examples
--------
>>> import pandas as pd
>>> df = pd.DataFrame({
... 'file_path': ['/path/to/file1.txt', '/path/to/file2.txt'],
... 'metadata_path': ['/path/to/meta1.json', '/path/to/meta2.json'],
... 'sample_id': ['sample_1', 'sample_2']
... })
>>> # Use sample_id column for tags and include file paths in packets
>>> source = DataFrameSource(
... columns=['file_path', 'metadata_path'],
... data=df,
... tag_function=lambda row, idx: {'sample_id': row['sample_id']}
... )
>>> # Use default row index tagging
>>> source = DataFrameSource(['file_path', 'metadata_path'], df)
"""

@staticmethod
def default_tag_function(row: pd.Series, idx: int) -> Tag:
return {"row_index": idx}

def __init__(
self,
columns: list[str],
data: pd.DataFrame,
tag_function: Callable[[pd.Series, int], Tag] | None = None,
label: str | None = None,
tag_function_hash_mode: Literal["content", "signature", "name"] = "name",
expected_tag_keys: Collection[str] | None = None,
**kwargs,
) -> None:
super().__init__(label=label, **kwargs)
self.columns = columns
self.dataframe = data

# Validate that all specified columns exist in the DataFrame
missing_columns = set(columns) - set(data.columns)
if missing_columns:
raise ValueError(f"Columns not found in DataFrame: {missing_columns}")

if tag_function is None:
Copy link

Copilot AI Jul 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Similar tag_function initialization logic appears in both DataFrameSource and ListSource—consider abstracting this into a shared helper or the base class to reduce duplication.

Copilot uses AI. Check for mistakes.
tag_function = self.__class__.default_tag_function
# If using default tag function and no explicit expected_tag_keys, set to default
if expected_tag_keys is None:
expected_tag_keys = ["row_index"]

self.expected_tag_keys = expected_tag_keys
self.tag_function = tag_function
self.tag_function_hash_mode = tag_function_hash_mode

def forward(self, *streams: SyncStream) -> SyncStream:
if len(streams) != 0:
raise ValueError(
"DataFrameSource does not support forwarding streams. "
"It generates its own stream from the DataFrame."
)

def generator() -> Iterator[tuple[Tag, Packet]]:
for idx, row in self.dataframe.iterrows():
tag = self.tag_function(row, idx)
packet = {col: row[col] for col in self.columns}
Comment on lines +292 to +294
Copy link

Copilot AI Jul 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using DataFrame.iterrows() can be slow for large data sets. Consider iterating with itertuples or another vectorized approach for better performance.

Suggested change
for idx, row in self.dataframe.iterrows():
tag = self.tag_function(row, idx)
packet = {col: row[col] for col in self.columns}
for row in self.dataframe.itertuples(index=True, name=None):
idx, *values = row
row_series = pd.Series(values, index=self.dataframe.columns)
tag = self.tag_function(row_series, idx)
packet = {col: row_series[col] for col in self.columns}

Copilot uses AI. Check for mistakes.
yield tag, packet

return SyncStreamFromGenerator(generator)

def __repr__(self) -> str:
return f"DataFrameSource(cols={self.columns}, rows={len(self.dataframe)})"

def identity_structure(self, *streams: SyncStream) -> Any:
hash_function_kwargs = {}
if self.tag_function_hash_mode == "content":
# if using content hash, exclude few
hash_function_kwargs = {
"include_name": False,
"include_module": False,
"include_declaration": False,
}

tag_function_hash = hash_function(
self.tag_function,
function_hash_mode=self.tag_function_hash_mode,
hash_kwargs=hash_function_kwargs,
)

# Convert DataFrame to hashable representation
df_subset = self.dataframe[self.columns]
df_content = df_subset.to_dict('records')
df_hashable = tuple(tuple(sorted(record.items())) for record in df_content)

return (
self.__class__.__name__,
tuple(self.columns),
df_hashable,
tag_function_hash,
) + tuple(streams)

def keys(
self, *streams: SyncStream, trigger_run: bool = False
) -> tuple[Collection[str] | None, Collection[str] | None]:
"""
Returns the keys of the stream. The keys are the names of the packets
in the stream. The keys are used to identify the packets in the stream.
If expected_keys are provided, they will be used instead of the default keys.
"""
if len(streams) != 0:
raise ValueError(
"DataFrameSource does not support forwarding streams. "
"It generates its own stream from the DataFrame."
)

if self.expected_tag_keys is not None:
return tuple(self.expected_tag_keys), tuple(self.columns)
return super().keys(trigger_run=trigger_run)

def claims_unique_tags(
self, *streams: "SyncStream", trigger_run: bool = True
) -> bool | None:
if len(streams) != 0:
raise ValueError(
"DataFrameSource does not support forwarding streams. "
"It generates its own stream from the DataFrame."
)
# Claim uniqueness only if the default tag function is used
if self.tag_function == self.__class__.default_tag_function:
return True
# Otherwise, delegate to the base class
return super().claims_unique_tags(trigger_run=trigger_run)


class ListSource(Source):
"""
A stream source that sources data from a list of elements.

For each element in the list, yields a tuple containing:
- A tag generated either by the provided tag_function or defaulting to the element index
- A packet containing the element under the provided name key

Parameters
----------
name : str
The key name under which each list element will be stored in the packet
data : list[Any]
The list of elements to source data from
tag_function : Callable[[Any, int], Tag] | None, default=None
Optional function to generate a tag from a list element and its index.
The function receives the element and the index as arguments.
If None, uses the element index in a dict with key 'element_index'
tag_function_hash_mode : Literal["content", "signature", "name"], default="name"
How to hash the tag function for identity purposes
expected_tag_keys : Collection[str] | None, default=None
Expected tag keys for the stream
label : str | None, default=None
Optional label for the source

Examples
--------
>>> # Simple list of file names
>>> file_list = ['/path/to/file1.txt', '/path/to/file2.txt', '/path/to/file3.txt']
>>> source = ListSource('file_path', file_list)
>>>
>>> # Custom tag function using filename stems
>>> from pathlib import Path
>>> source = ListSource(
... 'file_path',
... file_list,
... tag_function=lambda elem, idx: {'file_name': Path(elem).stem}
... )
>>>
>>> # List of sample IDs
>>> samples = ['sample_001', 'sample_002', 'sample_003']
>>> source = ListSource(
... 'sample_id',
... samples,
... tag_function=lambda elem, idx: {'sample': elem}
... )
"""

@staticmethod
def default_tag_function(element: Any, idx: int) -> Tag:
return {"element_index": idx}

def __init__(
self,
name: str,
data: list[Any],
tag_function: Callable[[Any, int], Tag] | None = None,
label: str | None = None,
tag_function_hash_mode: Literal["content", "signature", "name"] = "name",
expected_tag_keys: Collection[str] | None = None,
**kwargs,
) -> None:
super().__init__(label=label, **kwargs)
self.name = name
self.elements = list(data) # Create a copy to avoid external modifications

if tag_function is None:
tag_function = self.__class__.default_tag_function
# If using default tag function and no explicit expected_tag_keys, set to default
if expected_tag_keys is None:
expected_tag_keys = ["element_index"]

self.expected_tag_keys = expected_tag_keys
self.tag_function = tag_function
self.tag_function_hash_mode = tag_function_hash_mode

def forward(self, *streams: SyncStream) -> SyncStream:
if len(streams) != 0:
raise ValueError(
"ListSource does not support forwarding streams. "
"It generates its own stream from the list elements."
)

def generator() -> Iterator[tuple[Tag, Packet]]:
for idx, element in enumerate(self.elements):
tag = self.tag_function(element, idx)
packet = {self.name: element}
yield tag, packet

return SyncStreamFromGenerator(generator)

def __repr__(self) -> str:
return f"ListSource({self.name}, {len(self.elements)} elements)"

def identity_structure(self, *streams: SyncStream) -> Any:
hash_function_kwargs = {}
if self.tag_function_hash_mode == "content":
# if using content hash, exclude few
hash_function_kwargs = {
"include_name": False,
"include_module": False,
"include_declaration": False,
}

tag_function_hash = hash_function(
self.tag_function,
function_hash_mode=self.tag_function_hash_mode,
hash_kwargs=hash_function_kwargs,
)

# Convert list to hashable representation
# Handle potentially unhashable elements by converting to string
try:
elements_hashable = tuple(self.elements)
except TypeError:
# If elements are not hashable, convert to string representation
elements_hashable = tuple(str(elem) for elem in self.elements)

return (
self.__class__.__name__,
self.name,
elements_hashable,
tag_function_hash,
) + tuple(streams)

def keys(
self, *streams: SyncStream, trigger_run: bool = False
) -> tuple[Collection[str] | None, Collection[str] | None]:
"""
Returns the keys of the stream. The keys are the names of the packets
in the stream. The keys are used to identify the packets in the stream.
If expected_keys are provided, they will be used instead of the default keys.
"""
if len(streams) != 0:
raise ValueError(
"ListSource does not support forwarding streams. "
"It generates its own stream from the list elements."
)

if self.expected_tag_keys is not None:
return tuple(self.expected_tag_keys), (self.name,)
return super().keys(trigger_run=trigger_run)

def claims_unique_tags(
self, *streams: "SyncStream", trigger_run: bool = True
) -> bool | None:
if len(streams) != 0:
raise ValueError(
"ListSource does not support forwarding streams. "
"It generates its own stream from the list elements."
)
# Claim uniqueness only if the default tag function is used
if self.tag_function == self.__class__.default_tag_function:
return True
# Otherwise, delegate to the base class
return super().claims_unique_tags(trigger_run=trigger_run)