Skip to content

Commit 295a7fb

Browse files
committed
LangGraph: Unify plugin for Graph API and Functional API
- Create unified LangGraphPlugin supporting both StateGraph and @entrypoint - Add auto-detection to compile() for returning correct runner type - Fix _is_entrypoint() to distinguish @entrypoint (Pregel) from StateGraph.compile() (CompiledStateGraph) by checking class type and presence of __start__ node - Fix timedelta serialization in _filter_config() by excluding temporal options from metadata (handled separately by _get_node_activity_options) - Update tests to use real graphs instead of MagicMock
1 parent 02c4ca7 commit 295a7fb

File tree

7 files changed

+657
-153
lines changed

7 files changed

+657
-153
lines changed

temporalio/contrib/langgraph/__init__.py

Lines changed: 217 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from __future__ import annotations
99

1010
from datetime import timedelta
11-
from typing import Any
11+
from typing import TYPE_CHECKING, Any, Union
1212

1313
import temporalio.common
1414
import temporalio.workflow
@@ -19,24 +19,37 @@
1919
GraphAlreadyRegisteredError,
2020
)
2121
from temporalio.contrib.langgraph._functional_activity import execute_langgraph_task
22-
from temporalio.contrib.langgraph._functional_plugin import LangGraphFunctionalPlugin
22+
23+
# Backward compatibility - LangGraphFunctionalPlugin is now deprecated
24+
# Use LangGraphPlugin with entrypoints in the graphs parameter instead
25+
from temporalio.contrib.langgraph._functional_plugin import (
26+
LangGraphFunctionalPlugin,
27+
)
2328
from temporalio.contrib.langgraph._functional_registry import (
2429
get_entrypoint,
2530
register_entrypoint,
2631
)
32+
from temporalio.contrib.langgraph._functional_registry import (
33+
get_global_entrypoint_registry as _get_functional_registry,
34+
)
2735
from temporalio.contrib.langgraph._functional_runner import (
2836
TemporalFunctionalRunner,
29-
compile_functional,
3037
)
3138
from temporalio.contrib.langgraph._graph_registry import (
3239
get_default_activity_options,
3340
get_graph,
3441
get_per_node_activity_options,
3542
)
43+
from temporalio.contrib.langgraph._graph_registry import (
44+
get_global_registry as _get_graph_registry,
45+
)
3646
from temporalio.contrib.langgraph._models import StateSnapshot
3747
from temporalio.contrib.langgraph._plugin import LangGraphPlugin
3848
from temporalio.contrib.langgraph._runner import TemporalLangGraphRunner
3949

50+
if TYPE_CHECKING:
51+
from temporalio.contrib.langgraph._plugin import ActivityOptionsKey
52+
4053

4154
def activity_options(
4255
*,
@@ -53,13 +66,30 @@ def activity_options(
5366
) -> dict[str, Any]:
5467
"""Create activity options for LangGraph integration.
5568
56-
Use with Graph API:
57-
- ``graph.add_node(metadata=activity_options(...))`` for node activities
58-
- ``LangGraphPlugin(per_node_activity_options={"node": activity_options(...)})``
69+
Use with plugin registration:
70+
plugin = LangGraphPlugin(
71+
graphs={"my_graph": graph.compile(), "my_entrypoint": entrypoint_func},
72+
default_activity_options=activity_options(
73+
start_to_close_timeout=timedelta(minutes=5),
74+
),
75+
activity_options={
76+
# Global: applies to any graph/entrypoint with this node/task
77+
"call_model": activity_options(start_to_close_timeout=timedelta(minutes=2)),
5978
60-
Use with Functional API:
61-
- ``compile_functional(task_options={"task_name": activity_options(...)})``
62-
- ``LangGraphFunctionalPlugin(task_options={"task": activity_options(...)})``
79+
# Scoped to specific graph or entrypoint
80+
("my_graph", "expensive_node"): activity_options(
81+
start_to_close_timeout=timedelta(minutes=10),
82+
),
83+
},
84+
)
85+
86+
Use with compile() for workflow-level overrides:
87+
app = compile(
88+
"my_graph",
89+
activity_options={
90+
"slow_node": activity_options(start_to_close_timeout=timedelta(minutes=20)),
91+
},
92+
)
6393
6494
Parameters mirror ``workflow.execute_activity()``.
6595
"""
@@ -95,7 +125,7 @@ def temporal_node_metadata(
95125
"""Create node metadata combining activity options and execution flags.
96126
97127
Args:
98-
activity_options: Options from ``node_activity_options()``.
128+
activity_options: Options from ``activity_options()``.
99129
run_in_workflow: If True, run in workflow instead of as activity.
100130
"""
101131
# Start with activity options if provided, otherwise empty temporal config
@@ -118,23 +148,95 @@ def compile(
118148
graph_id: str,
119149
*,
120150
default_activity_options: dict[str, Any] | None = None,
121-
per_node_activity_options: dict[str, dict[str, Any]] | None = None,
151+
activity_options: dict[str, dict[str, Any]] | None = None,
122152
checkpoint: dict | None = None,
123-
) -> TemporalLangGraphRunner:
124-
"""Compile a registered graph for Temporal execution.
153+
) -> Union[TemporalLangGraphRunner, TemporalFunctionalRunner]:
154+
"""Compile a registered graph or entrypoint for Temporal execution.
155+
156+
This function auto-detects whether the ID refers to a Graph API graph
157+
(StateGraph) or a Functional API entrypoint (@entrypoint/@task).
125158
126159
.. warning::
127160
This API is experimental and may change in future versions.
128161
129162
Args:
130-
graph_id: ID of graph registered with LangGraphPlugin.
131-
default_activity_options: Default options for all nodes.
132-
per_node_activity_options: Per-node options by node name.
163+
graph_id: ID of graph or entrypoint registered with LangGraphPlugin.
164+
default_activity_options: Default options for all nodes/tasks.
165+
Use activity_options() helper to create.
166+
activity_options: Per-node/task options by name.
167+
Use activity_options() helper to create values.
133168
checkpoint: Checkpoint from previous get_state() for continue-as-new.
169+
Only applies to Graph API graphs.
170+
171+
Returns:
172+
TemporalLangGraphRunner for Graph API graphs, or
173+
TemporalFunctionalRunner for Functional API entrypoints.
134174
135175
Raises:
136-
ApplicationError: If no graph with the given ID is registered.
176+
ApplicationError: If no graph or entrypoint with the given ID is registered.
177+
178+
Example (Graph API):
179+
```python
180+
@workflow.defn
181+
class MyWorkflow:
182+
@workflow.run
183+
async def run(self, query: str) -> dict:
184+
app = compile("my_graph")
185+
return await app.ainvoke({"messages": [("user", query)]})
186+
```
187+
188+
Example (Functional API):
189+
```python
190+
@workflow.defn
191+
class MyWorkflow:
192+
@workflow.run
193+
async def run(self, topic: str) -> dict:
194+
app = compile("my_entrypoint")
195+
return await app.ainvoke(topic)
196+
```
137197
"""
198+
# Check which registry has this ID
199+
graph_registry = _get_graph_registry()
200+
functional_registry = _get_functional_registry()
201+
202+
is_graph = graph_registry.is_registered(graph_id)
203+
is_entrypoint = functional_registry.is_registered(graph_id)
204+
205+
if is_graph:
206+
return _compile_graph(
207+
graph_id,
208+
default_activity_options=default_activity_options,
209+
per_node_activity_options=activity_options,
210+
checkpoint=checkpoint,
211+
)
212+
elif is_entrypoint:
213+
return _compile_entrypoint(
214+
graph_id,
215+
default_activity_options=default_activity_options,
216+
task_options=activity_options,
217+
)
218+
else:
219+
# Neither registry has it - raise error
220+
from temporalio.exceptions import ApplicationError
221+
222+
graph_ids = graph_registry.list_graphs()
223+
entrypoint_ids = functional_registry.list_entrypoints()
224+
all_ids = graph_ids + entrypoint_ids
225+
raise ApplicationError(
226+
f"'{graph_id}' not found. Available: {all_ids}",
227+
type=GRAPH_NOT_FOUND_ERROR,
228+
non_retryable=True,
229+
)
230+
231+
232+
def _compile_graph(
233+
graph_id: str,
234+
*,
235+
default_activity_options: dict[str, Any] | None = None,
236+
per_node_activity_options: dict[str, dict[str, Any]] | None = None,
237+
checkpoint: dict | None = None,
238+
) -> TemporalLangGraphRunner:
239+
"""Compile a Graph API graph for Temporal execution."""
138240
# Get graph from registry
139241
pregel = get_graph(graph_id)
140242

@@ -145,11 +247,7 @@ def compile(
145247
def _merge_activity_options(
146248
base: dict[str, Any], override: dict[str, Any]
147249
) -> dict[str, Any]:
148-
"""Merge activity options, with override taking precedence.
149-
150-
Both dicts have structure {"temporal": {...}} from node_activity_options().
151-
We need to merge the inner "temporal" dicts.
152-
"""
250+
"""Merge activity options, with override taking precedence."""
153251
base_temporal = base.get("temporal", {})
154252
override_temporal = override.get("temporal", {})
155253
return {"temporal": {**base_temporal, **override_temporal}}
@@ -186,21 +284,115 @@ def _merge_activity_options(
186284
)
187285

188286

287+
def _compile_entrypoint(
288+
entrypoint_id: str,
289+
*,
290+
default_activity_options: dict[str, Any] | None = None,
291+
task_options: dict[str, dict[str, Any]] | None = None,
292+
) -> TemporalFunctionalRunner:
293+
"""Compile a Functional API entrypoint for Temporal execution."""
294+
from temporalio.contrib.langgraph._functional_registry import (
295+
get_entrypoint_default_options,
296+
get_entrypoint_task_options,
297+
)
298+
299+
# Get plugin-level options from registry
300+
plugin_default_options = get_entrypoint_default_options(entrypoint_id)
301+
plugin_task_options = get_entrypoint_task_options(entrypoint_id)
302+
303+
# Merge default options
304+
merged_default_options: dict[str, Any] | None = None
305+
if plugin_default_options or default_activity_options:
306+
# Unwrap activity_options format if needed
307+
base = plugin_default_options or {}
308+
override = default_activity_options or {}
309+
if "temporal" in base:
310+
base = base.get("temporal", {})
311+
if "temporal" in override:
312+
override = override.get("temporal", {})
313+
merged_default_options = {**base, **override}
314+
315+
# Merge per-task options
316+
merged_task_options: dict[str, dict[str, Any]] | None = None
317+
if plugin_task_options or task_options:
318+
merged_task_options = {}
319+
# Start with plugin options
320+
for task_name, opts in (plugin_task_options or {}).items():
321+
merged_task_options[task_name] = opts
322+
# Merge compile options
323+
if task_options:
324+
for task_name, opts in task_options.items():
325+
if task_name in merged_task_options:
326+
# Merge the options
327+
base = merged_task_options[task_name]
328+
if "temporal" in base:
329+
base = base.get("temporal", {})
330+
override = opts
331+
if "temporal" in override:
332+
override = override.get("temporal", {})
333+
merged_task_options[task_name] = {**base, **override}
334+
else:
335+
# Unwrap if needed
336+
if "temporal" in opts:
337+
merged_task_options[task_name] = opts.get("temporal", {})
338+
else:
339+
merged_task_options[task_name] = opts
340+
341+
# Get default timeout from merged options
342+
default_timeout = timedelta(minutes=5)
343+
if merged_default_options:
344+
if "start_to_close_timeout" in merged_default_options:
345+
default_timeout = merged_default_options["start_to_close_timeout"]
346+
347+
return TemporalFunctionalRunner(
348+
entrypoint_id=entrypoint_id,
349+
default_task_timeout=default_timeout,
350+
task_options=merged_task_options,
351+
)
352+
353+
354+
# Keep compile_functional for backward compatibility (deprecated)
355+
def compile_functional(
356+
entrypoint_id: str,
357+
default_task_timeout: timedelta = timedelta(minutes=5),
358+
task_options: dict[str, dict[str, Any]] | None = None,
359+
) -> TemporalFunctionalRunner:
360+
"""Compile a registered entrypoint for Temporal execution.
361+
362+
.. deprecated::
363+
Use ``compile()`` instead, which auto-detects graph vs entrypoint.
364+
365+
Args:
366+
entrypoint_id: ID of the registered entrypoint.
367+
default_task_timeout: Default timeout for task activities.
368+
task_options: Per-task activity options.
369+
370+
Returns:
371+
A TemporalFunctionalRunner that can be used to invoke the entrypoint.
372+
"""
373+
return TemporalFunctionalRunner(
374+
entrypoint_id=entrypoint_id,
375+
default_task_timeout=default_task_timeout,
376+
task_options=task_options,
377+
)
378+
379+
189380
__all__ = [
190-
# Main API - Graph API
381+
# Main unified API
191382
"activity_options",
192383
"compile",
193384
"LangGraphPlugin",
194385
"StateSnapshot",
195386
"temporal_node_metadata",
387+
# Runner types (for type annotations)
196388
"TemporalLangGraphRunner",
197-
# Main API - Functional API
389+
"TemporalFunctionalRunner",
390+
# Deprecated (kept for backward compatibility)
198391
"compile_functional",
199392
"execute_langgraph_task",
200393
"get_entrypoint",
201394
"LangGraphFunctionalPlugin",
202395
"register_entrypoint",
203-
"TemporalFunctionalRunner",
204396
# Exception types (for catching configuration errors)
205397
"GraphAlreadyRegisteredError",
206398
# Error type constants (for catching ApplicationError.type)

0 commit comments

Comments
 (0)