From 1f860dcbbad35ab35fa8ee6efd9571becee47789 Mon Sep 17 00:00:00 2001 From: Yana Harris Date: Tue, 19 May 2026 17:25:57 +0000 Subject: [PATCH 1/2] feat: pass invocation_state to edge condition calls Add support for edge conditions that receive invocation_state, enabling conditional routing based on runtime context (feature flags, user roles, environment config) passed during graph invocation. Also fixes a deadlock in _compute_ready_nodes_for_resume() where conditional edges evaluating to False would block downstream nodes from ever becoming ready on interrupt/resume workflows. Resolves #1346 --- src/strands/multiagent/__init__.py | 4 +- src/strands/multiagent/graph.py | 101 +++++- tests/strands/multiagent/test_graph.py | 433 +++++++++++++++++++++++++ tests_integ/test_multiagent_graph.py | 318 +++++++++++++++++- 4 files changed, 843 insertions(+), 13 deletions(-) diff --git a/src/strands/multiagent/__init__.py b/src/strands/multiagent/__init__.py index ad99944a8..8dd78c38c 100644 --- a/src/strands/multiagent/__init__.py +++ b/src/strands/multiagent/__init__.py @@ -9,10 +9,12 @@ """ from .base import MultiAgentBase, MultiAgentResult, Status -from .graph import GraphBuilder, GraphResult +from .graph import EdgeCondition, EdgeConditionWithContext, GraphBuilder, GraphResult from .swarm import Swarm, SwarmResult __all__ = [ + "EdgeCondition", + "EdgeConditionWithContext", "GraphBuilder", "GraphResult", "MultiAgentBase", diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 146a31563..e6c77e79c 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -16,11 +16,12 @@ import asyncio import copy +import inspect import logging import time from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, cast +from typing import Any, Protocol, TypeGuard, cast, runtime_checkable from opentelemetry import trace as trace_api @@ -62,6 +63,39 @@ _DEFAULT_GRAPH_ID = "default_graph" +@runtime_checkable +class EdgeConditionWithContext(Protocol): + """Protocol for edge conditions that receive invocation_state. + + This allows conditions to make routing decisions based on runtime context + passed during graph invocation, such as feature flags, user roles, or + environment-specific configuration. + + Designed with **kwargs for future extensibility without breaking changes. + """ + + def __call__(self, state: "GraphState", *, invocation_state: dict[str, Any], **kwargs: Any) -> bool: + """Evaluate whether the edge should be traversed.""" + ... + + +LegacyEdgeCondition = Callable[["GraphState"], bool] +EdgeCondition = LegacyEdgeCondition | EdgeConditionWithContext + + +def _is_context_condition(condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]: + """Check if a condition function accepts invocation_state parameter. + + Uses inspect.signature() for reliable detection, returning a TypeGuard + so mypy can narrow the type at call sites. + """ + try: + sig = inspect.signature(condition) + return "invocation_state" in sig.parameters + except (ValueError, TypeError): + return False + + @dataclass class GraphState: """Graph execution state. @@ -147,17 +181,28 @@ class GraphEdge: from_node: "GraphNode" to_node: "GraphNode" - condition: Callable[[GraphState], bool] | None = None + condition: EdgeCondition | None = None def __hash__(self) -> int: """Return hash for GraphEdge based on from_node and to_node.""" return hash((self.from_node.node_id, self.to_node.node_id)) - def should_traverse(self, state: GraphState) -> bool: - """Check if this edge should be traversed based on condition.""" - if self.condition is None: + def should_traverse(self, state: GraphState, *, invocation_state: dict[str, Any] | None = None) -> bool: + """Check if this edge should be traversed based on condition. + + Args: + state: The current graph execution state. + invocation_state: Runtime context passed during graph invocation. + New-style conditions (EdgeConditionWithContext) receive this parameter. + Legacy conditions (Callable[[GraphState], bool]) are called with state only. + """ + condition = self.condition + if condition is None: return True - return self.condition(state) + if _is_context_condition(condition): + return condition(state, invocation_state=invocation_state or {}) + legacy_condition = cast(LegacyEdgeCondition, condition) + return legacy_condition(state) @dataclass @@ -276,9 +321,14 @@ def add_edge( self, from_node: str | GraphNode, to_node: str | GraphNode, - condition: Callable[[GraphState], bool] | None = None, + condition: EdgeCondition | None = None, ) -> GraphEdge: - """Add an edge between two nodes with optional condition function that receives full GraphState.""" + """Add an edge between two nodes with optional condition function. + + The condition can be either: + - A legacy callable: Callable[[GraphState], bool] - receives only graph state + - A new-style callable: EdgeConditionWithContext - receives graph state and invocation_state + """ def resolve_node(node: str | GraphNode, node_type: str) -> GraphNode: if isinstance(node, str): @@ -491,6 +541,7 @@ def __init__( self._resume_next_nodes: list[GraphNode] = [] self._resume_from_session = False + self._current_invocation_state: dict[str, Any] = {} self.id = id run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) @@ -569,6 +620,8 @@ async def stream_async( if invocation_state is None: invocation_state = {} + self._current_invocation_state = invocation_state + await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state)) logger.debug("task=<%s> | starting graph execution", task) @@ -889,7 +942,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ # Check if at least one incoming edge condition is satisfied for edge in incoming_edges: if edge.from_node in completed_batch: - if edge.should_traverse(self.state): + if edge.should_traverse(self.state, invocation_state=self._current_invocation_state): logger.debug( "from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id ) @@ -1125,7 +1178,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: and edge.from_node in self.state.completed_nodes and edge.from_node.node_id in self.state.results ): - if edge.should_traverse(self.state): + if edge.should_traverse(self.state, invocation_state=self._current_invocation_state): dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id] if not dependency_results: @@ -1201,6 +1254,7 @@ def serialize_state(self) -> dict[str, Any]: "next_nodes_to_execute": next_nodes, "current_task": encode_bytes_values(self.state.task), "execution_order": [n.node_id for n in self.state.execution_order], + "invocation_state": self._current_invocation_state, "_internal_state": { "interrupt_state": self._interrupt_state.to_dict(), }, @@ -1223,6 +1277,8 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: internal_state = payload["_internal_state"] self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"]) + self._current_invocation_state = payload.get("invocation_state", {}) + if not payload.get("next_nodes_to_execute"): # Reset all nodes for node in self.nodes.values(): @@ -1246,11 +1302,34 @@ def _compute_ready_nodes_for_resume(self) -> list[GraphNode]: incoming = [e for e in self.edges if e.to_node is node] if not incoming: ready_nodes.append(node) - elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming): + elif self._is_node_ready_for_resume(node, incoming, completed_nodes): ready_nodes.append(node) return ready_nodes + def _is_node_ready_for_resume( + self, + node: GraphNode, + incoming: list[GraphEdge], + completed_nodes: set[GraphNode], + ) -> bool: + """Check if a node is ready for resume, accounting for conditional edges. + + A node is ready if all TRAVERSABLE incoming edges have their source completed. + Edges whose condition evaluates to False are excluded from the check — they + represent paths that were intentionally skipped. + """ + traversable_edges = [ + e + for e in incoming + if e.condition is None or e.should_traverse(self.state, invocation_state=self._current_invocation_state) + ] + + if not traversable_edges: + return False + + return all(e.from_node in completed_nodes for e in traversable_edges) + def _from_dict(self, payload: dict[str, Any]) -> None: self.state.status = Status(payload["status"]) # Hydrate completed nodes & results diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index a6085627c..1eb40d8a4 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2468,6 +2468,7 @@ def test_find_newly_ready_nodes_only_evaluates_outbound_edges(): GraphEdge(from_node=node_d, to_node=node_e), ] graph.state = GraphState() + graph._current_invocation_state = {} # When A completes, only B should be ready (not E) ready = graph._find_newly_ready_nodes([node_a]) @@ -2478,3 +2479,435 @@ def test_find_newly_ready_nodes_only_evaluates_outbound_edges(): ready = graph._find_newly_ready_nodes([node_d]) ready_ids = {n.node_id for n in ready} assert ready_ids == {"E"}, f"Expected only E, got {ready_ids}" + + +# ============================================================================= +# Tests for EdgeConditionWithContext (invocation_state in edge conditions) +# ============================================================================= + + +class TestEdgeConditionProtocol: + """Tests for the EdgeConditionWithContext protocol and dispatch logic.""" + + def test_legacy_condition_still_works(self): + """Verify Callable[[GraphState], bool] conditions work unchanged.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + + def legacy_condition(state: GraphState) -> bool: + return len(state.completed_nodes) > 0 + + edge = GraphEdge(from_node=node_a, to_node=node_b, condition=legacy_condition) + + assert not edge.should_traverse(GraphState()) + assert edge.should_traverse(GraphState(completed_nodes={node_a})) + + def test_legacy_condition_not_affected_by_invocation_state(self): + """Legacy conditions should work even when invocation_state is passed.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + + def legacy_condition(state: GraphState) -> bool: + return True + + edge = GraphEdge(from_node=node_a, to_node=node_b, condition=legacy_condition) + assert edge.should_traverse(GraphState(), invocation_state={"key": "value"}) + + def test_new_style_condition_receives_invocation_state(self): + """Verify EdgeConditionWithContext receives invocation_state kwarg.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + + received_invocation_state = {} + + def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + received_invocation_state.update(invocation_state) + return invocation_state.get("enable_path", False) + + edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) + + # Without the flag, should not traverse + assert not edge.should_traverse(GraphState(), invocation_state={"enable_path": False}) + assert received_invocation_state == {"enable_path": False} + + # With the flag, should traverse + received_invocation_state.clear() + assert edge.should_traverse(GraphState(), invocation_state={"enable_path": True}) + assert received_invocation_state == {"enable_path": True} + + def test_condition_none_always_traverses(self): + """Verify edges without conditions always traverse.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + + edge = GraphEdge(from_node=node_a, to_node=node_b, condition=None) + assert edge.should_traverse(GraphState()) + assert edge.should_traverse(GraphState(), invocation_state={"anything": True}) + + def test_new_style_condition_with_kwargs_extensibility(self): + """Verify conditions with **kwargs work for future extensibility.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + + def extensible_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return True + + edge = GraphEdge(from_node=node_a, to_node=node_b, condition=extensible_condition) + assert edge.should_traverse(GraphState(), invocation_state={}) + + def test_invocation_state_defaults_to_empty_dict_when_none(self): + """Verify graceful behavior when invocation_state is None.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + + received = [] + + def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + received.append(invocation_state) + return True + + edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) + assert edge.should_traverse(GraphState(), invocation_state=None) + assert received == [{}] + + +class TestInvocationStatePropagation: + """Tests that invocation_state flows correctly through graph execution paths.""" + + def test_is_node_ready_with_conditions_passes_invocation_state(self): + """Verify _is_node_ready_with_conditions passes invocation_state to edge conditions.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + node_b.dependencies.add(node_a) + + received_state = {} + + def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + received_state.update(invocation_state) + return invocation_state.get("activate", False) + + edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a, "B": node_b} + graph.edges = [edge] + graph.state = GraphState(completed_nodes={node_a}) + graph._current_invocation_state = {"activate": True} + + assert graph._is_node_ready_with_conditions(node_b, [node_a]) + assert received_state == {"activate": True} + + def test_is_node_ready_with_conditions_invocation_state_false(self): + """Verify condition returning False blocks node readiness.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + node_b.dependencies.add(node_a) + + def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("activate", False) + + edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a, "B": node_b} + graph.edges = [edge] + graph.state = GraphState(completed_nodes={node_a}) + graph._current_invocation_state = {"activate": False} + + assert not graph._is_node_ready_with_conditions(node_b, [node_a]) + + def test_build_node_input_passes_invocation_state(self): + """Verify _build_node_input uses invocation_state for edge condition evaluation.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + node_b.dependencies.add(node_a) + + def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("include_dep", False) + + edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) + + mock_result = AgentResult( + message={"role": "assistant", "content": [{"text": "result from A"}]}, + stop_reason="end_turn", + state={}, + metrics=Mock( + accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + accumulated_metrics={"latencyMs": 100.0}, + ), + ) + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a, "B": node_b} + graph.edges = [edge] + graph.state = GraphState( + task="test task", + completed_nodes={node_a}, + results={"A": NodeResult(result=mock_result)}, + ) + graph._current_invocation_state = {"include_dep": False} + graph._interrupt_state = _InterruptState() + + # With condition=False, dependency is excluded -> gets raw task + node_input = graph._build_node_input(node_b) + assert any("test task" in str(block) for block in node_input) + + # With condition=True, dependency result is included + graph._current_invocation_state = {"include_dep": True} + node_input = graph._build_node_input(node_b) + input_text = " ".join(str(block) for block in node_input) + assert "result from A" in input_text + + +class TestResumeDeadlockFix: + """Tests for the _compute_ready_nodes_for_resume deadlock fix with conditional edges.""" + + def test_resume_skips_false_condition_edges(self): + """Graph: A->(cond=False)->B, A->(unconditional)->C. C should be ready on resume.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + node_c = GraphNode(node_id="C", executor=create_mock_agent("C")) + node_b.dependencies.add(node_a) + node_c.dependencies.add(node_a) + + def always_false(state: GraphState) -> bool: + return False + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a, "B": node_b, "C": node_c} + graph.edges = [ + GraphEdge(from_node=node_a, to_node=node_b, condition=always_false), + GraphEdge(from_node=node_a, to_node=node_c), # unconditional + ] + graph.state = GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}) + graph._current_invocation_state = {} + + ready = graph._compute_ready_nodes_for_resume() + ready_ids = {n.node_id for n in ready} + # C should be ready (unconditional edge from A), B should not (condition=False) + assert "C" in ready_ids + assert "B" not in ready_ids + + def test_resume_diamond_with_conditional_skip(self): + """Exact scenario from issue comment: A->(cond=True)->B->C, A->(cond=False)->C. + + When condition is False, B is skipped. C has two incoming edges: + - B->C (unconditional, but B never ran) + - A->C (condition=False, should be excluded from readiness check) + + Without the fix, C is stuck because all() requires both edges satisfied. + With the fix, the A->C edge is excluded (condition=False), and since there + are no other traversable edges with incomplete sources, we need B->C. + But B never ran, so C can't be ready via B->C either. + + The correct fix scenario: when condition selects the FAST path (True), + B runs and C should be ready via B->C (excluding A->C which is False). + """ + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + node_c = GraphNode(node_id="C", executor=create_mock_agent("C")) + node_b.dependencies.add(node_a) + node_c.dependencies.add(node_a) + node_c.dependencies.add(node_b) + + def use_fast_path(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("fast", False) + + def skip_direct(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return not invocation_state.get("fast", False) + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a, "B": node_b, "C": node_c} + graph.edges = [ + GraphEdge(from_node=node_a, to_node=node_b, condition=use_fast_path), + GraphEdge(from_node=node_a, to_node=node_c, condition=skip_direct), + GraphEdge(from_node=node_b, to_node=node_c), + ] + graph.state = GraphState(status=Status.INTERRUPTED, completed_nodes={node_a, node_b}) + graph._current_invocation_state = {"fast": True} + + ready = graph._compute_ready_nodes_for_resume() + ready_ids = {n.node_id for n in ready} + # C should be ready: A->C edge excluded (condition=False), B->C is unconditional and B completed + assert "C" in ready_ids + + def test_resume_all_conditions_false_blocks_node(self): + """If ALL incoming edges have conditions that are False, node should not be ready.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + node_b.dependencies.add(node_a) + + def always_false(state: GraphState) -> bool: + return False + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a, "B": node_b} + graph.edges = [ + GraphEdge(from_node=node_a, to_node=node_b, condition=always_false), + ] + graph.state = GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}) + graph._current_invocation_state = {} + + ready = graph._compute_ready_nodes_for_resume() + ready_ids = {n.node_id for n in ready} + assert "B" not in ready_ids + + def test_resume_with_invocation_state_condition(self): + """Condition uses invocation_state; on resume with same state, correct routing.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + node_c = GraphNode(node_id="C", executor=create_mock_agent("C")) + node_b.dependencies.add(node_a) + node_c.dependencies.add(node_a) + + def check_role(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("role") == "admin" + + def check_not_admin(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("role") != "admin" + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a, "B": node_b, "C": node_c} + graph.edges = [ + GraphEdge(from_node=node_a, to_node=node_b, condition=check_role), + GraphEdge(from_node=node_a, to_node=node_c, condition=check_not_admin), + ] + graph.state = GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}) + + # As admin: only B should be ready + graph._current_invocation_state = {"role": "admin"} + ready = graph._compute_ready_nodes_for_resume() + ready_ids = {n.node_id for n in ready} + assert ready_ids == {"B"} + + # As non-admin: only C should be ready + graph._current_invocation_state = {"role": "user"} + ready = graph._compute_ready_nodes_for_resume() + ready_ids = {n.node_id for n in ready} + assert ready_ids == {"C"} + + def test_resume_mixed_conditional_unconditional_edges(self): + """Node with both conditional (False) and unconditional edges: ready if unconditional source completed.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + node_c = GraphNode(node_id="C", executor=create_mock_agent("C")) + node_b.dependencies.add(node_a) + node_c.dependencies.add(node_a) + node_c.dependencies.add(node_b) + + def always_false(state: GraphState) -> bool: + return False + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a, "B": node_b, "C": node_c} + graph.edges = [ + GraphEdge(from_node=node_a, to_node=node_b), # unconditional + GraphEdge(from_node=node_a, to_node=node_c, condition=always_false), # conditional (False) + GraphEdge(from_node=node_b, to_node=node_c), # unconditional + ] + graph.state = GraphState(status=Status.INTERRUPTED, completed_nodes={node_a, node_b}) + graph._current_invocation_state = {} + + ready = graph._compute_ready_nodes_for_resume() + ready_ids = {n.node_id for n in ready} + # C should be ready: A->C is excluded (condition=False), B->C is unconditional and B completed + assert "C" in ready_ids + + +class TestSerializationWithInvocationState: + """Tests for serialization/deserialization of invocation_state.""" + + def test_serialize_includes_invocation_state(self): + """Verify invocation_state appears in serialized payload.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a} + graph.edges = [] + graph.state = GraphState(status=Status.COMPLETED, completed_nodes={node_a}, task="test") + graph._current_invocation_state = {"feature_flag": True, "user_id": "123"} + graph._interrupt_state = _InterruptState() + graph.id = "test_graph" + + serialized = graph.serialize_state() + assert "invocation_state" in serialized + assert serialized["invocation_state"] == {"feature_flag": True, "user_id": "123"} + + def test_deserialize_restores_invocation_state(self): + """Verify invocation_state is restored on deserialization.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a} + graph.edges = set() + graph.state = GraphState() + graph._interrupt_state = _InterruptState() + graph._resume_from_session = False + graph._resume_next_nodes = [] + + payload = { + "status": "completed", + "completed_nodes": [], + "next_nodes_to_execute": [], + "invocation_state": {"role": "admin"}, + } + graph.deserialize_state(payload) + assert graph._current_invocation_state == {"role": "admin"} + + def test_deserialize_missing_invocation_state_defaults_empty(self): + """Backwards compat: old serialized payloads without invocation_state still work.""" + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a} + graph.edges = set() + graph.state = GraphState() + graph._interrupt_state = _InterruptState() + graph._resume_from_session = False + graph._resume_next_nodes = [] + + payload = { + "status": "completed", + "completed_nodes": [], + "next_nodes_to_execute": [], + } + graph.deserialize_state(payload) + assert graph._current_invocation_state == {} + + +class TestConditionSignatureDetection: + """Tests for the _is_context_condition helper.""" + + def test_detects_legacy_condition(self): + """Legacy condition without invocation_state param.""" + from strands.multiagent.graph import _is_context_condition + + def legacy(state: GraphState) -> bool: + return True + + assert not _is_context_condition(legacy) + + def test_detects_new_style_condition(self): + """New-style condition with invocation_state param.""" + from strands.multiagent.graph import _is_context_condition + + def new_style(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return True + + assert _is_context_condition(new_style) + + def test_detects_positional_invocation_state(self): + """Condition with invocation_state as positional param (also supported).""" + from strands.multiagent.graph import _is_context_condition + + def positional(state: GraphState, invocation_state: dict) -> bool: + return True + + assert _is_context_condition(positional) + + def test_lambda_without_invocation_state(self): + """Lambda conditions (legacy pattern).""" + from strands.multiagent.graph import _is_context_condition + + cond = lambda state: len(state.completed_nodes) > 0 # noqa: E731 + assert not _is_context_condition(cond) diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index b80a0f82d..2b7f17547 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -15,7 +15,7 @@ MessageAddedEvent, ) from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult, Status -from strands.multiagent.graph import GraphBuilder +from strands.multiagent.graph import GraphBuilder, GraphState from strands.session.file_session_manager import FileSessionManager from strands.types.content import ContentBlock from tests.fixtures.mock_hook_provider import MockHookProvider @@ -586,3 +586,319 @@ async def failing_after_two(*args, **kwargs): assert result.status == Status.COMPLETED assert len(result.execution_order) == 5 assert all(node.node_id == "loop_node" for node in result.execution_order) + + +@pytest.mark.asyncio +async def test_conditional_routing_with_invocation_state(): + """Test that edge conditions can use invocation_state for routing decisions. + + Graph structure: + entry -> (condition: use_detailed) -> detailed_agent + -> (condition: not use_detailed) -> brief_agent + """ + detailed_agent = Agent( + name="detailed", + model="us.amazon.nova-pro-v1:0", + system_prompt="Provide a very detailed, multi-paragraph explanation.", + ) + brief_agent = Agent( + name="brief", + model="us.amazon.nova-lite-v1:0", + system_prompt="Provide a one-sentence answer.", + ) + entry_agent = Agent( + name="entry", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a router. Just say 'routing complete'.", + ) + + def use_detailed(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("detail_level") == "high" + + def use_brief(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("detail_level") != "high" + + builder = GraphBuilder() + builder.add_node(entry_agent, "entry") + builder.add_node(detailed_agent, "detailed") + builder.add_node(brief_agent, "brief") + builder.add_edge("entry", "detailed", condition=use_detailed) + builder.add_edge("entry", "brief", condition=use_brief) + builder.set_entry_point("entry") + graph = builder.build() + + # With detail_level=high, only detailed_agent should execute + result = await graph.invoke_async("What is Python?", invocation_state={"detail_level": "high"}) + assert result.status == Status.COMPLETED + executed_nodes = {n.node_id for n in result.execution_order} + assert "detailed" in executed_nodes + assert "brief" not in executed_nodes + + # With detail_level=low, only brief_agent should execute + result = await graph.invoke_async("What is Python?", invocation_state={"detail_level": "low"}) + assert result.status == Status.COMPLETED + executed_nodes = {n.node_id for n in result.execution_order} + assert "brief" in executed_nodes + assert "detailed" not in executed_nodes + + +@pytest.mark.asyncio +async def test_legacy_conditions_unaffected_by_invocation_state(): + """Test that existing graphs with legacy conditions still work when invocation_state is passed.""" + agent1 = Agent( + name="agent1", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are agent 1. Just say hello.", + ) + agent2 = Agent( + name="agent2", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are agent 2. Just say goodbye.", + ) + + def legacy_condition(state: GraphState) -> bool: + return any(n.node_id == "agent1" for n in state.completed_nodes) + + builder = GraphBuilder() + builder.add_node(agent1, "agent1") + builder.add_node(agent2, "agent2") + builder.add_edge("agent1", "agent2", condition=legacy_condition) + builder.set_entry_point("agent1") + graph = builder.build() + + # Legacy condition should still work fine even with invocation_state passed + result = await graph.invoke_async("Hello", invocation_state={"some_key": "some_value"}) + assert result.status == Status.COMPLETED + executed_nodes = {n.node_id for n in result.execution_order} + assert "agent1" in executed_nodes + assert "agent2" in executed_nodes + + +@pytest.mark.asyncio +async def test_condition_combining_graph_state_and_invocation_state(): + """Test condition that uses both GraphState and invocation_state for decisions. + + Graph: entry -> (condition: completed entry AND feature_flag) -> feature_agent -> final + -> (unconditional) -> final + """ + entry_agent = Agent( + name="entry", + model="us.amazon.nova-lite-v1:0", + system_prompt="Just say 'entry done'.", + ) + feature_agent = Agent( + name="feature", + model="us.amazon.nova-lite-v1:0", + system_prompt="Execute the new feature path. Say 'feature executed'.", + ) + final_agent = Agent( + name="final", + model="us.amazon.nova-lite-v1:0", + system_prompt="Summarize. Say 'done'.", + ) + + def feature_gate(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + entry_completed = any(n.node_id == "entry" for n in state.completed_nodes) + flag_enabled = invocation_state.get("enable_feature_x", False) + return entry_completed and flag_enabled + + builder = GraphBuilder() + builder.add_node(entry_agent, "entry") + builder.add_node(feature_agent, "feature") + builder.add_node(final_agent, "final") + builder.add_edge("entry", "feature", condition=feature_gate) + builder.add_edge("entry", "final") + builder.add_edge("feature", "final") + builder.set_entry_point("entry") + graph = builder.build() + + # With flag disabled: entry -> final (feature skipped) + result = await graph.invoke_async("Run task", invocation_state={"enable_feature_x": False}) + assert result.status == Status.COMPLETED + executed_nodes = {n.node_id for n in result.execution_order} + assert "entry" in executed_nodes + assert "final" in executed_nodes + assert "feature" not in executed_nodes + + # With flag enabled: entry -> feature -> final + result = await graph.invoke_async("Run task", invocation_state={"enable_feature_x": True}) + assert result.status == Status.COMPLETED + executed_nodes = {n.node_id for n in result.execution_order} + assert "entry" in executed_nodes + assert "feature" in executed_nodes + assert "final" in executed_nodes + + +@pytest.mark.asyncio +async def test_diamond_graph_conditional_convergence(): + """Test diamond graph where one path is conditionally skipped and downstream converges. + + Graph: + entry -> (cond=True) -> fast_path -> merger + -> (cond=False) -> slow_path -> merger + + This tests the deadlock fix: merger should still execute even though slow_path's + incoming edge evaluates to False. + """ + entry_agent = Agent( + name="entry", + model="us.amazon.nova-lite-v1:0", + system_prompt="Say 'routed'.", + ) + fast_agent = Agent( + name="fast", + model="us.amazon.nova-lite-v1:0", + system_prompt="Say 'fast result'.", + ) + slow_agent = Agent( + name="slow", + model="us.amazon.nova-lite-v1:0", + system_prompt="Say 'slow result'.", + ) + merger_agent = Agent( + name="merger", + model="us.amazon.nova-lite-v1:0", + system_prompt="Merge results. Say 'merged'.", + ) + + def is_fast_mode(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("mode") == "fast" + + def is_slow_mode(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("mode") == "slow" + + builder = GraphBuilder() + builder.add_node(entry_agent, "entry") + builder.add_node(fast_agent, "fast") + builder.add_node(slow_agent, "slow") + builder.add_node(merger_agent, "merger") + builder.add_edge("entry", "fast", condition=is_fast_mode) + builder.add_edge("entry", "slow", condition=is_slow_mode) + builder.add_edge("fast", "merger") + builder.add_edge("slow", "merger") + builder.set_entry_point("entry") + graph = builder.build() + + # Fast mode: entry -> fast -> merger (slow skipped, merger not deadlocked) + result = await graph.invoke_async("Process", invocation_state={"mode": "fast"}) + assert result.status == Status.COMPLETED + executed_nodes = {n.node_id for n in result.execution_order} + assert executed_nodes == {"entry", "fast", "merger"} + + # Slow mode: entry -> slow -> merger (fast skipped) + result = await graph.invoke_async("Process", invocation_state={"mode": "slow"}) + assert result.status == Status.COMPLETED + executed_nodes = {n.node_id for n in result.execution_order} + assert executed_nodes == {"entry", "slow", "merger"} + + +@pytest.mark.asyncio +async def test_invocation_state_persisted_on_resume(tmp_path): + """Test that invocation_state is serialized and correctly used on resume after failure. + + Verifies that when a graph fails mid-execution and resumes from persisted state, + the invocation_state is restored and edge conditions evaluate correctly. + """ + session_id = f"invocation_state_resume_{uuid4()}" + session_manager = FileSessionManager(session_id=session_id, storage_dir=str(tmp_path)) + + agent1 = Agent(model="us.amazon.nova-lite-v1:0", system_prompt="Say 'step 1 done'.", name="agent1") + agent2 = Agent(model="us.amazon.nova-lite-v1:0", system_prompt="Say 'step 2 done'.", name="agent2") + agent3 = Agent(model="us.amazon.nova-lite-v1:0", system_prompt="Say 'step 3 done'.", name="agent3") + + def requires_premium(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("tier") == "premium" + + builder = GraphBuilder() + builder.add_node(agent1, "step1") + builder.add_node(agent2, "step2") + builder.add_node(agent3, "step3") + builder.add_edge("step1", "step2", condition=requires_premium) + builder.add_edge("step1", "step3") + builder.add_edge("step2", "step3") + builder.set_entry_point("step1") + builder.set_session_manager(session_manager) + graph = builder.build() + + # First invocation: step2 fails, step1 completed + async def failing_stream(*args, **kwargs): + raise Exception("Simulated failure in step2") + yield + + with patch.object(agent2, "stream_async", side_effect=failing_stream): + try: + await graph.invoke_async("Premium task", invocation_state={"tier": "premium"}) + raise AssertionError("Expected exception") + except Exception as e: + assert "Simulated failure in step2" in str(e) + + # Verify invocation_state was persisted + persisted = session_manager.read_multi_agent(session_id, graph.id) + assert persisted is not None + assert persisted.get("invocation_state") == {"tier": "premium"} + assert "step1" in persisted["completed_nodes"] + assert "step2" in persisted["failed_nodes"] + + # Resume: step2 should retry (condition still True), then step3 + result = await graph.invoke_async("Premium task", invocation_state={"tier": "premium"}) + assert result.status == Status.COMPLETED + executed_nodes = {n.node_id for n in result.execution_order} + assert "step2" in executed_nodes + assert "step3" in executed_nodes + + session_manager.delete_session(session_id) + + +@pytest.mark.asyncio +async def test_invocation_state_streaming_with_conditional_edges(): + """Test that streaming events are correctly emitted for conditional edge graphs. + + Verifies that only the activated path's nodes produce stream events. + """ + agent_a = Agent( + name="agent_a", + model="us.amazon.nova-lite-v1:0", + system_prompt="Say 'A done'.", + ) + agent_b = Agent( + name="agent_b", + model="us.amazon.nova-lite-v1:0", + system_prompt="Say 'B done'.", + ) + agent_c = Agent( + name="agent_c", + model="us.amazon.nova-lite-v1:0", + system_prompt="Say 'C done'.", + ) + + def go_to_b(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("path") == "B" + + def go_to_c(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: + return invocation_state.get("path") == "C" + + builder = GraphBuilder() + builder.add_node(agent_a, "A") + builder.add_node(agent_b, "B") + builder.add_node(agent_c, "C") + builder.add_edge("A", "B", condition=go_to_b) + builder.add_edge("A", "C", condition=go_to_c) + builder.set_entry_point("A") + graph = builder.build() + + # Stream with path=B — only A and B should have events + events = [] + async for event in graph.stream_async("Do something", invocation_state={"path": "B"}): + events.append(event) + + node_start_events = [e for e in events if e.get("type") == "multiagent_node_start"] + started_nodes = {e["node_id"] for e in node_start_events} + assert "A" in started_nodes + assert "B" in started_nodes + assert "C" not in started_nodes + + # Verify final result + result_events = [e for e in events if "result" in e and e.get("type") != "multiagent_node_stream"] + assert len(result_events) >= 1 + final_result = result_events[-1]["result"] + assert final_result.status == Status.COMPLETED From 7cc0d047332a9f6741b41e1b3da8c56fe18ed56e Mon Sep 17 00:00:00 2001 From: Yana Harris Date: Wed, 20 May 2026 20:48:57 +0000 Subject: [PATCH 2/2] fix: address review feedback for invocation_state edge conditions - Cache inspect.signature() results via WeakKeyDictionary - Remove unused @runtime_checkable decorator - Gate serialization validation on session_manager presence - Add validation on deserialization path for symmetry - Move json import to module level - Add inline comments for short-circuit and cache behavior - Extract _make_graph() test helper, fix list->set type consistency --- src/strands/multiagent/graph.py | 56 +++++++- tests/strands/multiagent/test_graph.py | 187 +++++++++++++------------ 2 files changed, 148 insertions(+), 95 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index e6c77e79c..b896e4cba 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -17,11 +17,13 @@ import asyncio import copy import inspect +import json import logging import time +import weakref from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, Protocol, TypeGuard, cast, runtime_checkable +from typing import Any, Protocol, TypeGuard, cast from opentelemetry import trace as trace_api @@ -63,7 +65,6 @@ _DEFAULT_GRAPH_ID = "default_graph" -@runtime_checkable class EdgeConditionWithContext(Protocol): """Protocol for edge conditions that receive invocation_state. @@ -72,6 +73,9 @@ class EdgeConditionWithContext(Protocol): environment-specific configuration. Designed with **kwargs for future extensibility without breaking changes. + + Note: not @runtime_checkable — isinstance() cannot distinguish callable signatures + structurally; use _is_context_condition() for dispatch instead. """ def __call__(self, state: "GraphState", *, invocation_state: dict[str, Any], **kwargs: Any) -> bool: @@ -82,18 +86,34 @@ def __call__(self, state: "GraphState", *, invocation_state: dict[str, Any], **k LegacyEdgeCondition = Callable[["GraphState"], bool] EdgeCondition = LegacyEdgeCondition | EdgeConditionWithContext +# GIL-protected: concurrent async graphs may read/write simultaneously, but +# CPython's GIL ensures dict mutation is atomic. Ephemeral callables (e.g. +# lambdas recreated per-call) will bypass the cache — this is benign; the +# fallback path is a single inspect.signature() call. +_context_condition_cache: weakref.WeakKeyDictionary[EdgeCondition, bool] = weakref.WeakKeyDictionary() + def _is_context_condition(condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]: """Check if a condition function accepts invocation_state parameter. Uses inspect.signature() for reliable detection, returning a TypeGuard - so mypy can narrow the type at call sites. + so mypy can narrow the type at call sites. Results are cached per condition + using weak references so entries are evicted when the function is collected. """ + try: + return _context_condition_cache[condition] + except (KeyError, TypeError): + pass try: sig = inspect.signature(condition) - return "invocation_state" in sig.parameters + result = "invocation_state" in sig.parameters except (ValueError, TypeError): - return False + result = False + try: + _context_condition_cache[condition] = result + except TypeError: + pass + return result @dataclass @@ -620,6 +640,8 @@ async def stream_async( if invocation_state is None: invocation_state = {} + if self.session_manager is not None: + self._validate_invocation_state(invocation_state) self._current_invocation_state = invocation_state await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state)) @@ -1239,6 +1261,20 @@ def _build_result(self, interrupts: list[Interrupt]) -> GraphResult: interrupts=interrupts, ) + @staticmethod + def _validate_invocation_state(invocation_state: dict[str, Any]) -> None: + """Validate that invocation_state is JSON-serializable. + + Raises: + TypeError: If invocation_state contains non-JSON-serializable values. + """ + try: + json.dumps(invocation_state) + except (TypeError, ValueError) as e: + raise TypeError( + f"invocation_state must be JSON-serializable for session persistence: {e}" + ) from e + def serialize_state(self) -> dict[str, Any]: """Serialize the current graph state to a dictionary.""" compute_nodes = self._compute_ready_nodes_for_resume() @@ -1277,7 +1313,9 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: internal_state = payload["_internal_state"] self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"]) - self._current_invocation_state = payload.get("invocation_state", {}) + invocation_state = payload.get("invocation_state", {}) + self._validate_invocation_state(invocation_state) + self._current_invocation_state = invocation_state if not payload.get("next_nodes_to_execute"): # Reset all nodes @@ -1318,10 +1356,16 @@ def _is_node_ready_for_resume( A node is ready if all TRAVERSABLE incoming edges have their source completed. Edges whose condition evaluates to False are excluded from the check — they represent paths that were intentionally skipped. + + Re-evaluates conditions (rather than caching traversal results) intentionally: + invocation_state may change between invocations, so conditions must reflect + current runtime context. This means condition logic changes between serialize + and resume will also take effect — consistent with the graph being defined in code. """ traversable_edges = [ e for e in incoming + # Short-circuit: skip signature inspection + cache lookup for unconditional edges. if e.condition is None or e.should_traverse(self.state, invocation_state=self._current_invocation_state) ] diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 1eb40d8a4..7ff31dd19 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -16,6 +16,26 @@ from strands.types._events import MultiAgentNodeCancelEvent +def _make_graph( + nodes: dict, + edges=None, + state=None, + invocation_state=None, + interrupt_state=None, +) -> Graph: + """Create a minimally-valid Graph instance for unit tests without invoking __init__.""" + graph = Graph.__new__(Graph) + graph.nodes = nodes + graph.edges = edges if edges is not None else set() + graph.state = state or GraphState() + graph._current_invocation_state = invocation_state or {} + graph._interrupt_state = interrupt_state or _InterruptState() + graph._resume_from_session = False + graph._resume_next_nodes = [] + graph.id = "test_graph" + return graph + + def create_mock_agent(name, response_text="Default response", metrics=None, agent_id=None): """Create a mock Agent with specified properties.""" agent = Mock(spec=Agent) @@ -2460,15 +2480,14 @@ def test_find_newly_ready_nodes_only_evaluates_outbound_edges(): node_d = GraphNode(node_id="D", executor=create_mock_agent("D")) node_e = GraphNode(node_id="E", executor=create_mock_agent("E")) - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a, "B": node_b, "C": node_c, "D": node_d, "E": node_e} - graph.edges = [ - GraphEdge(from_node=node_a, to_node=node_b), - GraphEdge(from_node=node_b, to_node=node_c), - GraphEdge(from_node=node_d, to_node=node_e), - ] - graph.state = GraphState() - graph._current_invocation_state = {} + graph = _make_graph( + nodes={"A": node_a, "B": node_b, "C": node_c, "D": node_d, "E": node_e}, + edges={ + GraphEdge(from_node=node_a, to_node=node_b), + GraphEdge(from_node=node_b, to_node=node_c), + GraphEdge(from_node=node_d, to_node=node_e), + }, + ) # When A completes, only B should be ready (not E) ready = graph._find_newly_ready_nodes([node_a]) @@ -2588,11 +2607,12 @@ def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a, "B": node_b} - graph.edges = [edge] - graph.state = GraphState(completed_nodes={node_a}) - graph._current_invocation_state = {"activate": True} + graph = _make_graph( + nodes={"A": node_a, "B": node_b}, + edges={edge}, + state=GraphState(completed_nodes={node_a}), + invocation_state={"activate": True}, + ) assert graph._is_node_ready_with_conditions(node_b, [node_a]) assert received_state == {"activate": True} @@ -2608,11 +2628,12 @@ def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> edge = GraphEdge(from_node=node_a, to_node=node_b, condition=context_condition) - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a, "B": node_b} - graph.edges = [edge] - graph.state = GraphState(completed_nodes={node_a}) - graph._current_invocation_state = {"activate": False} + graph = _make_graph( + nodes={"A": node_a, "B": node_b}, + edges={edge}, + state=GraphState(completed_nodes={node_a}), + invocation_state={"activate": False}, + ) assert not graph._is_node_ready_with_conditions(node_b, [node_a]) @@ -2637,16 +2658,16 @@ def context_condition(state: GraphState, *, invocation_state: dict, **kwargs) -> ), ) - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a, "B": node_b} - graph.edges = [edge] - graph.state = GraphState( - task="test task", - completed_nodes={node_a}, - results={"A": NodeResult(result=mock_result)}, + graph = _make_graph( + nodes={"A": node_a, "B": node_b}, + edges={edge}, + state=GraphState( + task="test task", + completed_nodes={node_a}, + results={"A": NodeResult(result=mock_result)}, + ), + invocation_state={"include_dep": False}, ) - graph._current_invocation_state = {"include_dep": False} - graph._interrupt_state = _InterruptState() # With condition=False, dependency is excluded -> gets raw task node_input = graph._build_node_input(node_b) @@ -2673,14 +2694,14 @@ def test_resume_skips_false_condition_edges(self): def always_false(state: GraphState) -> bool: return False - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a, "B": node_b, "C": node_c} - graph.edges = [ - GraphEdge(from_node=node_a, to_node=node_b, condition=always_false), - GraphEdge(from_node=node_a, to_node=node_c), # unconditional - ] - graph.state = GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}) - graph._current_invocation_state = {} + graph = _make_graph( + nodes={"A": node_a, "B": node_b, "C": node_c}, + edges={ + GraphEdge(from_node=node_a, to_node=node_b, condition=always_false), + GraphEdge(from_node=node_a, to_node=node_c), # unconditional + }, + state=GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}), + ) ready = graph._compute_ready_nodes_for_resume() ready_ids = {n.node_id for n in ready} @@ -2716,15 +2737,16 @@ def use_fast_path(state: GraphState, *, invocation_state: dict, **kwargs) -> boo def skip_direct(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: return not invocation_state.get("fast", False) - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a, "B": node_b, "C": node_c} - graph.edges = [ - GraphEdge(from_node=node_a, to_node=node_b, condition=use_fast_path), - GraphEdge(from_node=node_a, to_node=node_c, condition=skip_direct), - GraphEdge(from_node=node_b, to_node=node_c), - ] - graph.state = GraphState(status=Status.INTERRUPTED, completed_nodes={node_a, node_b}) - graph._current_invocation_state = {"fast": True} + graph = _make_graph( + nodes={"A": node_a, "B": node_b, "C": node_c}, + edges={ + GraphEdge(from_node=node_a, to_node=node_b, condition=use_fast_path), + GraphEdge(from_node=node_a, to_node=node_c, condition=skip_direct), + GraphEdge(from_node=node_b, to_node=node_c), + }, + state=GraphState(status=Status.INTERRUPTED, completed_nodes={node_a, node_b}), + invocation_state={"fast": True}, + ) ready = graph._compute_ready_nodes_for_resume() ready_ids = {n.node_id for n in ready} @@ -2740,13 +2762,13 @@ def test_resume_all_conditions_false_blocks_node(self): def always_false(state: GraphState) -> bool: return False - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a, "B": node_b} - graph.edges = [ - GraphEdge(from_node=node_a, to_node=node_b, condition=always_false), - ] - graph.state = GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}) - graph._current_invocation_state = {} + graph = _make_graph( + nodes={"A": node_a, "B": node_b}, + edges={ + GraphEdge(from_node=node_a, to_node=node_b, condition=always_false), + }, + state=GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}), + ) ready = graph._compute_ready_nodes_for_resume() ready_ids = {n.node_id for n in ready} @@ -2766,13 +2788,14 @@ def check_role(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: def check_not_admin(state: GraphState, *, invocation_state: dict, **kwargs) -> bool: return invocation_state.get("role") != "admin" - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a, "B": node_b, "C": node_c} - graph.edges = [ - GraphEdge(from_node=node_a, to_node=node_b, condition=check_role), - GraphEdge(from_node=node_a, to_node=node_c, condition=check_not_admin), - ] - graph.state = GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}) + graph = _make_graph( + nodes={"A": node_a, "B": node_b, "C": node_c}, + edges={ + GraphEdge(from_node=node_a, to_node=node_b, condition=check_role), + GraphEdge(from_node=node_a, to_node=node_c, condition=check_not_admin), + }, + state=GraphState(status=Status.INTERRUPTED, completed_nodes={node_a}), + ) # As admin: only B should be ready graph._current_invocation_state = {"role": "admin"} @@ -2798,15 +2821,15 @@ def test_resume_mixed_conditional_unconditional_edges(self): def always_false(state: GraphState) -> bool: return False - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a, "B": node_b, "C": node_c} - graph.edges = [ - GraphEdge(from_node=node_a, to_node=node_b), # unconditional - GraphEdge(from_node=node_a, to_node=node_c, condition=always_false), # conditional (False) - GraphEdge(from_node=node_b, to_node=node_c), # unconditional - ] - graph.state = GraphState(status=Status.INTERRUPTED, completed_nodes={node_a, node_b}) - graph._current_invocation_state = {} + graph = _make_graph( + nodes={"A": node_a, "B": node_b, "C": node_c}, + edges={ + GraphEdge(from_node=node_a, to_node=node_b), # unconditional + GraphEdge(from_node=node_a, to_node=node_c, condition=always_false), # conditional (False) + GraphEdge(from_node=node_b, to_node=node_c), # unconditional + }, + state=GraphState(status=Status.INTERRUPTED, completed_nodes={node_a, node_b}), + ) ready = graph._compute_ready_nodes_for_resume() ready_ids = {n.node_id for n in ready} @@ -2821,13 +2844,11 @@ def test_serialize_includes_invocation_state(self): """Verify invocation_state appears in serialized payload.""" node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a} - graph.edges = [] - graph.state = GraphState(status=Status.COMPLETED, completed_nodes={node_a}, task="test") - graph._current_invocation_state = {"feature_flag": True, "user_id": "123"} - graph._interrupt_state = _InterruptState() - graph.id = "test_graph" + graph = _make_graph( + nodes={"A": node_a}, + state=GraphState(status=Status.COMPLETED, completed_nodes={node_a}, task="test"), + invocation_state={"feature_flag": True, "user_id": "123"}, + ) serialized = graph.serialize_state() assert "invocation_state" in serialized @@ -2837,13 +2858,7 @@ def test_deserialize_restores_invocation_state(self): """Verify invocation_state is restored on deserialization.""" node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a} - graph.edges = set() - graph.state = GraphState() - graph._interrupt_state = _InterruptState() - graph._resume_from_session = False - graph._resume_next_nodes = [] + graph = _make_graph(nodes={"A": node_a}) payload = { "status": "completed", @@ -2858,13 +2873,7 @@ def test_deserialize_missing_invocation_state_defaults_empty(self): """Backwards compat: old serialized payloads without invocation_state still work.""" node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) - graph = Graph.__new__(Graph) - graph.nodes = {"A": node_a} - graph.edges = set() - graph.state = GraphState() - graph._interrupt_state = _InterruptState() - graph._resume_from_session = False - graph._resume_next_nodes = [] + graph = _make_graph(nodes={"A": node_a}) payload = { "status": "completed",