-
Notifications
You must be signed in to change notification settings - Fork 844
feat: pass invocation_state to edge condition calls #2305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,11 +16,14 @@ | |
|
|
||
| import asyncio | ||
| import copy | ||
| import inspect | ||
| import json | ||
| import logging | ||
| import time | ||
| import weakref | ||
| from collections.abc import AsyncIterator, Callable, Mapping | ||
| from dataclasses import dataclass, field | ||
| from typing import Any, cast | ||
| from typing import Any, Protocol, TypeGuard, cast | ||
|
|
||
| from opentelemetry import trace as trace_api | ||
|
|
||
|
|
@@ -62,6 +65,57 @@ | |
| _DEFAULT_GRAPH_ID = "default_graph" | ||
|
|
||
|
|
||
| class EdgeConditionWithContext(Protocol): | ||
| """Protocol for edge conditions that receive invocation_state. | ||
|
|
||
| This allows conditions to make routing decisions based on runtime context | ||
| passed during graph invocation, such as feature flags, user roles, or | ||
| environment-specific configuration. | ||
|
|
||
| Designed with **kwargs for future extensibility without breaking changes. | ||
|
|
||
| Note: not @runtime_checkable — isinstance() cannot distinguish callable signatures | ||
| structurally; use _is_context_condition() for dispatch instead. | ||
| """ | ||
|
|
||
| def __call__(self, state: "GraphState", *, invocation_state: dict[str, Any], **kwargs: Any) -> bool: | ||
| """Evaluate whether the edge should be traversed.""" | ||
| ... | ||
|
|
||
|
|
||
| LegacyEdgeCondition = Callable[["GraphState"], bool] | ||
| EdgeCondition = LegacyEdgeCondition | EdgeConditionWithContext | ||
|
|
||
| # GIL-protected: concurrent async graphs may read/write simultaneously, but | ||
| # CPython's GIL ensures dict mutation is atomic. Ephemeral callables (e.g. | ||
| # lambdas recreated per-call) will bypass the cache — this is benign; the | ||
| # fallback path is a single inspect.signature() call. | ||
| _context_condition_cache: weakref.WeakKeyDictionary[EdgeCondition, bool] = weakref.WeakKeyDictionary() | ||
|
|
||
|
|
||
| def _is_context_condition(condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]: | ||
| """Check if a condition function accepts invocation_state parameter. | ||
|
|
||
| Uses inspect.signature() for reliable detection, returning a TypeGuard | ||
| so mypy can narrow the type at call sites. Results are cached per condition | ||
| using weak references so entries are evicted when the function is collected. | ||
| """ | ||
| try: | ||
| return _context_condition_cache[condition] | ||
| except (KeyError, TypeError): | ||
| pass | ||
| try: | ||
| sig = inspect.signature(condition) | ||
| result = "invocation_state" in sig.parameters | ||
| except (ValueError, TypeError): | ||
| result = False | ||
| try: | ||
| _context_condition_cache[condition] = result | ||
| except TypeError: | ||
| pass | ||
| return result | ||
|
|
||
|
|
||
| @dataclass | ||
| class GraphState: | ||
| """Graph execution state. | ||
|
|
@@ -147,17 +201,28 @@ class GraphEdge: | |
|
|
||
| from_node: "GraphNode" | ||
| to_node: "GraphNode" | ||
| condition: Callable[[GraphState], bool] | None = None | ||
| condition: EdgeCondition | None = None | ||
|
|
||
| def __hash__(self) -> int: | ||
| """Return hash for GraphEdge based on from_node and to_node.""" | ||
| return hash((self.from_node.node_id, self.to_node.node_id)) | ||
|
|
||
| def should_traverse(self, state: GraphState) -> bool: | ||
| """Check if this edge should be traversed based on condition.""" | ||
| if self.condition is None: | ||
| def should_traverse(self, state: GraphState, *, invocation_state: dict[str, Any] | None = None) -> bool: | ||
| """Check if this edge should be traversed based on condition. | ||
|
|
||
| Args: | ||
| state: The current graph execution state. | ||
| invocation_state: Runtime context passed during graph invocation. | ||
| New-style conditions (EdgeConditionWithContext) receive this parameter. | ||
| Legacy conditions (Callable[[GraphState], bool]) are called with state only. | ||
| """ | ||
| condition = self.condition | ||
| if condition is None: | ||
| return True | ||
| return self.condition(state) | ||
| if _is_context_condition(condition): | ||
| return condition(state, invocation_state=invocation_state or {}) | ||
| legacy_condition = cast(LegacyEdgeCondition, condition) | ||
| return legacy_condition(state) | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -276,9 +341,14 @@ def add_edge( | |
| self, | ||
| from_node: str | GraphNode, | ||
| to_node: str | GraphNode, | ||
| condition: Callable[[GraphState], bool] | None = None, | ||
| condition: EdgeCondition | None = None, | ||
| ) -> GraphEdge: | ||
| """Add an edge between two nodes with optional condition function that receives full GraphState.""" | ||
| """Add an edge between two nodes with optional condition function. | ||
|
|
||
| The condition can be either: | ||
| - A legacy callable: Callable[[GraphState], bool] - receives only graph state | ||
| - A new-style callable: EdgeConditionWithContext - receives graph state and invocation_state | ||
| """ | ||
|
|
||
| def resolve_node(node: str | GraphNode, node_type: str) -> GraphNode: | ||
| if isinstance(node, str): | ||
|
|
@@ -491,6 +561,7 @@ def __init__( | |
|
|
||
| self._resume_next_nodes: list[GraphNode] = [] | ||
| self._resume_from_session = False | ||
| self._current_invocation_state: dict[str, Any] = {} | ||
| self.id = id | ||
|
|
||
| run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) | ||
|
|
@@ -569,6 +640,10 @@ async def stream_async( | |
| if invocation_state is None: | ||
| invocation_state = {} | ||
|
|
||
| if self.session_manager is not None: | ||
| self._validate_invocation_state(invocation_state) | ||
| self._current_invocation_state = invocation_state | ||
|
|
||
| await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state)) | ||
|
|
||
| logger.debug("task=<%s> | starting graph execution", task) | ||
|
|
@@ -889,7 +964,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ | |
| # Check if at least one incoming edge condition is satisfied | ||
| for edge in incoming_edges: | ||
| if edge.from_node in completed_batch: | ||
| if edge.should_traverse(self.state): | ||
| if edge.should_traverse(self.state, invocation_state=self._current_invocation_state): | ||
| logger.debug( | ||
| "from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id | ||
| ) | ||
|
|
@@ -1125,7 +1200,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: | |
| and edge.from_node in self.state.completed_nodes | ||
| and edge.from_node.node_id in self.state.results | ||
| ): | ||
| if edge.should_traverse(self.state): | ||
| if edge.should_traverse(self.state, invocation_state=self._current_invocation_state): | ||
| dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id] | ||
|
|
||
| if not dependency_results: | ||
|
|
@@ -1186,6 +1261,20 @@ def _build_result(self, interrupts: list[Interrupt]) -> GraphResult: | |
| interrupts=interrupts, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _validate_invocation_state(invocation_state: dict[str, Any]) -> None: | ||
| """Validate that invocation_state is JSON-serializable. | ||
|
|
||
| Raises: | ||
| TypeError: If invocation_state contains non-JSON-serializable values. | ||
| """ | ||
| try: | ||
| json.dumps(invocation_state) | ||
| except (TypeError, ValueError) as e: | ||
| raise TypeError( | ||
| f"invocation_state must be JSON-serializable for session persistence: {e}" | ||
| ) from e | ||
|
|
||
| def serialize_state(self) -> dict[str, Any]: | ||
| """Serialize the current graph state to a dictionary.""" | ||
| compute_nodes = self._compute_ready_nodes_for_resume() | ||
|
|
@@ -1201,6 +1290,7 @@ def serialize_state(self) -> dict[str, Any]: | |
| "next_nodes_to_execute": next_nodes, | ||
| "current_task": encode_bytes_values(self.state.task), | ||
| "execution_order": [n.node_id for n in self.state.execution_order], | ||
| "invocation_state": self._current_invocation_state, | ||
| "_internal_state": { | ||
| "interrupt_state": self._interrupt_state.to_dict(), | ||
| }, | ||
|
|
@@ -1223,6 +1313,10 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: | |
| internal_state = payload["_internal_state"] | ||
| self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"]) | ||
|
|
||
| invocation_state = payload.get("invocation_state", {}) | ||
| self._validate_invocation_state(invocation_state) | ||
| self._current_invocation_state = invocation_state | ||
|
|
||
| if not payload.get("next_nodes_to_execute"): | ||
| # Reset all nodes | ||
| for node in self.nodes.values(): | ||
|
|
@@ -1246,11 +1340,40 @@ def _compute_ready_nodes_for_resume(self) -> list[GraphNode]: | |
| incoming = [e for e in self.edges if e.to_node is node] | ||
| if not incoming: | ||
| ready_nodes.append(node) | ||
| elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming): | ||
| elif self._is_node_ready_for_resume(node, incoming, completed_nodes): | ||
| ready_nodes.append(node) | ||
|
|
||
| return ready_nodes | ||
|
|
||
| def _is_node_ready_for_resume( | ||
| self, | ||
| node: GraphNode, | ||
| incoming: list[GraphEdge], | ||
| completed_nodes: set[GraphNode], | ||
| ) -> bool: | ||
| """Check if a node is ready for resume, accounting for conditional edges. | ||
|
|
||
| A node is ready if all TRAVERSABLE incoming edges have their source completed. | ||
| Edges whose condition evaluates to False are excluded from the check — they | ||
| represent paths that were intentionally skipped. | ||
|
|
||
| Re-evaluates conditions (rather than caching traversal results) intentionally: | ||
| invocation_state may change between invocations, so conditions must reflect | ||
| current runtime context. This means condition logic changes between serialize | ||
| and resume will also take effect — consistent with the graph being defined in code. | ||
| """ | ||
| traversable_edges = [ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: In the Also, is calling traversable_edges = [
e for e in incoming
if e.should_traverse(self.state, invocation_state=self._current_invocation_state)
]Suggestion: The explicit
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in |
||
| e | ||
| for e in incoming | ||
| # Short-circuit: skip signature inspection + cache lookup for unconditional edges. | ||
| if e.condition is None or e.should_traverse(self.state, invocation_state=self._current_invocation_state) | ||
| ] | ||
|
|
||
| if not traversable_edges: | ||
| return False | ||
|
|
||
| return all(e.from_node in completed_nodes for e in traversable_edges) | ||
|
|
||
| def _from_dict(self, payload: dict[str, Any]) -> None: | ||
| self.state.status = Status(payload["status"]) | ||
| # Hydrate completed nodes & results | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue:
invocation_stateis serialized directly into the session payload without validation. If a user passes non-JSON-serializable objects (e.g., class instances, functions) ininvocation_state, this will fail silently or raise an unclear error during serialization.Suggestion: Consider either:
invocation_statevalues must be JSON-serializable, orserialize_state()that provides a clear error message if serialization fails due to non-serializable invocation_state values.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in
7cc0d04— added_validate_invocation_state()that callsjson.dumps()and raises a clearTypeErrorif the value isn't serializable. It's gated onsession_manageris notNone(serialization only matters when sessions persist), and also validated symmetrically on the deserialization path indeserialize_state.