diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 065324289f..82957e4a65 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import contextvars import inspect import json import logging @@ -1820,9 +1821,16 @@ async def invoke_with_termination_handling( False, ) - execution_results = await asyncio.gather(*[ - invoke_with_termination_handling(function_call, seq_idx) for seq_idx, function_call in enumerate(function_calls) - ]) + # Create each task inside a copied context so the active agent span is + # preserved for every parallel tool invocation. + execution_tasks = [ + contextvars.copy_context().run( + asyncio.create_task, + invoke_with_termination_handling(function_call, seq_idx), + ) + for seq_idx, function_call in enumerate(function_calls) + ] + execution_results = await asyncio.gather(*execution_tasks) # Unpack results - each is (Content, terminate_flag) contents: list[Content] = [result[0] for result in execution_results] diff --git a/python/packages/core/tests/core/test_observability.py b/python/packages/core/tests/core/test_observability.py index 46f6e2c151..03347d0359 100644 --- a/python/packages/core/tests/core/test_observability.py +++ b/python/packages/core/tests/core/test_observability.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import asyncio import logging from collections.abc import AsyncIterable, Awaitable, MutableSequence, Sequence from typing import Any @@ -4239,6 +4240,92 @@ async def _get() -> ChatResponse: assert inner.context.trace_id == agent_span.context.trace_id +@pytest.mark.parametrize("enable_sensitive_data", [False], indirect=True) +async def test_parallel_function_call_spans_nested_under_agent_span(span_exporter: InMemorySpanExporter): + """Parallel execute_tool spans should preserve the active agent span context.""" + from agent_framework._tools import FunctionInvocationLayer + + @tool(name="first_tool", description="First parallel tool", approval_mode="never_require") + async def first_tool() -> str: + await asyncio.sleep(0) + return "first" + + @tool(name="second_tool", description="Second parallel tool", approval_mode="never_require") + async def second_tool() -> str: + await asyncio.sleep(0) + return "second" + + class ParallelToolChatClient(FunctionInvocationLayer, ChatTelemetryLayer, BaseChatClient[Any]): + def __init__(self) -> None: + super().__init__() + self.call_count = 0 + + def service_url(self): + return "https://test.example.com" + + def _inner_get_response( + self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: + del stream, messages, options, kwargs + self.call_count += 1 + if self.call_count == 1: + + async def _get_tool_calls() -> ChatResponse: + return ChatResponse( + messages=[ + Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_first", + name="first_tool", + arguments="{}", + ), + Content.from_function_call( + call_id="call_second", + name="second_tool", + arguments="{}", + ), + ], + ) + ], + ) + + return _get_tool_calls() + + async def _get_final() -> ChatResponse: + return ChatResponse( + messages=[Message(role="assistant", contents=["Both tools completed."])], + finish_reason="stop", + ) + + return _get_final() + + agent = Agent( + client=ParallelToolChatClient(), + id="parallel_tool_agent_id", + name="parallel_tool_agent", + default_options={"model": "ToolModel", "tools": [first_tool, second_tool], "tool_choice": "auto"}, + ) + + span_exporter.clear() + await agent.run("Call both tools.") + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] + tool_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.TOOL_EXECUTION_OPERATION] + + assert len(invoke_spans) == 1 + assert len(tool_spans) == 2 + assert {s.attributes.get(OtelAttr.TOOL_NAME.value) for s in tool_spans} == {"first_tool", "second_tool"} + + agent_span = invoke_spans[0] + for tool_span in tool_spans: + assert tool_span.parent is not None, f"Span {tool_span.name} has no parent" + assert tool_span.parent.span_id == agent_span.context.span_id + assert tool_span.context.trace_id == agent_span.context.trace_id + + @pytest.mark.parametrize("stream", [False, True]) async def test_chat_span_nested_under_explicit_outer_span( span_exporter: InMemorySpanExporter, mock_chat_client, stream: bool