diff --git a/notebooks/tutorials/02_parallel_execution_on_ray.ipynb b/notebooks/tutorials/02_parallel_execution_on_ray.ipynb index d6ff7d5..7a94855 100644 --- a/notebooks/tutorials/02_parallel_execution_on_ray.ipynb +++ b/notebooks/tutorials/02_parallel_execution_on_ray.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "0c2dfaec", + "id": "db13f7d9", "metadata": {}, "outputs": [], "source": [ @@ -22,36 +22,20 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-08-10 23:27:14,560\tINFO client_builder.py:242 -- Passing the following kwargs to ray.init() on the server: log_to_driver\n", - "SIGTERM handler is not set because current thread is not the main thread.\n", - "2025-08-10 23:27:17,455\tWARNING utils.py:1280 -- Python patch version mismatch: The cluster was started with:\n", + "2025-11-12 08:26:40,429\tINFO client_builder.py:242 -- Passing the following kwargs to ray.init() on the server: log_to_driver\n", + "2025-11-12 08:26:46,184\tWARNING utils.py:1280 -- Python patch version mismatch: The cluster was started with:\n", " Ray: 2.48.0\n", - " Python: 3.13.5\n", + " Python: 3.12.12\n", "This process on Ray Client was started with:\n", " Ray: 2.48.0\n", - " Python: 3.13.3\n", + " Python: 3.12.10\n", "\n" ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[36m(autoscaler +28s)\u001b[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.\n", - "\u001b[36m(autoscaler +28s)\u001b[0m Adding 5 node(s) of type workergroup.\n", - "\u001b[36m(autoscaler +28s)\u001b[0m Resized to 6 CPUs, 5 GPUs.\n", - "\u001b[36m(autoscaler +28s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*44. Add suitable node types to this cluster to resolve this issue.\n", - "\u001b[36m(autoscaler +34s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*22. Add suitable node types to this cluster to resolve this issue.\n", - "\u001b[36m(autoscaler +49s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*11. Add suitable node types to this cluster to resolve this issue.\n", - "\u001b[36m(autoscaler +55s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*22. Add suitable node types to this cluster to resolve this issue.\n", - "\u001b[36m(autoscaler +1m0s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*27. Add suitable node types to this cluster to resolve this issue.\n", - "\u001b[36m(autoscaler +1m10s)\u001b[0m No available node types can fulfill resource requests {'CPU': 1.0}*6. Add suitable node types to this cluster to resolve this issue.\n" - ] } ], "source": [ "ray_engine = RayEngine(\n", - " \"ray://raycluster-op-test-kuberay-head-svc.ray.svc.cluster.local:10001\"\n", + " \"ray://op-pipe-kuberay-head-svc.ray.svc.cluster.local:10001\",\n", ")" ] }, @@ -142,7 +126,7 @@ } ], "source": [ - "result_stream1 = add_numbers(input_stream)\n", + "result_stream1 = add_numbers.pod(input_stream)\n", "result_stream1.run()\n", "result_stream1.as_df()" ] @@ -157,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "id": "75ade620", "metadata": {}, "outputs": [ @@ -194,14 +178,14 @@ "└─────┴─────┘" ] }, - "execution_count": 9, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "result_stream2 = add_numbers(input_stream)\n", - "result_stream2.run(ray_engine)\n", + "result_stream2 = add_numbers.pod(input_stream)\n", + "await result_stream2.run_async(execution_engine=ray_engine)\n", "result_stream2.as_df()" ] }, @@ -223,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 22, "id": "f459da03", "metadata": {}, "outputs": [], @@ -236,39 +220,281 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 23, "id": "2befd400", "metadata": {}, "outputs": [], "source": [ - "store = op.stores.BatchedDeltaTableArrowStore(\"./test_store\")\n", - "pipeline = op.Pipeline(\"pipeline_with_ray\", store)" + "database = op.databases.DeltaTableDatabase(\"./test_store\")\n", + "pipeline = op.Pipeline(\"pipeline_with_ray\", database)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 24, "id": "e21ecaf2", "metadata": {}, "outputs": [], "source": [ "with pipeline:\n", - " result_stream = add_numbers(input_stream)" + " result_stream = add_numbers.pod(input_stream)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "a9c6cc81", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "PodNode[add_numbers]\n", + "
\n", + "shape: (0, 4)
*idsum_source_sum_context_key
i64i64strstr
" + ], + "text/plain": [ + "PodNode(pod=FunctionPod:add_numbers)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.add_numbers" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "2a3e6023", + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.run(execution_engine=ray_engine)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "565b3aa9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "PodNode[add_numbers]\n", + "
\n", + "shape: (50, 2)
*idsum
i64i64
00
15
210
315
420
45225
46230
47235
48240
49245
" + ], + "text/plain": [ + "PodNode(pod=FunctionPod:add_numbers)" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.add_numbers" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "84826cb2", + "metadata": {}, + "outputs": [], + "source": [ + "pipeline.add_numbers.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "bdd1f48b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "PodNode[add_numbers]\n", + "
\n", + "shape: (50, 2)
*idsum
i64i64
00
15
210
315
420
45225
46230
47235
48240
49245
" + ], + "text/plain": [ + "PodNode(pod=FunctionPod:add_numbers)" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipeline.add_numbers" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "a67a57b0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "PodNode[add_numbers]\n", + "
\n", + "shape: (50, 2)
*idsum
i64i64
00
15
210
315
420
45225
46230
47235
48240
49245
" + ], + "text/plain": [ + "PodNode(pod=FunctionPod:add_numbers)" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[36m(autoscaler +2m31s)\u001b[0m Removing 1 nodes of type cpuOnlyGroup (idle).\n", + "\u001b[36m(autoscaler +2m31s)\u001b[0m Resized to 30 CPUs.\n", + "\u001b[36m(autoscaler +3m12s)\u001b[0m Removing 1 nodes of type cpuOnlyGroup (idle).\n", + "\u001b[36m(autoscaler +3m12s)\u001b[0m Resized to 0 CPUs.\n" + ] + } + ], + "source": [ + "pipeline.add_numbers" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 38, "id": "8449cb5d", "metadata": {}, "outputs": [], "source": [ - "pipeline.run(ray_engine)" + "pipeline.run(execution_engine=ray_engine)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "78f48078", + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "1575620a", + "metadata": {}, + "outputs": [], + "source": [ + "def synchronous_run(async_func, *args, **kwargs):\n", + " \"\"\"\n", + " Use existing event loop if available.\n", + "\n", + " Pros: Reuses existing loop, more efficient\n", + " Cons: More complex, need to handle loop detection\n", + " \"\"\"\n", + " import asyncio\n", + " try:\n", + " # Check if we're already in an event loop\n", + " _ = asyncio.get_running_loop()\n", + "\n", + " def run_in_thread():\n", + " return asyncio.run(async_func(*args, **kwargs))\n", + "\n", + " import concurrent.futures\n", + "\n", + " with concurrent.futures.ThreadPoolExecutor() as executor:\n", + " future = executor.submit(run_in_thread)\n", + " return future.result()\n", + " except RuntimeError:\n", + " # No event loop running, safe to use asyncio.run()\n", + " return asyncio.run(async_func(*args, **kwargs))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "2744b279", + "metadata": {}, + "outputs": [], + "source": [ + "async def show_message():\n", + " await asyncio.sleep(10)\n", + " print(\"Hello, World!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "49c781b5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Hello, World!\n" + ] + } + ], + "source": [ + "synchronous_run(show_message)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "fdae4e82", + "metadata": {}, + "outputs": [], + "source": [ + "async def test():\n", + " show_message()\n", + " " ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "id": "40743bb7", "metadata": {}, "outputs": [ @@ -282,30 +508,19 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (50, 2)
idsum
i64i64
00
15
210
315
420
45225
46230
47235
48240
49245
" + "shape: (0, 4)
idsum_source_sum_context_key
i64i64strstr
" ], "text/plain": [ - "shape: (50, 2)\n", - "┌─────┬─────┐\n", - "│ id ┆ sum │\n", - "│ --- ┆ --- │\n", - "│ i64 ┆ i64 │\n", - "╞═════╪═════╡\n", - "│ 0 ┆ 0 │\n", - "│ 1 ┆ 5 │\n", - "│ 2 ┆ 10 │\n", - "│ 3 ┆ 15 │\n", - "│ 4 ┆ 20 │\n", - "│ … ┆ … │\n", - "│ 45 ┆ 225 │\n", - "│ 46 ┆ 230 │\n", - "│ 47 ┆ 235 │\n", - "│ 48 ┆ 240 │\n", - "│ 49 ┆ 245 │\n", - "└─────┴─────┘" + "shape: (0, 4)\n", + "┌─────┬─────┬─────────────┬──────────────┐\n", + "│ id ┆ sum ┆ _source_sum ┆ _context_key │\n", + "│ --- ┆ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ i64 ┆ str ┆ str │\n", + "╞═════╪═════╪═════════════╪══════════════╡\n", + "└─────┴─────┴─────────────┴──────────────┘" ] }, - "execution_count": 14, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -313,6 +528,14 @@ "source": [ "pipeline.add_numbers.as_df()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02af60fa", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -331,7 +554,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.3" + "version": "3.12.10" } }, "nbformat": 4, diff --git a/src/orcapod/core/pods.py b/src/orcapod/core/pods.py index 0379bb5..9e2f9ad 100644 --- a/src/orcapod/core/pods.py +++ b/src/orcapod/core/pods.py @@ -10,6 +10,8 @@ ArrowPacket, DictPacket, ) +from functools import wraps + from orcapod.utils.git_utils import get_git_info_for_python_object from orcapod.core.kernels import KernelStream, TrackedKernelBase from orcapod.core.operators import Join @@ -252,9 +254,14 @@ def function_pod( """ def decorator(func: Callable) -> CallableWithPod: + if func.__name__ == "": raise ValueError("Lambda functions cannot be used with function_pod") + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + # Store the original function in the module for pickling purposes # and make sure to change the name of the function @@ -267,9 +274,8 @@ def decorator(func: Callable) -> CallableWithPod: label=label, **kwargs, ) - setattr(func, "pod", pod) - return cast(CallableWithPod, func) - + setattr(wrapper, "pod", pod) + return cast(CallableWithPod, wrapper) return decorator diff --git a/src/orcapod/core/sources/base.py b/src/orcapod/core/sources/base.py index b8e128f..8f44599 100644 --- a/src/orcapod/core/sources/base.py +++ b/src/orcapod/core/sources/base.py @@ -1,5 +1,4 @@ from abc import abstractmethod -from ast import Not from collections.abc import Collection, Iterator from typing import TYPE_CHECKING, Any diff --git a/src/orcapod/core/streams/base.py b/src/orcapod/core/streams/base.py index d96230a..2959cf3 100644 --- a/src/orcapod/core/streams/base.py +++ b/src/orcapod/core/streams/base.py @@ -1,3 +1,4 @@ +from calendar import c import logging from abc import abstractmethod from collections.abc import Collection, Iterator, Mapping @@ -475,6 +476,9 @@ def flow( def _repr_html_(self) -> str: df = self.as_polars_df() + # reorder columns + new_column_order = [c for c in df.columns if c in self.tag_keys()] + [c for c in df.columns if c not in self.tag_keys()] + df = df[new_column_order] tag_map = {t: f"*{t}" for t in self.tag_keys()} # TODO: construct repr html better df = df.rename(tag_map) diff --git a/src/orcapod/core/streams/pod_node_stream.py b/src/orcapod/core/streams/pod_node_stream.py index 496bbc6..4596bcb 100644 --- a/src/orcapod/core/streams/pod_node_stream.py +++ b/src/orcapod/core/streams/pod_node_stream.py @@ -57,70 +57,24 @@ def mode(self) -> str: async def run_async( self, + *args: Any, execution_engine: cp.ExecutionEngine | None = None, execution_engine_opts: dict[str, Any] | None = None, + **kwargs: Any, ) -> None: """ Runs the stream, processing the input stream and preparing the output stream. This is typically called before iterating over the packets. """ if self._cached_output_packets is None: - cached_results = [] - - # identify all entries in the input stream for which we still have not computed packets - target_entries = self.input_stream.as_table( - include_content_hash=constants.INPUT_PACKET_HASH, - include_source=True, - include_system_tags=True, - ) - existing_entries = self.pod_node.get_all_cached_outputs( - include_system_columns=True + cached_results, missing = self._identify_existing_and_missing_entries(*args, + execution_engine=execution_engine, + execution_engine_opts=execution_engine_opts, + **kwargs, ) - if existing_entries is None or existing_entries.num_rows == 0: - missing = target_entries.drop_columns([constants.INPUT_PACKET_HASH]) - existing = None - else: - all_results = target_entries.join( - existing_entries.append_column( - "_exists", pa.array([True] * len(existing_entries)) - ), - keys=[constants.INPUT_PACKET_HASH], - join_type="left outer", - right_suffix="_right", - ) - # grab all columns from target_entries first - missing = ( - all_results.filter(pc.is_null(pc.field("_exists"))) - .select(target_entries.column_names) - .drop_columns([constants.INPUT_PACKET_HASH]) - ) - - existing = all_results.filter( - pc.is_valid(pc.field("_exists")) - ).drop_columns( - [ - "_exists", - constants.INPUT_PACKET_HASH, - constants.PACKET_RECORD_ID, - *self.input_stream.keys()[1], # remove the input packet keys - ] - # TODO: look into NOT fetching back the record ID - ) - - renamed = [ - c.removesuffix("_right") if c.endswith("_right") else c - for c in existing.column_names - ] - existing = existing.rename_columns(renamed) tag_keys = self.input_stream.keys()[0] - if existing is not None and existing.num_rows > 0: - # If there are existing entries, we can cache them - existing_stream = TableStream(existing, tag_columns=tag_keys) - for tag, packet in existing_stream.iter_packets(): - cached_results.append((tag, packet)) - pending_calls = [] if missing is not None and missing.num_rows > 0: for tag, packet in TableStream(missing, tag_columns=tag_keys): @@ -134,23 +88,23 @@ async def run_async( or self._execution_engine_opts, ) pending_calls.append(pending) - import asyncio + import asyncio completed_calls = await asyncio.gather(*pending_calls) for result in completed_calls: cached_results.append(result) + self.clear_cache() self._cached_output_packets = cached_results self._set_modified_time() + self.pod_node.flush() - def run( - self, - *args: Any, + def _identify_existing_and_missing_entries(self, + *args: Any, execution_engine: cp.ExecutionEngine | None = None, execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - cached_results = [] + **kwargs: Any) -> tuple[list[tuple[cp.Tag, cp.Packet|None]], pa.Table | None]: + cached_results: list[tuple[cp.Tag, cp.Packet|None]] = [] # identify all entries in the input stream for which we still have not computed packets if len(args) > 0 or len(kwargs) > 0: @@ -223,6 +177,25 @@ def run( for tag, packet in existing_stream.iter_packets(): cached_results.append((tag, packet)) + + + return cached_results, missing + + def run( + self, + *args: Any, + execution_engine: cp.ExecutionEngine | None = None, + execution_engine_opts: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + tag_keys = self.input_stream.keys()[0] + cached_results, missing = self._identify_existing_and_missing_entries( + *args, + execution_engine=execution_engine, + execution_engine_opts=execution_engine_opts, + **kwargs, + ) + if missing is not None and missing.num_rows > 0: packet_record_to_output_lut: dict[str, cp.Packet | None] = {} execution_engine_hash = ( @@ -257,11 +230,14 @@ def run( ) cached_results.append((tag, output_packet)) + + # reset the cache and set new results + self.clear_cache() self._cached_output_packets = cached_results self._set_modified_time() self.pod_node.flush() # TODO: evaluate proper handling of cache here - self.clear_cache() + # self.clear_cache() def clear_cache(self) -> None: self._cached_output_packets = None @@ -300,115 +276,7 @@ def iter_packets( self._cached_output_packets = cached_results self._set_modified_time() - # if self._cached_output_packets is None: - # cached_results = [] - - # # identify all entries in the input stream for which we still have not computed packets - # target_entries = self.input_stream.as_table( - # include_system_tags=True, - # include_source=True, - # include_content_hash=constants.INPUT_PACKET_HASH, - # execution_engine=execution_engine, - # ) - # existing_entries = self.pod_node.get_all_cached_outputs( - # include_system_columns=True - # ) - # if existing_entries is None or existing_entries.num_rows == 0: - # missing = target_entries.drop_columns([constants.INPUT_PACKET_HASH]) - # existing = None - # else: - # # missing = target_entries.join( - # # existing_entries, - # # keys=[constants.INPUT_PACKET_HASH], - # # join_type="left anti", - # # ) - # # Single join that gives you both missing and existing - # # More efficient - only bring the key column from existing_entries - # # .select([constants.INPUT_PACKET_HASH]).append_column( - # # "_exists", pa.array([True] * len(existing_entries)) - # # ), - - # # TODO: do more proper replacement operation - # target_df = pl.DataFrame(target_entries) - # existing_df = pl.DataFrame( - # existing_entries.append_column( - # "_exists", pa.array([True] * len(existing_entries)) - # ) - # ) - # all_results_df = target_df.join( - # existing_df, - # on=constants.INPUT_PACKET_HASH, - # how="left", - # suffix="_right", - # ) - # all_results = all_results_df.to_arrow() - # # all_results = target_entries.join( - # # existing_entries.append_column( - # # "_exists", pa.array([True] * len(existing_entries)) - # # ), - # # keys=[constants.INPUT_PACKET_HASH], - # # join_type="left outer", - # # right_suffix="_right", # rename the existing records in case of collision of output packet keys with input packet keys - # # ) - # # grab all columns from target_entries first - # missing = ( - # all_results.filter(pc.is_null(pc.field("_exists"))) - # .select(target_entries.column_names) - # .drop_columns([constants.INPUT_PACKET_HASH]) - # ) - - # existing = all_results.filter( - # pc.is_valid(pc.field("_exists")) - # ).drop_columns( - # [ - # "_exists", - # constants.INPUT_PACKET_HASH, - # constants.PACKET_RECORD_ID, - # *self.input_stream.keys()[1], # remove the input packet keys - # ] - # # TODO: look into NOT fetching back the record ID - # ) - # renamed = [ - # c.removesuffix("_right") if c.endswith("_right") else c - # for c in existing.column_names - # ] - # existing = existing.rename_columns(renamed) - - # tag_keys = self.input_stream.keys()[0] - - # if existing is not None and existing.num_rows > 0: - # # If there are existing entries, we can cache them - # existing_stream = TableStream(existing, tag_columns=tag_keys) - # for tag, packet in existing_stream.iter_packets(): - # cached_results.append((tag, packet)) - # yield tag, packet - - # if missing is not None and missing.num_rows > 0: - # hash_to_output_lut: dict[str, cp.Packet | None] = {} - # for tag, packet in TableStream(missing, tag_columns=tag_keys): - # # Since these packets are known to be missing, skip the cache lookup - # packet_hash = packet.content_hash().to_string() - # if packet_hash in hash_to_output_lut: - # output_packet = hash_to_output_lut[packet_hash] - # else: - # tag, output_packet = self.pod_node.call( - # tag, - # packet, - # skip_cache_lookup=True, - # execution_engine=execution_engine, - # ) - # hash_to_output_lut[packet_hash] = output_packet - # cached_results.append((tag, output_packet)) - # if output_packet is not None: - # yield tag, output_packet - - # self._cached_output_packets = cached_results - # self._set_modified_time() - # else: - # for tag, packet in self._cached_output_packets: - # if packet is not None: - # yield tag, packet - + def keys( self, include_system_tags: bool = False ) -> tuple[tuple[str, ...], tuple[str, ...]]: diff --git a/src/orcapod/execution_engines/ray_execution_engine.py b/src/orcapod/execution_engines/ray_execution_engine.py index d4e6727..b36de84 100644 --- a/src/orcapod/execution_engines/ray_execution_engine.py +++ b/src/orcapod/execution_engines/ray_execution_engine.py @@ -26,6 +26,11 @@ class RayEngine: 3. No polling needed - Ray handles async integration """ + @property + def supports_async(self) -> bool: + """Indicate that this engine supports async execution.""" + return True + def __init__(self, ray_address: str | None = None, **ray_init_kwargs): """Initialize Ray with native async support.""" diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py index b3534ca..45d83e0 100644 --- a/src/orcapod/pipeline/graph.py +++ b/src/orcapod/pipeline/graph.py @@ -4,7 +4,7 @@ from orcapod import contexts from orcapod.protocols import core_protocols as cp from orcapod.protocols import database_protocols as dbp -from typing import Any +from typing import Any, cast from collections.abc import Collection import os import tempfile @@ -18,6 +18,8 @@ else: nx = LazyModule("networkx") +logger = logging.getLogger(__name__) + def synchronous_run(async_func, *args, **kwargs): """ @@ -43,7 +45,6 @@ def run_in_thread(): return asyncio.run(async_func(*args, **kwargs)) -logger = logging.getLogger(__name__) class GraphNode: @@ -94,11 +95,39 @@ def __init__( self.results_store_path_prefix = self.name + ("_results",) self.pipeline_database = pipeline_database self.results_database = results_database - self.nodes: dict[str, Node] = {} + self._nodes: dict[str, Node] = {} self.auto_compile = auto_compile self._dirty = False self._ordered_nodes = [] # Track order of invocations + @property + def nodes(self) -> dict[str, Node]: + return self._nodes.copy() + + @property + def function_pods(self) -> dict[str, cp.Pod]: + return { + label: cast(cp.Pod, node) + for label, node in self._nodes.items() + if getattr(node, "kernel_type") == "function" + } + + @property + def source_pods(self) -> dict[str, cp.Source]: + return { + label: node + for label, node in self._nodes.items() + if getattr(node, "kernel_type") == "source" + } + + @property + def operator_pods(self) -> dict[str, cp.Kernel]: + return { + label: node + for label, node in self._nodes.items() + if getattr(node, "kernel_type") == "operator" + } + def __exit__(self, exc_type=None, exc_value=None, traceback=None): """ Exit the pipeline context, ensuring all nodes are properly closed. @@ -156,13 +185,13 @@ def compile(self) -> None: # If there are multiple nodes with the same label, we need to resolve the collision logger.info(f"Collision detected for label '{label}': {nodes}") for i, node in enumerate(nodes, start=1): - self.nodes[f"{label}_{i}"] = node + self._nodes[f"{label}_{i}"] = node node.label = f"{label}_{i}" else: - self.nodes[label] = nodes[0] + self._nodes[label] = nodes[0] nodes[0].label = label - self.label_lut = {v: k for k, v in self.nodes.items()} + self.label_lut = {v: k for k, v in self._nodes.items()} self.graph = node_graph @@ -172,7 +201,7 @@ def show_graph(self, **kwargs) -> None: def set_mode(self, mode: str) -> None: if mode not in ("production", "development"): raise ValueError("Mode must be either 'production' or 'development'") - for node in self.nodes.values(): + for node in self._nodes.values(): if hasattr(node, "set_mode"): node.set_mode(mode) @@ -201,6 +230,16 @@ def run( may implement more efficient graph traversal algorithms. """ import networkx as nx + if run_async is True and (execution_engine is None or not execution_engine.supports_async): + raise ValueError( + "Cannot run asynchronously with an execution engine that does not support async." + ) + + # if set to None, determine based on execution engine capabilities + if run_async is None: + run_async = execution_engine is not None and execution_engine.supports_async + + logger.info(f"Running pipeline with run_async={run_async}") for node in nx.topological_sort(self.graph): if run_async: @@ -257,27 +296,27 @@ def wrap_invocation( def __getattr__(self, item: str) -> Any: """Allow direct access to pipeline attributes.""" - if item in self.nodes: - return self.nodes[item] + if item in self._nodes: + return self._nodes[item] raise AttributeError(f"Pipeline has no attribute '{item}'") def __dir__(self) -> list[str]: """Return a list of attributes and methods of the pipeline.""" - return list(super().__dir__()) + list(self.nodes.keys()) + return list(super().__dir__()) + list(self._nodes.keys()) def rename(self, old_name: str, new_name: str) -> None: """ Rename a node in the pipeline. This will update the label and the internal mapping. """ - if old_name not in self.nodes: + if old_name not in self._nodes: raise KeyError(f"Node '{old_name}' does not exist in the pipeline.") - if new_name in self.nodes: + if new_name in self._nodes: raise KeyError(f"Node '{new_name}' already exists in the pipeline.") - node = self.nodes[old_name] - del self.nodes[old_name] + node = self._nodes[old_name] + del self._nodes[old_name] node.label = new_name - self.nodes[new_name] = node + self._nodes[new_name] = node logger.info(f"Node '{old_name}' renamed to '{new_name}'") diff --git a/src/orcapod/pipeline/nodes.py b/src/orcapod/pipeline/nodes.py index db9d0d4..af63971 100644 --- a/src/orcapod/pipeline/nodes.py +++ b/src/orcapod/pipeline/nodes.py @@ -264,6 +264,15 @@ def __init__( pipeline_path_prefix=pipeline_path_prefix, **kwargs, ) + self._execution_engine_opts: dict[str, Any] = {} + + @property + def execution_engine_opts(self) -> dict[str, Any]: + return self._execution_engine_opts.copy() + + @execution_engine_opts.setter + def execution_engine_opts(self, opts: dict[str, Any]) -> None: + self._execution_engine_opts = opts def flush(self): self.pipeline_database.flush() @@ -309,6 +318,11 @@ def call( if record_id is None: record_id = self.get_record_id(packet, execution_engine_hash) + combined_execution_engine_opts = self.execution_engine_opts + if execution_engine_opts is not None: + combined_execution_engine_opts.update(execution_engine_opts) + + tag, output_packet = super().call( tag, packet, @@ -316,7 +330,7 @@ def call( skip_cache_lookup=skip_cache_lookup, skip_cache_insert=skip_cache_insert, execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, + execution_engine_opts=combined_execution_engine_opts, ) # if output_packet is not None: @@ -348,6 +362,12 @@ async def async_call( if record_id is None: record_id = self.get_record_id(packet, execution_engine_hash) + + combined_execution_engine_opts = self.execution_engine_opts + if execution_engine_opts is not None: + combined_execution_engine_opts.update(execution_engine_opts) + + tag, output_packet = await super().async_call( tag, packet, @@ -355,7 +375,7 @@ async def async_call( skip_cache_lookup=skip_cache_lookup, skip_cache_insert=skip_cache_insert, execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, + execution_engine_opts=combined_execution_engine_opts, ) if output_packet is not None: diff --git a/src/orcapod/protocols/core_protocols/base.py b/src/orcapod/protocols/core_protocols/base.py index 7faad25..87d9a81 100644 --- a/src/orcapod/protocols/core_protocols/base.py +++ b/src/orcapod/protocols/core_protocols/base.py @@ -41,6 +41,10 @@ class ExecutionEngine(Protocol): "local", "threadpool", "processpool", or "ray" and is used for logging and diagnostics. """ + @property + def supports_async(self) -> bool: + """Indicate whether this engine supports async execution.""" + ... @property def name(self) -> str: