From 56510b325027952f66a63db98906a5396f6e03fe Mon Sep 17 00:00:00 2001 From: westerberg Date: Wed, 10 Jun 2026 15:42:35 +0000 Subject: [PATCH 1/3] feat(workflow): Allow ToolNode to accept JSON string or Content inputs Enable ToolNode to receive JSON-formatted strings or types.Content objects in workflows. This supports flexible upstream formats, such as output from LLM agents or raw user content, automatically parsing them into dictionary arguments for downstream tools. --- src/google/adk/workflow/_tool_node.py | 15 +++ tests/unittests/workflow/test_tool_node.py | 133 +++++++++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 tests/unittests/workflow/test_tool_node.py diff --git a/src/google/adk/workflow/_tool_node.py b/src/google/adk/workflow/_tool_node.py index fad28bb714..b603ac91dd 100644 --- a/src/google/adk/workflow/_tool_node.py +++ b/src/google/adk/workflow/_tool_node.py @@ -66,7 +66,22 @@ async def _run_impl( function_call_id=str(uuid.uuid4()), ) + from google.genai import types + from ..utils.content_utils import extract_text_from_content + import json + args = node_input + if isinstance(args, types.Content): + args = extract_text_from_content(args) + + if isinstance(args, str): + args = args.strip() + if args: + try: + args = json.loads(args) + except json.JSONDecodeError: + pass + if args is None: args = {} elif not isinstance(args, dict): diff --git a/tests/unittests/workflow/test_tool_node.py b/tests/unittests/workflow/test_tool_node.py new file mode 100644 index 0000000000..7a6ccf3ed8 --- /dev/null +++ b/tests/unittests/workflow/test_tool_node.py @@ -0,0 +1,133 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ToolNode input parsing and execution.""" + +from typing import Any + +from google.adk.events.event import Event +from google.adk.tools.base_tool import BaseTool +from google.adk.workflow import START +from google.adk.workflow._tool_node import _ToolNode as ToolNode +from google.adk.workflow._workflow import Workflow +from google.genai import types +import pytest + +from . import workflow_testing_utils +from .. import testing_utils + + +class MockTool(BaseTool): + """A mock tool that returns the args it was called with.""" + + def __init__(self, name="mock_tool", description="Mock tool"): + super().__init__(name=name, description=description) + + async def run_async(self, *, args: dict[str, Any], tool_context) -> Any: + return args + + +async def _run_tool_node_wf(node_input: Any) -> list[Any]: + """Runs a workflow with a ToolNode that receives node_input.""" + tool_node = ToolNode(tool=MockTool()) + + def start_node(): + return Event(output=node_input) + + wf = Workflow( + name="tool_node_test_wf", + edges=[ + (START, start_node), + (start_node, tool_node), + ], + ) + app_instance = testing_utils.App(name="test_app", root_agent=wf) + runner = testing_utils.InMemoryRunner(app=app_instance) + events = await runner.run_async("start") + return workflow_testing_utils.simplify_events_with_node(events) + + +@pytest.mark.asyncio +async def test_tool_node_accepts_dict(): + """Tests that ToolNode accepts a dict as input and passes it to the tool.""" + input_dict = {"param_a": 1, "param_b": "value"} + simplified = await _run_tool_node_wf(input_dict) + assert ( + "tool_node_test_wf@1/mock_tool@1", + {"output": input_dict}, + ) in simplified + + +@pytest.mark.asyncio +async def test_tool_node_accepts_none(): + """Tests that ToolNode accepts None, converting it to an empty dict.""" + simplified = await _run_tool_node_wf(None) + assert ("tool_node_test_wf@1/mock_tool@1", {"output": {}}) in simplified + + +@pytest.mark.asyncio +async def test_tool_node_accepts_json_string(): + """Tests that ToolNode accepts a valid JSON string representing a dict.""" + json_str = '{"param_a": 1, "param_b": "value"}' + simplified = await _run_tool_node_wf(json_str) + assert ( + "tool_node_test_wf@1/mock_tool@1", + {"output": {"param_a": 1, "param_b": "value"}}, + ) in simplified + + +@pytest.mark.asyncio +async def test_tool_node_accepts_content_with_json_string(): + """Tests that ToolNode accepts a types.Content containing a JSON string.""" + json_str = '{"param_a": 1, "param_b": "value"}' + content = types.Content( + parts=[types.Part.from_text(text=json_str)], role="user" + ) + simplified = await _run_tool_node_wf(content) + assert ( + "tool_node_test_wf@1/mock_tool@1", + {"output": {"param_a": 1, "param_b": "value"}}, + ) in simplified + + +@pytest.mark.asyncio +async def test_tool_node_rejects_non_dict_json_string(): + """Tests that ToolNode raises TypeError if JSON string represents a non-dict (e.g. list).""" + json_str = "[1, 2, 3]" + with pytest.raises( + TypeError, match="The input to ToolNode must be a dictionary" + ): + await _run_tool_node_wf(json_str) + + +@pytest.mark.asyncio +async def test_tool_node_rejects_invalid_json_string(): + """Tests that ToolNode raises TypeError if string input is not valid JSON.""" + invalid_str = "not a json" + with pytest.raises( + TypeError, match="The input to ToolNode must be a dictionary" + ): + await _run_tool_node_wf(invalid_str) + + +@pytest.mark.asyncio +async def test_tool_node_rejects_non_dict_content(): + """Tests that ToolNode raises TypeError if Content contains non-dict text.""" + content = types.Content( + parts=[types.Part.from_text(text="not a json")], role="user" + ) + with pytest.raises( + TypeError, match="The input to ToolNode must be a dictionary" + ): + await _run_tool_node_wf(content) From 020c5377eb83500e6347f22339c9eb60e746f897 Mon Sep 17 00:00:00 2001 From: westerberg Date: Wed, 10 Jun 2026 16:01:10 +0000 Subject: [PATCH 2/3] lint fix --- src/google/adk/workflow/_tool_node.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/google/adk/workflow/_tool_node.py b/src/google/adk/workflow/_tool_node.py index b603ac91dd..83a49b77d6 100644 --- a/src/google/adk/workflow/_tool_node.py +++ b/src/google/adk/workflow/_tool_node.py @@ -66,9 +66,10 @@ async def _run_impl( function_call_id=str(uuid.uuid4()), ) + import json + from google.genai import types from ..utils.content_utils import extract_text_from_content - import json args = node_input if isinstance(args, types.Content): From bfb7df84aab056ba14aca97a2983865368828df5 Mon Sep 17 00:00:00 2001 From: westerberg Date: Thu, 11 Jun 2026 08:26:35 +0000 Subject: [PATCH 3/3] pre-commit fixes --- src/google/adk/flows/llm_flows/base_llm_flow.py | 4 ++-- src/google/adk/telemetry/_instrumentation.py | 2 +- src/google/adk/workflow/_tool_node.py | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 6878c9a5c6..0f2f6cc31a 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -28,6 +28,8 @@ from websockets.exceptions import ConnectionClosed from websockets.exceptions import ConnectionClosedOK +from . import _output_schema_processor +from . import functions from ...agents.base_agent import BaseAgent from ...agents.callback_context import CallbackContext from ...agents.invocation_context import InvocationContext @@ -50,8 +52,6 @@ from ...tools.tool_context import ToolContext from ...utils import model_name_utils from ...utils.context_utils import Aclosing -from . import _output_schema_processor -from . import functions from .audio_cache_manager import AudioCacheManager from .functions import build_auth_request_event diff --git a/src/google/adk/telemetry/_instrumentation.py b/src/google/adk/telemetry/_instrumentation.py index 8ce2797628..ea5dac4bff 100644 --- a/src/google/adk/telemetry/_instrumentation.py +++ b/src/google/adk/telemetry/_instrumentation.py @@ -26,9 +26,9 @@ from opentelemetry import trace import opentelemetry.context as context_api -from ..events import event as event_lib from . import _metrics from . import tracing +from ..events import event as event_lib if TYPE_CHECKING: from ..agents.base_agent import BaseAgent diff --git a/src/google/adk/workflow/_tool_node.py b/src/google/adk/workflow/_tool_node.py index 83a49b77d6..a007ce08c1 100644 --- a/src/google/adk/workflow/_tool_node.py +++ b/src/google/adk/workflow/_tool_node.py @@ -69,6 +69,7 @@ async def _run_impl( import json from google.genai import types + from ..utils.content_utils import extract_text_from_content args = node_input