diff --git a/pyproject.toml b/pyproject.toml index 3fc90a9..2361913 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ client = ["httpx[http2]"] adk = ["google-adk>=1.20.0"] openai = ["openai-agents>=0.6.1"] pydantic_ai = ["pydantic-ai-slim>=1.68.0"] +tracing = ["opentelemetry-api>=1.36.0"] [build-system] requires = ["maturin>=1.6,<2.0"] diff --git a/python/restate/ext/adk/summarizer.py b/python/restate/ext/adk/summarizer.py index c50f17a..73ed47e 100644 --- a/python/restate/ext/adk/summarizer.py +++ b/python/restate/ext/adk/summarizer.py @@ -67,17 +67,14 @@ def from_summarizer( """Create a RestateEventSummarizer wrapping a custom summarizer.""" return RestateEventSummarizer(summarizer, max_retries=max_retries) - async def maybe_summarize_events( - self, *, events: list[Event] - ) -> Optional[Event]: + async def maybe_summarize_events(self, *, events: list[Event]) -> Optional[Event]: if not events: return None ctx = current_context() if ctx is None: raise RuntimeError( - "No Restate context found. " - "RestateEventSummarizer must be used from within a Restate handler." + "No Restate context found. RestateEventSummarizer must be used from within a Restate handler." ) inner = self._inner @@ -92,4 +89,4 @@ async def call_inner() -> Optional[Event]: max_attempts=self._max_retries, initial_retry_interval=timedelta(seconds=1), ), - ) \ No newline at end of file + ) diff --git a/python/restate/ext/pydantic/__init__.py b/python/restate/ext/pydantic/__init__.py index 76d633d..2b16785 100644 --- a/python/restate/ext/pydantic/__init__.py +++ b/python/restate/ext/pydantic/__init__.py @@ -8,6 +8,7 @@ from ._serde import PydanticTypeAdapter from ._toolset import RestateContextRunToolSet + def restate_object_context() -> ObjectContext: """Get the current Restate ObjectContext.""" ctx = current_context() diff --git a/python/restate/ext/tracing/__init__.py b/python/restate/ext/tracing/__init__.py new file mode 100644 index 0000000..16b8df3 --- /dev/null +++ b/python/restate/ext/tracing/__init__.py @@ -0,0 +1,3 @@ +from ._tracing import RestateTracer, RestateTracerProvider + +__all__ = ["RestateTracer", "RestateTracerProvider"] diff --git a/python/restate/ext/tracing/_tracing.py b/python/restate/ext/tracing/_tracing.py new file mode 100644 index 0000000..b272185 --- /dev/null +++ b/python/restate/ext/tracing/_tracing.py @@ -0,0 +1,146 @@ +"""Restate OTEL tracer wrapper that flattens all spans under the Restate trace. + +Wraps any tracer so that every span — regardless of framework nesting — becomes a +direct child of the Restate invocation trace. Works transparently with any +OTEL-integrated agent framework (Google ADK, Pydantic AI, OpenAI Agents, etc.). + +Usage: + tracer = RestateTracer(trace_api.get_tracer("my-tracer")) + # All spans created by this tracer are flat children of the Restate trace. +""" + +from opentelemetry.trace import INVALID_SPAN, use_span, Tracer, TracerProvider +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from restate.server_context import ( + current_context, + get_extension_data, + set_extension_data, + restate_context_is_replaying, +) + +_propagator = TraceContextTextMapPropagator() +_EXTENSION_KEY = "otel_span_cleanup" + + +class _SpanCleanup: + """Stored as Restate extension data. ``__close__`` is called automatically + when the Restate invocation context is torn down, ending any spans that + were never properly closed (e.g. because the handler raised).""" + + def __init__(self): + self._spans = [] + + def track(self, span): + self._spans.append(span) + + def __close__(self): + for span in self._spans: + if span.is_recording(): + span.end() + self._spans.clear() + + +class RestateTracer(Tracer): + """Wraps a ``Tracer`` to always parent spans under the Restate root context. + + During Restate replay, returns no-op spans to avoid duplicates.""" + + def __init__(self, tracer): + self._tracer = tracer + + @staticmethod + def _get_root_context(): + """Extract the Restate trace parent from the current handler, or None.""" + ctx = current_context() + if ctx is None: + raise Exception("You are not in a Restate handler") + return _propagator.extract(ctx.request().attempt_headers) + + def start_span( + self, + name, + context=None, + kind=None, + attributes=None, + links=None, + start_time=None, + record_exception=True, + set_status_on_exception=True, + ): + if restate_context_is_replaying.get(False): + return INVALID_SPAN + root = self._get_root_context() + if root is not None: + context = root + span = self._tracer.start_span( + name, + context=context, + kind=kind, + attributes=attributes, + links=links, + start_time=start_time, + record_exception=record_exception, + set_status_on_exception=set_status_on_exception, + ) + self._track_span(span) + return span + + def start_as_current_span( + self, + name, + context=None, + kind=None, + attributes=None, + links=None, + start_time=None, + record_exception=True, + set_status_on_exception=True, + end_on_exit=True, + ): + if restate_context_is_replaying.get(False): + return use_span(INVALID_SPAN, end_on_exit=False) + root = self._get_root_context() + if root is not None: + context = root + return self._tracer.start_as_current_span( + name, + context=context, + kind=kind, + attributes=attributes, + links=links, + start_time=start_time, + record_exception=record_exception, + set_status_on_exception=set_status_on_exception, + end_on_exit=end_on_exit, + ) + + @staticmethod + def _track_span(span): + """Register a span for cleanup when the Restate invocation ends.""" + ctx = current_context() + if ctx is None: + return + cleanup = get_extension_data(ctx, _EXTENSION_KEY) + if cleanup is None: + cleanup = _SpanCleanup() + set_extension_data(ctx, _EXTENSION_KEY, cleanup) + cleanup.track(span) + + def __getattr__(self, name): + return getattr(self._tracer, name) + + +class RestateTracerProvider(TracerProvider): + """Wraps a ``TracerProvider`` to return ``RestateTracer`` instances. + + Pass this to instrumentors (e.g. ``GoogleADKInstrumentor``) so that every + span they create is automatically parented under the Restate invocation.""" + + def __init__(self, provider): + self._provider = provider + + def get_tracer(self, *args, **kwargs): + return RestateTracer(self._provider.get_tracer(*args, **kwargs)) + + def __getattr__(self, name): + return getattr(self._provider, name) diff --git a/uv.lock b/uv.lock index a5a2389..8b3e41c 100644 --- a/uv.lock +++ b/uv.lock @@ -983,6 +983,7 @@ wheels = [ name = "griffelib" version = "2.0.0" source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ad/06/eccbd311c9e2b3ca45dbc063b93134c57a1ccc7607c5e545264ad092c4a9/griffelib-2.0.0.tar.gz", hash = "sha256:e504d637a089f5cab9b5daf18f7645970509bf4f53eda8d79ed71cce8bd97934", size = 166312, upload-time = "2026-03-23T21:06:55.954Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/4d/51/c936033e16d12b627ea334aaaaf42229c37620d0f15593456ab69ab48161/griffelib-2.0.0-py3-none-any.whl", hash = "sha256:01284878c966508b6d6f1dbff9b6fa607bc062d8261c5c7253cb285b06422a7f", size = 142004, upload-time = "2026-02-09T19:09:40.561Z" }, ] @@ -2512,6 +2513,9 @@ test = [ { name = "hypercorn" }, { name = "pytest" }, ] +tracing = [ + { name = "opentelemetry-api" }, +] [package.metadata] requires-dist = [ @@ -2525,6 +2529,7 @@ requires-dist = [ { name = "msgspec", marker = "extra == 'serde'" }, { name = "mypy", marker = "extra == 'lint'", specifier = ">=1.11.2" }, { name = "openai-agents", marker = "extra == 'openai'", specifier = ">=0.6.1" }, + { name = "opentelemetry-api", marker = "extra == 'tracing'", specifier = ">=1.36.0" }, { name = "pydantic", marker = "extra == 'serde'" }, { name = "pydantic-ai-slim", marker = "extra == 'pydantic-ai'", specifier = ">=1.68.0" }, { name = "pyright", marker = "extra == 'lint'", specifier = ">=1.1.390" }, @@ -2532,7 +2537,7 @@ requires-dist = [ { name = "ruff", marker = "extra == 'lint'", specifier = ">=0.6.9" }, { name = "testcontainers", marker = "extra == 'harness'" }, ] -provides-extras = ["adk", "client", "harness", "lint", "openai", "pydantic-ai", "serde", "test"] +provides-extras = ["adk", "client", "harness", "lint", "openai", "pydantic-ai", "serde", "test", "tracing"] [[package]] name = "rpds-py"