Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ client = ["httpx[http2]"]
adk = ["google-adk>=1.20.0"]
openai = ["openai-agents>=0.6.1"]
pydantic_ai = ["pydantic-ai-slim>=1.68.0"]
tracing = ["opentelemetry-api>=1.36.0"]

[build-system]
requires = ["maturin>=1.6,<2.0"]
Expand Down
9 changes: 3 additions & 6 deletions python/restate/ext/adk/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,14 @@ def from_summarizer(
"""Create a RestateEventSummarizer wrapping a custom summarizer."""
return RestateEventSummarizer(summarizer, max_retries=max_retries)

async def maybe_summarize_events(
self, *, events: list[Event]
) -> Optional[Event]:
async def maybe_summarize_events(self, *, events: list[Event]) -> Optional[Event]:
if not events:
return None

ctx = current_context()
if ctx is None:
raise RuntimeError(
"No Restate context found. "
"RestateEventSummarizer must be used from within a Restate handler."
"No Restate context found. RestateEventSummarizer must be used from within a Restate handler."
)

inner = self._inner
Expand All @@ -92,4 +89,4 @@ async def call_inner() -> Optional[Event]:
max_attempts=self._max_retries,
initial_retry_interval=timedelta(seconds=1),
),
)
)
1 change: 1 addition & 0 deletions python/restate/ext/pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ._serde import PydanticTypeAdapter
from ._toolset import RestateContextRunToolSet


def restate_object_context() -> ObjectContext:
"""Get the current Restate ObjectContext."""
ctx = current_context()
Expand Down
3 changes: 3 additions & 0 deletions python/restate/ext/tracing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._tracing import RestateTracer, RestateTracerProvider

__all__ = ["RestateTracer", "RestateTracerProvider"]
146 changes: 146 additions & 0 deletions python/restate/ext/tracing/_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Restate OTEL tracer wrapper that flattens all spans under the Restate trace.

Wraps any tracer so that every span — regardless of framework nesting — becomes a
direct child of the Restate invocation trace. Works transparently with any
OTEL-integrated agent framework (Google ADK, Pydantic AI, OpenAI Agents, etc.).

Usage:
tracer = RestateTracer(trace_api.get_tracer("my-tracer"))
# All spans created by this tracer are flat children of the Restate trace.
"""

from opentelemetry.trace import INVALID_SPAN, use_span, Tracer, TracerProvider
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from restate.server_context import (
current_context,
get_extension_data,
set_extension_data,
restate_context_is_replaying,
)

_propagator = TraceContextTextMapPropagator()
_EXTENSION_KEY = "otel_span_cleanup"


class _SpanCleanup:
"""Stored as Restate extension data. ``__close__`` is called automatically
when the Restate invocation context is torn down, ending any spans that
were never properly closed (e.g. because the handler raised)."""

def __init__(self):
self._spans = []
Comment on lines +30 to +31
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you really need to collect spans like this? if you create the parent span for hte invocation attempt, and close that one, all the child spans should be closed as well.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gvdongen worth to double check if this creates an issues?


def track(self, span):
self._spans.append(span)

def __close__(self):
for span in self._spans:
if span.is_recording():
span.end()
self._spans.clear()


class RestateTracer(Tracer):
"""Wraps a ``Tracer`` to always parent spans under the Restate root context.

During Restate replay, returns no-op spans to avoid duplicates."""

def __init__(self, tracer):
self._tracer = tracer

@staticmethod
def _get_root_context():
"""Extract the Restate trace parent from the current handler, or None."""
ctx = current_context()
if ctx is None:
raise Exception("You are not in a Restate handler")
return _propagator.extract(ctx.request().attempt_headers)

def start_span(
self,
name,
context=None,
kind=None,
attributes=None,
links=None,
start_time=None,
record_exception=True,
set_status_on_exception=True,
):
if restate_context_is_replaying.get(False):
return INVALID_SPAN
root = self._get_root_context()
if root is not None:
context = root
span = self._tracer.start_span(
name,
context=context,
kind=kind,
attributes=attributes,
links=links,
start_time=start_time,
record_exception=record_exception,
set_status_on_exception=set_status_on_exception,
)
self._track_span(span)
return span

def start_as_current_span(
self,
name,
context=None,
kind=None,
attributes=None,
links=None,
start_time=None,
record_exception=True,
set_status_on_exception=True,
end_on_exit=True,
):
if restate_context_is_replaying.get(False):
return use_span(INVALID_SPAN, end_on_exit=False)
root = self._get_root_context()
if root is not None:
context = root
return self._tracer.start_as_current_span(
name,
context=context,
kind=kind,
attributes=attributes,
links=links,
start_time=start_time,
record_exception=record_exception,
set_status_on_exception=set_status_on_exception,
end_on_exit=end_on_exit,
)

@staticmethod
def _track_span(span):
"""Register a span for cleanup when the Restate invocation ends."""
ctx = current_context()
if ctx is None:
return
cleanup = get_extension_data(ctx, _EXTENSION_KEY)
if cleanup is None:
cleanup = _SpanCleanup()
set_extension_data(ctx, _EXTENSION_KEY, cleanup)
cleanup.track(span)

def __getattr__(self, name):
return getattr(self._tracer, name)


class RestateTracerProvider(TracerProvider):
"""Wraps a ``TracerProvider`` to return ``RestateTracer`` instances.

Pass this to instrumentors (e.g. ``GoogleADKInstrumentor``) so that every
span they create is automatically parented under the Restate invocation."""

def __init__(self, provider):
self._provider = provider

def get_tracer(self, *args, **kwargs):
return RestateTracer(self._provider.get_tracer(*args, **kwargs))

def __getattr__(self, name):
return getattr(self._provider, name)
7 changes: 6 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading