diff --git a/notebooks/tutorials/02_parallel_execution_on_ray.ipynb b/notebooks/tutorials/02_parallel_execution_on_ray.ipynb
index d6ff7d5..7a94855 100644
--- a/notebooks/tutorials/02_parallel_execution_on_ray.ipynb
+++ b/notebooks/tutorials/02_parallel_execution_on_ray.ipynb
@@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": 1,
- "id": "0c2dfaec",
+ "id": "db13f7d9",
"metadata": {},
"outputs": [],
"source": [
@@ -22,36 +22,20 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2025-08-10 23:27:14,560\tINFO client_builder.py:242 -- Passing the following kwargs to ray.init() on the server: log_to_driver\n",
- "SIGTERM handler is not set because current thread is not the main thread.\n",
- "2025-08-10 23:27:17,455\tWARNING utils.py:1280 -- Python patch version mismatch: The cluster was started with:\n",
+ "2025-11-12 08:26:40,429\tINFO client_builder.py:242 -- Passing the following kwargs to ray.init() on the server: log_to_driver\n",
+ "2025-11-12 08:26:46,184\tWARNING utils.py:1280 -- Python patch version mismatch: The cluster was started with:\n",
" Ray: 2.48.0\n",
- " Python: 3.13.5\n",
+ " Python: 3.12.12\n",
"This process on Ray Client was started with:\n",
" Ray: 2.48.0\n",
- " Python: 3.13.3\n",
+ " Python: 3.12.10\n",
"\n"
]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\u001b[36m(autoscaler +28s)\u001b[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.\n",
- "\u001b[36m(autoscaler +28s)\u001b[0m Adding 5 node(s) of type workergroup.\n",
- "\u001b[36m(autoscaler +28s)\u001b[0m Resized to 6 CPUs, 5 GPUs.\n",
- "\u001b[36m(autoscaler +28s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*44. Add suitable node types to this cluster to resolve this issue.\n",
- "\u001b[36m(autoscaler +34s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*22. Add suitable node types to this cluster to resolve this issue.\n",
- "\u001b[36m(autoscaler +49s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*11. Add suitable node types to this cluster to resolve this issue.\n",
- "\u001b[36m(autoscaler +55s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*22. Add suitable node types to this cluster to resolve this issue.\n",
- "\u001b[36m(autoscaler +1m0s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*27. Add suitable node types to this cluster to resolve this issue.\n",
- "\u001b[36m(autoscaler +1m10s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*6. Add suitable node types to this cluster to resolve this issue.\n"
- ]
}
],
"source": [
"ray_engine = RayEngine(\n",
- " \"ray://raycluster-op-test-kuberay-head-svc.ray.svc.cluster.local:10001\"\n",
+ " \"ray://op-pipe-kuberay-head-svc.ray.svc.cluster.local:10001\",\n",
")"
]
},
@@ -142,7 +126,7 @@
}
],
"source": [
- "result_stream1 = add_numbers(input_stream)\n",
+ "result_stream1 = add_numbers.pod(input_stream)\n",
"result_stream1.run()\n",
"result_stream1.as_df()"
]
@@ -157,7 +141,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 7,
"id": "75ade620",
"metadata": {},
"outputs": [
@@ -194,14 +178,14 @@
"└─────┴─────┘"
]
},
- "execution_count": 9,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "result_stream2 = add_numbers(input_stream)\n",
- "result_stream2.run(ray_engine)\n",
+ "result_stream2 = add_numbers.pod(input_stream)\n",
+ "await result_stream2.run_async(execution_engine=ray_engine)\n",
"result_stream2.as_df()"
]
},
@@ -223,7 +207,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 22,
"id": "f459da03",
"metadata": {},
"outputs": [],
@@ -236,39 +220,281 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 23,
"id": "2befd400",
"metadata": {},
"outputs": [],
"source": [
- "store = op.stores.BatchedDeltaTableArrowStore(\"./test_store\")\n",
- "pipeline = op.Pipeline(\"pipeline_with_ray\", store)"
+ "database = op.databases.DeltaTableDatabase(\"./test_store\")\n",
+ "pipeline = op.Pipeline(\"pipeline_with_ray\", database)"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 24,
"id": "e21ecaf2",
"metadata": {},
"outputs": [],
"source": [
"with pipeline:\n",
- " result_stream = add_numbers(input_stream)"
+ " result_stream = add_numbers.pod(input_stream)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "a9c6cc81",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "PodNode[add_numbers]\n",
+ "
\n",
+ "
shape: (0, 4)| *id | sum | _source_sum | _context_key |
|---|
| i64 | i64 | str | str |
"
+ ],
+ "text/plain": [
+ "PodNode(pod=FunctionPod:add_numbers)"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pipeline.add_numbers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "2a3e6023",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline.run(execution_engine=ray_engine)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "565b3aa9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "PodNode[add_numbers]\n",
+ "\n",
+ "
shape: (50, 2)| *id | sum |
|---|
| i64 | i64 |
| 0 | 0 |
| 1 | 5 |
| 2 | 10 |
| 3 | 15 |
| 4 | 20 |
| … | … |
| 45 | 225 |
| 46 | 230 |
| 47 | 235 |
| 48 | 240 |
| 49 | 245 |
"
+ ],
+ "text/plain": [
+ "PodNode(pod=FunctionPod:add_numbers)"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pipeline.add_numbers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "id": "84826cb2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipeline.add_numbers.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "id": "bdd1f48b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "PodNode[add_numbers]\n",
+ "\n",
+ "
shape: (50, 2)| *id | sum |
|---|
| i64 | i64 |
| 0 | 0 |
| 1 | 5 |
| 2 | 10 |
| 3 | 15 |
| 4 | 20 |
| … | … |
| 45 | 225 |
| 46 | 230 |
| 47 | 235 |
| 48 | 240 |
| 49 | 245 |
"
+ ],
+ "text/plain": [
+ "PodNode(pod=FunctionPod:add_numbers)"
+ ]
+ },
+ "execution_count": 58,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pipeline.add_numbers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "id": "a67a57b0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "PodNode[add_numbers]\n",
+ "\n",
+ "
shape: (50, 2)| *id | sum |
|---|
| i64 | i64 |
| 0 | 0 |
| 1 | 5 |
| 2 | 10 |
| 3 | 15 |
| 4 | 20 |
| … | … |
| 45 | 225 |
| 46 | 230 |
| 47 | 235 |
| 48 | 240 |
| 49 | 245 |
"
+ ],
+ "text/plain": [
+ "PodNode(pod=FunctionPod:add_numbers)"
+ ]
+ },
+ "execution_count": 39,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[36m(autoscaler +2m31s)\u001b[0m Removing 1 nodes of type cpuOnlyGroup (idle).\n",
+ "\u001b[36m(autoscaler +2m31s)\u001b[0m Resized to 30 CPUs.\n",
+ "\u001b[36m(autoscaler +3m12s)\u001b[0m Removing 1 nodes of type cpuOnlyGroup (idle).\n",
+ "\u001b[36m(autoscaler +3m12s)\u001b[0m Resized to 0 CPUs.\n"
+ ]
+ }
+ ],
+ "source": [
+ "pipeline.add_numbers"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 38,
"id": "8449cb5d",
"metadata": {},
"outputs": [],
"source": [
- "pipeline.run(ray_engine)"
+ "pipeline.run(execution_engine=ray_engine)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "78f48078",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import asyncio"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "1575620a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def synchronous_run(async_func, *args, **kwargs):\n",
+ " \"\"\"\n",
+ " Use existing event loop if available.\n",
+ "\n",
+ " Pros: Reuses existing loop, more efficient\n",
+ " Cons: More complex, need to handle loop detection\n",
+ " \"\"\"\n",
+ " import asyncio\n",
+ " try:\n",
+ " # Check if we're already in an event loop\n",
+ " _ = asyncio.get_running_loop()\n",
+ "\n",
+ " def run_in_thread():\n",
+ " return asyncio.run(async_func(*args, **kwargs))\n",
+ "\n",
+ " import concurrent.futures\n",
+ "\n",
+ " with concurrent.futures.ThreadPoolExecutor() as executor:\n",
+ " future = executor.submit(run_in_thread)\n",
+ " return future.result()\n",
+ " except RuntimeError:\n",
+ " # No event loop running, safe to use asyncio.run()\n",
+ " return asyncio.run(async_func(*args, **kwargs))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "id": "2744b279",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "async def show_message():\n",
+ " await asyncio.sleep(10)\n",
+ " print(\"Hello, World!\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 56,
+ "id": "49c781b5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Hello, World!\n"
+ ]
+ }
+ ],
+ "source": [
+ "synchronous_run(show_message)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "id": "fdae4e82",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "async def test():\n",
+ " show_message()\n",
+ " "
]
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 16,
"id": "40743bb7",
"metadata": {},
"outputs": [
@@ -282,30 +508,19 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (50, 2)| id | sum |
|---|
| i64 | i64 |
| 0 | 0 |
| 1 | 5 |
| 2 | 10 |
| 3 | 15 |
| 4 | 20 |
| … | … |
| 45 | 225 |
| 46 | 230 |
| 47 | 235 |
| 48 | 240 |
| 49 | 245 |
"
+ "shape: (0, 4)| id | sum | _source_sum | _context_key |
|---|
| i64 | i64 | str | str |
"
],
"text/plain": [
- "shape: (50, 2)\n",
- "┌─────┬─────┐\n",
- "│ id ┆ sum │\n",
- "│ --- ┆ --- │\n",
- "│ i64 ┆ i64 │\n",
- "╞═════╪═════╡\n",
- "│ 0 ┆ 0 │\n",
- "│ 1 ┆ 5 │\n",
- "│ 2 ┆ 10 │\n",
- "│ 3 ┆ 15 │\n",
- "│ 4 ┆ 20 │\n",
- "│ … ┆ … │\n",
- "│ 45 ┆ 225 │\n",
- "│ 46 ┆ 230 │\n",
- "│ 47 ┆ 235 │\n",
- "│ 48 ┆ 240 │\n",
- "│ 49 ┆ 245 │\n",
- "└─────┴─────┘"
+ "shape: (0, 4)\n",
+ "┌─────┬─────┬─────────────┬──────────────┐\n",
+ "│ id ┆ sum ┆ _source_sum ┆ _context_key │\n",
+ "│ --- ┆ --- ┆ --- ┆ --- │\n",
+ "│ i64 ┆ i64 ┆ str ┆ str │\n",
+ "╞═════╪═════╪═════════════╪══════════════╡\n",
+ "└─────┴─────┴─────────────┴──────────────┘"
]
},
- "execution_count": 14,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -313,6 +528,14 @@
"source": [
"pipeline.add_numbers.as_df()"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "02af60fa",
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
@@ -331,7 +554,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.13.3"
+ "version": "3.12.10"
}
},
"nbformat": 4,
diff --git a/src/orcapod/core/pods.py b/src/orcapod/core/pods.py
index 0379bb5..9e2f9ad 100644
--- a/src/orcapod/core/pods.py
+++ b/src/orcapod/core/pods.py
@@ -10,6 +10,8 @@
ArrowPacket,
DictPacket,
)
+from functools import wraps
+
from orcapod.utils.git_utils import get_git_info_for_python_object
from orcapod.core.kernels import KernelStream, TrackedKernelBase
from orcapod.core.operators import Join
@@ -252,9 +254,14 @@ def function_pod(
"""
def decorator(func: Callable) -> CallableWithPod:
+
if func.__name__ == "":
raise ValueError("Lambda functions cannot be used with function_pod")
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ return func(*args, **kwargs)
+
# Store the original function in the module for pickling purposes
# and make sure to change the name of the function
@@ -267,9 +274,8 @@ def decorator(func: Callable) -> CallableWithPod:
label=label,
**kwargs,
)
- setattr(func, "pod", pod)
- return cast(CallableWithPod, func)
-
+ setattr(wrapper, "pod", pod)
+ return cast(CallableWithPod, wrapper)
return decorator
diff --git a/src/orcapod/core/sources/base.py b/src/orcapod/core/sources/base.py
index b8e128f..8f44599 100644
--- a/src/orcapod/core/sources/base.py
+++ b/src/orcapod/core/sources/base.py
@@ -1,5 +1,4 @@
from abc import abstractmethod
-from ast import Not
from collections.abc import Collection, Iterator
from typing import TYPE_CHECKING, Any
diff --git a/src/orcapod/core/streams/base.py b/src/orcapod/core/streams/base.py
index d96230a..2959cf3 100644
--- a/src/orcapod/core/streams/base.py
+++ b/src/orcapod/core/streams/base.py
@@ -1,3 +1,4 @@
+from calendar import c
import logging
from abc import abstractmethod
from collections.abc import Collection, Iterator, Mapping
@@ -475,6 +476,9 @@ def flow(
def _repr_html_(self) -> str:
df = self.as_polars_df()
+ # reorder columns
+ new_column_order = [c for c in df.columns if c in self.tag_keys()] + [c for c in df.columns if c not in self.tag_keys()]
+ df = df[new_column_order]
tag_map = {t: f"*{t}" for t in self.tag_keys()}
# TODO: construct repr html better
df = df.rename(tag_map)
diff --git a/src/orcapod/core/streams/pod_node_stream.py b/src/orcapod/core/streams/pod_node_stream.py
index 496bbc6..4596bcb 100644
--- a/src/orcapod/core/streams/pod_node_stream.py
+++ b/src/orcapod/core/streams/pod_node_stream.py
@@ -57,70 +57,24 @@ def mode(self) -> str:
async def run_async(
self,
+ *args: Any,
execution_engine: cp.ExecutionEngine | None = None,
execution_engine_opts: dict[str, Any] | None = None,
+ **kwargs: Any,
) -> None:
"""
Runs the stream, processing the input stream and preparing the output stream.
This is typically called before iterating over the packets.
"""
if self._cached_output_packets is None:
- cached_results = []
-
- # identify all entries in the input stream for which we still have not computed packets
- target_entries = self.input_stream.as_table(
- include_content_hash=constants.INPUT_PACKET_HASH,
- include_source=True,
- include_system_tags=True,
- )
- existing_entries = self.pod_node.get_all_cached_outputs(
- include_system_columns=True
+ cached_results, missing = self._identify_existing_and_missing_entries(*args,
+ execution_engine=execution_engine,
+ execution_engine_opts=execution_engine_opts,
+ **kwargs,
)
- if existing_entries is None or existing_entries.num_rows == 0:
- missing = target_entries.drop_columns([constants.INPUT_PACKET_HASH])
- existing = None
- else:
- all_results = target_entries.join(
- existing_entries.append_column(
- "_exists", pa.array([True] * len(existing_entries))
- ),
- keys=[constants.INPUT_PACKET_HASH],
- join_type="left outer",
- right_suffix="_right",
- )
- # grab all columns from target_entries first
- missing = (
- all_results.filter(pc.is_null(pc.field("_exists")))
- .select(target_entries.column_names)
- .drop_columns([constants.INPUT_PACKET_HASH])
- )
-
- existing = all_results.filter(
- pc.is_valid(pc.field("_exists"))
- ).drop_columns(
- [
- "_exists",
- constants.INPUT_PACKET_HASH,
- constants.PACKET_RECORD_ID,
- *self.input_stream.keys()[1], # remove the input packet keys
- ]
- # TODO: look into NOT fetching back the record ID
- )
-
- renamed = [
- c.removesuffix("_right") if c.endswith("_right") else c
- for c in existing.column_names
- ]
- existing = existing.rename_columns(renamed)
tag_keys = self.input_stream.keys()[0]
- if existing is not None and existing.num_rows > 0:
- # If there are existing entries, we can cache them
- existing_stream = TableStream(existing, tag_columns=tag_keys)
- for tag, packet in existing_stream.iter_packets():
- cached_results.append((tag, packet))
-
pending_calls = []
if missing is not None and missing.num_rows > 0:
for tag, packet in TableStream(missing, tag_columns=tag_keys):
@@ -134,23 +88,23 @@ async def run_async(
or self._execution_engine_opts,
)
pending_calls.append(pending)
- import asyncio
+ import asyncio
completed_calls = await asyncio.gather(*pending_calls)
for result in completed_calls:
cached_results.append(result)
+ self.clear_cache()
self._cached_output_packets = cached_results
self._set_modified_time()
+ self.pod_node.flush()
- def run(
- self,
- *args: Any,
+ def _identify_existing_and_missing_entries(self,
+ *args: Any,
execution_engine: cp.ExecutionEngine | None = None,
execution_engine_opts: dict[str, Any] | None = None,
- **kwargs: Any,
- ) -> None:
- cached_results = []
+ **kwargs: Any) -> tuple[list[tuple[cp.Tag, cp.Packet|None]], pa.Table | None]:
+ cached_results: list[tuple[cp.Tag, cp.Packet|None]] = []
# identify all entries in the input stream for which we still have not computed packets
if len(args) > 0 or len(kwargs) > 0:
@@ -223,6 +177,25 @@ def run(
for tag, packet in existing_stream.iter_packets():
cached_results.append((tag, packet))
+
+
+ return cached_results, missing
+
+ def run(
+ self,
+ *args: Any,
+ execution_engine: cp.ExecutionEngine | None = None,
+ execution_engine_opts: dict[str, Any] | None = None,
+ **kwargs: Any,
+ ) -> None:
+ tag_keys = self.input_stream.keys()[0]
+ cached_results, missing = self._identify_existing_and_missing_entries(
+ *args,
+ execution_engine=execution_engine,
+ execution_engine_opts=execution_engine_opts,
+ **kwargs,
+ )
+
if missing is not None and missing.num_rows > 0:
packet_record_to_output_lut: dict[str, cp.Packet | None] = {}
execution_engine_hash = (
@@ -257,11 +230,14 @@ def run(
)
cached_results.append((tag, output_packet))
+
+ # reset the cache and set new results
+ self.clear_cache()
self._cached_output_packets = cached_results
self._set_modified_time()
self.pod_node.flush()
# TODO: evaluate proper handling of cache here
- self.clear_cache()
+ # self.clear_cache()
def clear_cache(self) -> None:
self._cached_output_packets = None
@@ -300,115 +276,7 @@ def iter_packets(
self._cached_output_packets = cached_results
self._set_modified_time()
- # if self._cached_output_packets is None:
- # cached_results = []
-
- # # identify all entries in the input stream for which we still have not computed packets
- # target_entries = self.input_stream.as_table(
- # include_system_tags=True,
- # include_source=True,
- # include_content_hash=constants.INPUT_PACKET_HASH,
- # execution_engine=execution_engine,
- # )
- # existing_entries = self.pod_node.get_all_cached_outputs(
- # include_system_columns=True
- # )
- # if existing_entries is None or existing_entries.num_rows == 0:
- # missing = target_entries.drop_columns([constants.INPUT_PACKET_HASH])
- # existing = None
- # else:
- # # missing = target_entries.join(
- # # existing_entries,
- # # keys=[constants.INPUT_PACKET_HASH],
- # # join_type="left anti",
- # # )
- # # Single join that gives you both missing and existing
- # # More efficient - only bring the key column from existing_entries
- # # .select([constants.INPUT_PACKET_HASH]).append_column(
- # # "_exists", pa.array([True] * len(existing_entries))
- # # ),
-
- # # TODO: do more proper replacement operation
- # target_df = pl.DataFrame(target_entries)
- # existing_df = pl.DataFrame(
- # existing_entries.append_column(
- # "_exists", pa.array([True] * len(existing_entries))
- # )
- # )
- # all_results_df = target_df.join(
- # existing_df,
- # on=constants.INPUT_PACKET_HASH,
- # how="left",
- # suffix="_right",
- # )
- # all_results = all_results_df.to_arrow()
- # # all_results = target_entries.join(
- # # existing_entries.append_column(
- # # "_exists", pa.array([True] * len(existing_entries))
- # # ),
- # # keys=[constants.INPUT_PACKET_HASH],
- # # join_type="left outer",
- # # right_suffix="_right", # rename the existing records in case of collision of output packet keys with input packet keys
- # # )
- # # grab all columns from target_entries first
- # missing = (
- # all_results.filter(pc.is_null(pc.field("_exists")))
- # .select(target_entries.column_names)
- # .drop_columns([constants.INPUT_PACKET_HASH])
- # )
-
- # existing = all_results.filter(
- # pc.is_valid(pc.field("_exists"))
- # ).drop_columns(
- # [
- # "_exists",
- # constants.INPUT_PACKET_HASH,
- # constants.PACKET_RECORD_ID,
- # *self.input_stream.keys()[1], # remove the input packet keys
- # ]
- # # TODO: look into NOT fetching back the record ID
- # )
- # renamed = [
- # c.removesuffix("_right") if c.endswith("_right") else c
- # for c in existing.column_names
- # ]
- # existing = existing.rename_columns(renamed)
-
- # tag_keys = self.input_stream.keys()[0]
-
- # if existing is not None and existing.num_rows > 0:
- # # If there are existing entries, we can cache them
- # existing_stream = TableStream(existing, tag_columns=tag_keys)
- # for tag, packet in existing_stream.iter_packets():
- # cached_results.append((tag, packet))
- # yield tag, packet
-
- # if missing is not None and missing.num_rows > 0:
- # hash_to_output_lut: dict[str, cp.Packet | None] = {}
- # for tag, packet in TableStream(missing, tag_columns=tag_keys):
- # # Since these packets are known to be missing, skip the cache lookup
- # packet_hash = packet.content_hash().to_string()
- # if packet_hash in hash_to_output_lut:
- # output_packet = hash_to_output_lut[packet_hash]
- # else:
- # tag, output_packet = self.pod_node.call(
- # tag,
- # packet,
- # skip_cache_lookup=True,
- # execution_engine=execution_engine,
- # )
- # hash_to_output_lut[packet_hash] = output_packet
- # cached_results.append((tag, output_packet))
- # if output_packet is not None:
- # yield tag, output_packet
-
- # self._cached_output_packets = cached_results
- # self._set_modified_time()
- # else:
- # for tag, packet in self._cached_output_packets:
- # if packet is not None:
- # yield tag, packet
-
+
def keys(
self, include_system_tags: bool = False
) -> tuple[tuple[str, ...], tuple[str, ...]]:
diff --git a/src/orcapod/execution_engines/ray_execution_engine.py b/src/orcapod/execution_engines/ray_execution_engine.py
index d4e6727..b36de84 100644
--- a/src/orcapod/execution_engines/ray_execution_engine.py
+++ b/src/orcapod/execution_engines/ray_execution_engine.py
@@ -26,6 +26,11 @@ class RayEngine:
3. No polling needed - Ray handles async integration
"""
+ @property
+ def supports_async(self) -> bool:
+ """Indicate that this engine supports async execution."""
+ return True
+
def __init__(self, ray_address: str | None = None, **ray_init_kwargs):
"""Initialize Ray with native async support."""
diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py
index b3534ca..45d83e0 100644
--- a/src/orcapod/pipeline/graph.py
+++ b/src/orcapod/pipeline/graph.py
@@ -4,7 +4,7 @@
from orcapod import contexts
from orcapod.protocols import core_protocols as cp
from orcapod.protocols import database_protocols as dbp
-from typing import Any
+from typing import Any, cast
from collections.abc import Collection
import os
import tempfile
@@ -18,6 +18,8 @@
else:
nx = LazyModule("networkx")
+logger = logging.getLogger(__name__)
+
def synchronous_run(async_func, *args, **kwargs):
"""
@@ -43,7 +45,6 @@ def run_in_thread():
return asyncio.run(async_func(*args, **kwargs))
-logger = logging.getLogger(__name__)
class GraphNode:
@@ -94,11 +95,39 @@ def __init__(
self.results_store_path_prefix = self.name + ("_results",)
self.pipeline_database = pipeline_database
self.results_database = results_database
- self.nodes: dict[str, Node] = {}
+ self._nodes: dict[str, Node] = {}
self.auto_compile = auto_compile
self._dirty = False
self._ordered_nodes = [] # Track order of invocations
+ @property
+ def nodes(self) -> dict[str, Node]:
+ return self._nodes.copy()
+
+ @property
+ def function_pods(self) -> dict[str, cp.Pod]:
+ return {
+ label: cast(cp.Pod, node)
+ for label, node in self._nodes.items()
+ if getattr(node, "kernel_type") == "function"
+ }
+
+ @property
+ def source_pods(self) -> dict[str, cp.Source]:
+ return {
+ label: node
+ for label, node in self._nodes.items()
+ if getattr(node, "kernel_type") == "source"
+ }
+
+ @property
+ def operator_pods(self) -> dict[str, cp.Kernel]:
+ return {
+ label: node
+ for label, node in self._nodes.items()
+ if getattr(node, "kernel_type") == "operator"
+ }
+
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
"""
Exit the pipeline context, ensuring all nodes are properly closed.
@@ -156,13 +185,13 @@ def compile(self) -> None:
# If there are multiple nodes with the same label, we need to resolve the collision
logger.info(f"Collision detected for label '{label}': {nodes}")
for i, node in enumerate(nodes, start=1):
- self.nodes[f"{label}_{i}"] = node
+ self._nodes[f"{label}_{i}"] = node
node.label = f"{label}_{i}"
else:
- self.nodes[label] = nodes[0]
+ self._nodes[label] = nodes[0]
nodes[0].label = label
- self.label_lut = {v: k for k, v in self.nodes.items()}
+ self.label_lut = {v: k for k, v in self._nodes.items()}
self.graph = node_graph
@@ -172,7 +201,7 @@ def show_graph(self, **kwargs) -> None:
def set_mode(self, mode: str) -> None:
if mode not in ("production", "development"):
raise ValueError("Mode must be either 'production' or 'development'")
- for node in self.nodes.values():
+ for node in self._nodes.values():
if hasattr(node, "set_mode"):
node.set_mode(mode)
@@ -201,6 +230,16 @@ def run(
may implement more efficient graph traversal algorithms.
"""
import networkx as nx
+ if run_async is True and (execution_engine is None or not execution_engine.supports_async):
+ raise ValueError(
+ "Cannot run asynchronously with an execution engine that does not support async."
+ )
+
+ # if set to None, determine based on execution engine capabilities
+ if run_async is None:
+ run_async = execution_engine is not None and execution_engine.supports_async
+
+ logger.info(f"Running pipeline with run_async={run_async}")
for node in nx.topological_sort(self.graph):
if run_async:
@@ -257,27 +296,27 @@ def wrap_invocation(
def __getattr__(self, item: str) -> Any:
"""Allow direct access to pipeline attributes."""
- if item in self.nodes:
- return self.nodes[item]
+ if item in self._nodes:
+ return self._nodes[item]
raise AttributeError(f"Pipeline has no attribute '{item}'")
def __dir__(self) -> list[str]:
"""Return a list of attributes and methods of the pipeline."""
- return list(super().__dir__()) + list(self.nodes.keys())
+ return list(super().__dir__()) + list(self._nodes.keys())
def rename(self, old_name: str, new_name: str) -> None:
"""
Rename a node in the pipeline.
This will update the label and the internal mapping.
"""
- if old_name not in self.nodes:
+ if old_name not in self._nodes:
raise KeyError(f"Node '{old_name}' does not exist in the pipeline.")
- if new_name in self.nodes:
+ if new_name in self._nodes:
raise KeyError(f"Node '{new_name}' already exists in the pipeline.")
- node = self.nodes[old_name]
- del self.nodes[old_name]
+ node = self._nodes[old_name]
+ del self._nodes[old_name]
node.label = new_name
- self.nodes[new_name] = node
+ self._nodes[new_name] = node
logger.info(f"Node '{old_name}' renamed to '{new_name}'")
diff --git a/src/orcapod/pipeline/nodes.py b/src/orcapod/pipeline/nodes.py
index db9d0d4..af63971 100644
--- a/src/orcapod/pipeline/nodes.py
+++ b/src/orcapod/pipeline/nodes.py
@@ -264,6 +264,15 @@ def __init__(
pipeline_path_prefix=pipeline_path_prefix,
**kwargs,
)
+ self._execution_engine_opts: dict[str, Any] = {}
+
+ @property
+ def execution_engine_opts(self) -> dict[str, Any]:
+ return self._execution_engine_opts.copy()
+
+ @execution_engine_opts.setter
+ def execution_engine_opts(self, opts: dict[str, Any]) -> None:
+ self._execution_engine_opts = opts
def flush(self):
self.pipeline_database.flush()
@@ -309,6 +318,11 @@ def call(
if record_id is None:
record_id = self.get_record_id(packet, execution_engine_hash)
+ combined_execution_engine_opts = self.execution_engine_opts
+ if execution_engine_opts is not None:
+ combined_execution_engine_opts.update(execution_engine_opts)
+
+
tag, output_packet = super().call(
tag,
packet,
@@ -316,7 +330,7 @@ def call(
skip_cache_lookup=skip_cache_lookup,
skip_cache_insert=skip_cache_insert,
execution_engine=execution_engine,
- execution_engine_opts=execution_engine_opts,
+ execution_engine_opts=combined_execution_engine_opts,
)
# if output_packet is not None:
@@ -348,6 +362,12 @@ async def async_call(
if record_id is None:
record_id = self.get_record_id(packet, execution_engine_hash)
+
+ combined_execution_engine_opts = self.execution_engine_opts
+ if execution_engine_opts is not None:
+ combined_execution_engine_opts.update(execution_engine_opts)
+
+
tag, output_packet = await super().async_call(
tag,
packet,
@@ -355,7 +375,7 @@ async def async_call(
skip_cache_lookup=skip_cache_lookup,
skip_cache_insert=skip_cache_insert,
execution_engine=execution_engine,
- execution_engine_opts=execution_engine_opts,
+ execution_engine_opts=combined_execution_engine_opts,
)
if output_packet is not None:
diff --git a/src/orcapod/protocols/core_protocols/base.py b/src/orcapod/protocols/core_protocols/base.py
index 7faad25..87d9a81 100644
--- a/src/orcapod/protocols/core_protocols/base.py
+++ b/src/orcapod/protocols/core_protocols/base.py
@@ -41,6 +41,10 @@ class ExecutionEngine(Protocol):
"local", "threadpool", "processpool", or "ray" and is used for logging
and diagnostics.
"""
+ @property
+ def supports_async(self) -> bool:
+ """Indicate whether this engine supports async execution."""
+ ...
@property
def name(self) -> str: