1111from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
1212from opentelemetry import context as otel_context
1313from opentelemetry .propagate import extract , inject
14+ from opentelemetry .trace import TracerProvider , get_tracer
1415from pydantic import BaseModel , TypeAdapter
1516from typing_extensions import Self
1617
18+ from mcp .shared ._otel_utils import mcp_client_span
1719from mcp .shared .exceptions import MCPError
1820from mcp .shared .message import MessageMetadata , ServerMessageMetadata , SessionMessage
1921from mcp .shared .response_router import ResponseRouter
@@ -190,6 +192,8 @@ def __init__(
190192 write_stream : MemoryObjectSendStream [SessionMessage ],
191193 # If none, reading will never time out
192194 read_timeout_seconds : float | None = None ,
195+ * ,
196+ tracer_provider : TracerProvider | None = None ,
193197 ) -> None :
194198 self ._read_stream = read_stream
195199 self ._write_stream = write_stream
@@ -200,6 +204,7 @@ def __init__(
200204 self ._progress_callbacks = {}
201205 self ._response_routers = []
202206 self ._exit_stack = AsyncExitStack ()
207+ self ._tracer = get_tracer ("mcp" , tracer_provider = tracer_provider )
203208
204209 def add_response_router (self , router : ResponseRouter ) -> None :
205210 """Register a response router to handle responses for non-standard requests.
@@ -256,22 +261,22 @@ async def send_request(
256261 response_stream , response_stream_reader = anyio .create_memory_object_stream [JSONRPCResponse | JSONRPCError ](1 )
257262 self ._response_streams [request_id ] = response_stream
258263
259- # Set up progress token if progress callback is provided
260- request_data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
261- if progress_callback is not None :
262- # Use request_id as progress token
263- if "params" not in request_data : # pragma: lax no cover
264- request_data ["params" ] = {}
265- if "_meta" not in request_data ["params" ]: # pragma: lax no cover
266- request_data ["params" ]["_meta" ] = {}
267- request_data ["params" ]["_meta" ]["progressToken" ] = request_id
268- # Store the callback for this request
269- self ._progress_callbacks [request_id ] = progress_callback
264+ async def make_request () -> ReceiveResultT :
265+ # Set up progress token if progress callback is provided
266+ request_data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
267+ if progress_callback is not None :
268+ # Use request_id as progress token
269+ if "params" not in request_data : # pragma: lax no cover
270+ request_data ["params" ] = {}
271+ if "_meta" not in request_data ["params" ]: # pragma: lax no cover
272+ request_data ["params" ]["_meta" ] = {}
273+ request_data ["params" ]["_meta" ]["progressToken" ] = request_id
274+ # Store the callback for this request
275+ self ._progress_callbacks [request_id ] = progress_callback
276+
277+ # Propagate opentelemetry trace context
278+ self ._inject_otel_context (request_data )
270279
271- # Propagate opentelemetry trace context
272- self ._inject_otel_context (request_data )
273-
274- try :
275280 jsonrpc_request = JSONRPCRequest (jsonrpc = "2.0" , id = request_id , ** request_data )
276281 await self ._write_stream .send (SessionMessage (message = jsonrpc_request , metadata = metadata ))
277282
@@ -291,6 +296,9 @@ async def send_request(
291296 else :
292297 return result_type .model_validate (response_or_error .result , by_name = False )
293298
299+ try :
300+ with mcp_client_span (self ._tracer , request , json_rpc_request_id = request_id ):
301+ return await make_request ()
294302 finally :
295303 self ._response_streams .pop (request_id , None )
296304 self ._progress_callbacks .pop (request_id , None )
@@ -315,7 +323,9 @@ async def send_notification(
315323 message = jsonrpc_notification ,
316324 metadata = ServerMessageMetadata (related_request_id = related_request_id ) if related_request_id else None ,
317325 )
318- await self ._write_stream .send (session_message )
326+
327+ with mcp_client_span (self ._tracer , notification ):
328+ await self ._write_stream .send (session_message )
319329
320330 def _inject_otel_context (self , request : dict [str , Any ]) -> None :
321331 """Propagate OpenTelemetry context in `_meta`.
0 commit comments