Skip to content
Open
93 changes: 63 additions & 30 deletions src/agentex/lib/core/services/adk/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ def __init__(self, on_flush: Callable[[StreamTaskMessageDelta], Awaitable[object
self._first_flushed = False
self._closed = False
self._lock = asyncio.Lock()
self._flush_signal = asyncio.Event()
# _wake lets the ticker park at zero CPU when idle (set on empty ->
# non-empty); _flush_now bypasses the coalescing window (first delta /
# size threshold / close).
self._wake = asyncio.Event()
self._flush_now = asyncio.Event()
self._task: asyncio.Task[None] | None = None

def start(self) -> None:
Expand All @@ -177,22 +181,36 @@ async def add(self, update: StreamTaskMessageDelta) -> None:
if self._closed:
return
async with self._lock:
was_empty = not self._buf
self._buf.append(update)
self._buf_chars += _delta_char_len(update.delta)
if not self._first_flushed or self._buf_chars >= self.MAX_BUFFERED_CHARS:
self._first_flushed = True
self._flush_signal.set()
self._flush_now.set()
# Unpark the ticker; it applies the coalescing window itself.
if was_empty:
self._wake.set()

async def _run(self) -> None:
try:
while True:
try:
await asyncio.wait_for(self._flush_signal.wait(), timeout=self.FLUSH_INTERVAL_S)
except asyncio.TimeoutError:
pass
# Park at zero CPU until there is data (or close()). A
# fixed-interval ticker instead leaked CPU on buffers orphaned
# without close() — one task spinning every FLUSH_INTERVAL_S.
await self._wake.wait()
self._wake.clear()
# Coalesce for up to FLUSH_INTERVAL_S unless an immediate flush
# is already pending.
if not self._flush_now.is_set() and not self._closed:
try:
await asyncio.wait_for(self._flush_now.wait(), timeout=self.FLUSH_INTERVAL_S)
except asyncio.TimeoutError:
pass
async with self._lock:
self._flush_signal.clear()
self._flush_now.clear()
drained = self._drain_locked()
# Deltas arriving during the _on_flush awaits below re-arm the
# ticker via add(), so they get flushed on the next loop.
for u in drained:
try:
await self._on_flush(u)
Expand All @@ -215,13 +233,14 @@ async def close(self) -> None:
# producing the duplicate-tail symptom seen on the UI stream.
self._closed = True
if self._task is not None:
self._flush_signal.set()
try:
await self._task
except asyncio.CancelledError:
# Propagate if our caller is being cancelled; the task itself
# swallows CancelledError so this only fires on outer cancel.
raise
self._wake.set()
self._flush_now.set()
# Shield the ticker: if close() is cancelled, awaiting the task bare
# would propagate the cancel into the ticker mid-flush, and _run
# swallows CancelledError — silently dropping an already-drained
# batch. Shielded, the ticker finishes its in-flight flush and exits
# on _closed; the cancel still propagates to our caller.
await asyncio.shield(self._task)
self._task = None
async with self._lock:
drained = self._drain_locked()
Expand Down Expand Up @@ -415,26 +434,30 @@ async def open(self) -> "StreamingTaskMessageContext":

return self

async def close(self) -> TaskMessage:
"""Close the streaming context."""
async def close(self, done_event: StreamTaskMessageDone | None = None) -> TaskMessage:
"""Close the streaming context.

``done_event`` is the caller-provided terminal update when close is
driven by a streamed ``StreamTaskMessageDone`` — published as-is so its
``index``/``parent_task_message`` survive. An implicit close (``__aexit__``)
passes nothing and a terminal Done is synthesized.
"""
if not self.task_message:
raise ValueError("Context not properly initialized - no task message")

if self._is_closed:
return self.task_message # Already done

# Drain any buffered deltas before announcing DONE so consumers see the
# full sequence in order.
# Reap the buffer ticker before the _is_closed short-circuit, so a
# context already marked done by another path can't orphan it.
if self._buffer is not None:
await self._buffer.close()
self._buffer = None

# Send the DONE event
done_event = StreamTaskMessageDone(
parent_task_message=self.task_message,
type="done",
if self._is_closed:
return self.task_message # Already done (buffer reaped above)

# Send the DONE event (the caller's, if provided, so its metadata survives).
await self._streaming_service.stream_update(
done_event or StreamTaskMessageDone(parent_task_message=self.task_message, type="done")
)
await self._streaming_service.stream_update(done_event)

# Update the task message with the final content
has_deltas = (
Expand Down Expand Up @@ -486,12 +509,22 @@ async def stream_update(self, update: TaskMessageUpdate) -> TaskMessageUpdate |
await self._buffer.add(update)
return update

result = await self._streaming_service.stream_update(update)

if isinstance(update, StreamTaskMessageDone):
await self.close()
# close() drains the buffer first, then publishes this exact Done
# once (preserving its index/parent), persists, and marks closed.
await self.close(done_event=update)
return update
Comment thread
greptile-apps[bot] marked this conversation as resolved.
elif isinstance(update, StreamTaskMessageFull):

# Full publishes below, so drain and stop the buffer first → leftover
# deltas land in order (deltas -> Full) instead of trailing the terminal
# Full as a stale duplicate tail. Also stops the ticker.
if isinstance(update, StreamTaskMessageFull) and self._buffer is not None:
await self._buffer.close()
self._buffer = None

result = await self._streaming_service.stream_update(update)

if isinstance(update, StreamTaskMessageFull):
await self._agentex_client.messages.update(
task_id=self.task_id,
message_id=update.parent_task_message.id, # type: ignore[union-attr]
Expand Down
172 changes: 171 additions & 1 deletion tests/lib/core/services/adk/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
ToolResponseDelta,
ReasoningSummaryDelta,
)
from agentex.types.task_message_update import StreamTaskMessageDelta
from agentex.types.task_message_update import (
StreamTaskMessageDone,
StreamTaskMessageFull,
StreamTaskMessageDelta,
)
from agentex.lib.core.services.adk.streaming import (
CoalescingBuffer,
StreamingTaskMessageContext,
Expand Down Expand Up @@ -303,6 +307,50 @@ async def on_flush(u: StreamTaskMessageDelta) -> None:
await buf.close()


class TestCoalescingBufferIdleParks:
"""The ticker must park on its wake event when idle, not poll every
FLUSH_INTERVAL_S — a buffer orphaned without close() otherwise pins CPU."""

@staticmethod
def _count_drains(buf: CoalescingBuffer) -> list[int]:
"""Instrument _drain_locked to count drain cycles."""
n = [0]
orig = buf._drain_locked

def counting() -> list[StreamTaskMessageDelta]:
n[0] += 1
return orig()

buf._drain_locked = counting # type: ignore[method-assign]
return n

@pytest.mark.asyncio
async def test_idle_buffer_does_not_spin(self) -> None:
buf = CoalescingBuffer(on_flush=AsyncMock())
drains = self._count_drains(buf)
buf.start()
try:
# ~8 FLUSH_INTERVAL_S windows; a polling ticker would drain ~8x.
await asyncio.sleep(0.4)
assert drains[0] == 0, f"idle ticker woke {drains[0]}x (must park at 0)"
finally:
await buf.close()

@pytest.mark.asyncio
async def test_orphaned_buffer_parks_after_flush(self, task_message: TaskMessage) -> None:
"""An orphaned buffer (close() never runs) must still park once drained."""
buf = CoalescingBuffer(on_flush=AsyncMock())
buf.start()
try:
await buf.add(_text(task_message, "hi"))
await asyncio.sleep(0.020) # let the immediate flush land and park
drains = self._count_drains(buf) # count only post-flush cycles
await asyncio.sleep(0.4)
assert drains[0] == 0, f"orphaned ticker woke {drains[0]}x (must park at 0)"
finally:
await buf.close()


class TestCoalescingBufferClose:
@pytest.mark.asyncio
async def test_close_drains_remaining_buffered_items(self, task_message: TaskMessage) -> None:
Expand Down Expand Up @@ -352,6 +400,35 @@ async def on_flush(u: StreamTaskMessageDelta) -> None:
await buf.add(_text(task_message, "after"))
assert flushed == []

@pytest.mark.asyncio
async def test_cancelled_close_does_not_drop_in_flight_batch(self, task_message: TaskMessage) -> None:
"""If close() is cancelled while the ticker is mid-flush, the already-
drained batch must still publish — not be lost to a force-cancel."""
flushed: list[StreamTaskMessageDelta] = []
gate = asyncio.Event()
entered = asyncio.Event()

async def on_flush(u: StreamTaskMessageDelta) -> None:
entered.set()
await gate.wait() # block mid-flush, before the item is recorded
flushed.append(u)

buf = CoalescingBuffer(on_flush=on_flush)
buf.start()
await buf.add(_text(task_message, "hi")) # first delta → immediate flush
await entered.wait() # ticker is now blocked inside on_flush

close_task = asyncio.create_task(buf.close())
await asyncio.sleep(0) # let close() reach `await self._task`
close_task.cancel()
with pytest.raises(asyncio.CancelledError):
await close_task

gate.set() # release the in-flight flush
assert buf._task is not None
await buf._task # ticker finishes the batch and exits on _closed
assert len(flushed) == 1, "in-flight batch was dropped on cancelled close()"


class TestCoalescingBufferCloseDuringFlush:
@pytest.mark.asyncio
Expand Down Expand Up @@ -520,3 +597,96 @@ async def test_open_without_created_at_passes_omit(self) -> None:

kwargs = client.messages.create.call_args.kwargs
assert kwargs["created_at"] is omit


class TestFullMessageClosesBuffer:
"""A StreamTaskMessageFull must stop the buffer ticker. If it marks the
context done without closing the buffer, close()'s _is_closed short-circuit
leaves the ticker orphaned (the worker CPU leak)."""

@pytest.mark.asyncio
async def test_full_message_stops_ticker(self) -> None:
ctx, _svc, tm = await _make_context("coalesced")
# A delta makes the buffer and its ticker live.
await ctx.stream_update(_text(tm, "hello"))
buf = ctx._buffer
assert buf is not None
task = buf._task
assert task is not None and not task.done()

await ctx.stream_update(
StreamTaskMessageFull(
parent_task_message=tm,
content=TextContent(author="agent", content="final", format="markdown"),
type="full",
)
)

assert ctx._buffer is None, "Full message left the buffer un-closed"
assert task.done(), "coalescing-buffer ticker still running after Full (orphaned)"

@pytest.mark.asyncio
async def test_full_is_terminal_publish_no_trailing_deltas(self) -> None:
# Buffered deltas must publish BEFORE the Full, never after (a trailing
# delta after the terminal Full reads as a stale duplicate tail).
ctx, svc, tm = await _make_context("coalesced")
# "alpha" flushes immediately; "beta" stays buffered in the window.
await ctx.stream_update(_text(tm, "alpha"))
await ctx.stream_update(_text(tm, "beta"))

full = StreamTaskMessageFull(
parent_task_message=tm,
content=TextContent(author="agent", content="alphabeta", format="markdown"),
type="full",
)
await ctx.stream_update(full)

published = [c.args[0] for c in svc.stream_update.await_args_list]
assert published, "nothing was published"
assert published[-1] is full, (
f"Full must be the terminal publish; saw trailing "
f"{type(published[-1]).__name__} after it (stale duplicate tail)"
)
assert any(isinstance(u, StreamTaskMessageDelta) for u in published[:-1]), (
"expected the buffered deltas to be published before the Full"
)

@pytest.mark.asyncio
async def test_done_is_single_terminal_publish_no_trailing_deltas(self) -> None:
# Same guarantee as Full: buffered deltas publish BEFORE the terminal
# Done, Done is published exactly once (not duplicated by close()), and
# the caller's Done is published as-is so its metadata (index) survives.
ctx, svc, tm = await _make_context("coalesced")
# "alpha" flushes immediately; "beta" stays buffered in the window.
await ctx.stream_update(_text(tm, "alpha"))
await ctx.stream_update(_text(tm, "beta"))

done = StreamTaskMessageDone(parent_task_message=tm, type="done", index=7)
await ctx.stream_update(done)

published = [c.args[0] for c in svc.stream_update.await_args_list]
dones = [u for u in published if isinstance(u, StreamTaskMessageDone)]
assert len(dones) == 1, f"Done must publish exactly once, saw {len(dones)}"
assert published[-1] is done, (
f"the caller's Done must be the terminal publish (metadata preserved); "
f"saw trailing {type(published[-1]).__name__}"
)
assert dones[0].index == 7, "caller's Done index must be preserved, not synthesized away"
assert any(isinstance(u, StreamTaskMessageDelta) for u in published[:-1]), (
"expected the buffered deltas to be published before the Done"
)

@pytest.mark.asyncio
async def test_close_reaps_buffer_even_if_already_marked_closed(self) -> None:
# close() must stop the ticker even when _is_closed is already set.
ctx, _svc, tm = await _make_context("coalesced")
await ctx.stream_update(_text(tm, "hi"))
buf = ctx._buffer
assert buf is not None
task = buf._task
assert task is not None and not task.done()

ctx._is_closed = True # stray "already done" mark with a live buffer
await ctx.close()

assert task.done(), "close() must reap the buffer even when already marked closed"
Loading