Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import contextvars
import inspect
import json
import logging
Expand Down Expand Up @@ -1820,9 +1821,16 @@ async def invoke_with_termination_handling(
False,
)

execution_results = await asyncio.gather(*[
invoke_with_termination_handling(function_call, seq_idx) for seq_idx, function_call in enumerate(function_calls)
])
# Create each task inside a copied context so the active agent span is
# preserved for every parallel tool invocation.
execution_tasks = [
contextvars.copy_context().run(
asyncio.create_task,
invoke_with_termination_handling(function_call, seq_idx),
)
for seq_idx, function_call in enumerate(function_calls)
]
execution_results = await asyncio.gather(*execution_tasks)
Comment thread
2830500285 marked this conversation as resolved.

# Unpack results - each is (Content, terminate_flag)
contents: list[Content] = [result[0] for result in execution_results]
Expand Down
87 changes: 87 additions & 0 deletions python/packages/core/tests/core/test_observability.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

import asyncio
import logging
from collections.abc import AsyncIterable, Awaitable, MutableSequence, Sequence
from typing import Any
Expand Down Expand Up @@ -4239,6 +4240,92 @@ async def _get() -> ChatResponse:
assert inner.context.trace_id == agent_span.context.trace_id


@pytest.mark.parametrize("enable_sensitive_data", [False], indirect=True)
async def test_parallel_function_call_spans_nested_under_agent_span(span_exporter: InMemorySpanExporter):
"""Parallel execute_tool spans should preserve the active agent span context."""
from agent_framework._tools import FunctionInvocationLayer

@tool(name="first_tool", description="First parallel tool", approval_mode="never_require")
async def first_tool() -> str:
await asyncio.sleep(0)
return "first"

@tool(name="second_tool", description="Second parallel tool", approval_mode="never_require")
async def second_tool() -> str:
await asyncio.sleep(0)
return "second"

class ParallelToolChatClient(FunctionInvocationLayer, ChatTelemetryLayer, BaseChatClient[Any]):
def __init__(self) -> None:
super().__init__()
self.call_count = 0

def service_url(self):
return "https://test.example.com"

def _inner_get_response(
self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
del stream, messages, options, kwargs
self.call_count += 1
if self.call_count == 1:

async def _get_tool_calls() -> ChatResponse:
return ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_first",
name="first_tool",
arguments="{}",
),
Content.from_function_call(
call_id="call_second",
name="second_tool",
arguments="{}",
),
],
)
],
)

return _get_tool_calls()

async def _get_final() -> ChatResponse:
return ChatResponse(
messages=[Message(role="assistant", contents=["Both tools completed."])],
finish_reason="stop",
)

return _get_final()

agent = Agent(
client=ParallelToolChatClient(),
id="parallel_tool_agent_id",
name="parallel_tool_agent",
default_options={"model": "ToolModel", "tools": [first_tool, second_tool], "tool_choice": "auto"},
)

span_exporter.clear()
await agent.run("Call both tools.")

spans = span_exporter.get_finished_spans()
invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION]
tool_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.TOOL_EXECUTION_OPERATION]

assert len(invoke_spans) == 1
assert len(tool_spans) == 2
assert {s.attributes.get(OtelAttr.TOOL_NAME.value) for s in tool_spans} == {"first_tool", "second_tool"}

agent_span = invoke_spans[0]
for tool_span in tool_spans:
assert tool_span.parent is not None, f"Span {tool_span.name} has no parent"
assert tool_span.parent.span_id == agent_span.context.span_id
assert tool_span.context.trace_id == agent_span.context.trace_id


@pytest.mark.parametrize("stream", [False, True])
async def test_chat_span_nested_under_explicit_outer_span(
span_exporter: InMemorySpanExporter, mock_chat_client, stream: bool
Expand Down