diff --git a/src/agentex/lib/core/services/adk/streaming.py b/src/agentex/lib/core/services/adk/streaming.py index 7215f084c..c6ab24503 100644 --- a/src/agentex/lib/core/services/adk/streaming.py +++ b/src/agentex/lib/core/services/adk/streaming.py @@ -166,7 +166,11 @@ def __init__(self, on_flush: Callable[[StreamTaskMessageDelta], Awaitable[object self._first_flushed = False self._closed = False self._lock = asyncio.Lock() - self._flush_signal = asyncio.Event() + # _wake lets the ticker park at zero CPU when idle (set on empty -> + # non-empty); _flush_now bypasses the coalescing window (first delta / + # size threshold / close). + self._wake = asyncio.Event() + self._flush_now = asyncio.Event() self._task: asyncio.Task[None] | None = None def start(self) -> None: @@ -177,22 +181,36 @@ async def add(self, update: StreamTaskMessageDelta) -> None: if self._closed: return async with self._lock: + was_empty = not self._buf self._buf.append(update) self._buf_chars += _delta_char_len(update.delta) if not self._first_flushed or self._buf_chars >= self.MAX_BUFFERED_CHARS: self._first_flushed = True - self._flush_signal.set() + self._flush_now.set() + # Unpark the ticker; it applies the coalescing window itself. + if was_empty: + self._wake.set() async def _run(self) -> None: try: while True: - try: - await asyncio.wait_for(self._flush_signal.wait(), timeout=self.FLUSH_INTERVAL_S) - except asyncio.TimeoutError: - pass + # Park at zero CPU until there is data (or close()). A + # fixed-interval ticker instead leaked CPU on buffers orphaned + # without close() — one task spinning every FLUSH_INTERVAL_S. + await self._wake.wait() + self._wake.clear() + # Coalesce for up to FLUSH_INTERVAL_S unless an immediate flush + # is already pending. + if not self._flush_now.is_set() and not self._closed: + try: + await asyncio.wait_for(self._flush_now.wait(), timeout=self.FLUSH_INTERVAL_S) + except asyncio.TimeoutError: + pass async with self._lock: - self._flush_signal.clear() + self._flush_now.clear() drained = self._drain_locked() + # Deltas arriving during the _on_flush awaits below re-arm the + # ticker via add(), so they get flushed on the next loop. for u in drained: try: await self._on_flush(u) @@ -215,13 +233,14 @@ async def close(self) -> None: # producing the duplicate-tail symptom seen on the UI stream. self._closed = True if self._task is not None: - self._flush_signal.set() - try: - await self._task - except asyncio.CancelledError: - # Propagate if our caller is being cancelled; the task itself - # swallows CancelledError so this only fires on outer cancel. - raise + self._wake.set() + self._flush_now.set() + # Shield the ticker: if close() is cancelled, awaiting the task bare + # would propagate the cancel into the ticker mid-flush, and _run + # swallows CancelledError — silently dropping an already-drained + # batch. Shielded, the ticker finishes its in-flight flush and exits + # on _closed; the cancel still propagates to our caller. + await asyncio.shield(self._task) self._task = None async with self._lock: drained = self._drain_locked() @@ -415,26 +434,30 @@ async def open(self) -> "StreamingTaskMessageContext": return self - async def close(self) -> TaskMessage: - """Close the streaming context.""" + async def close(self, done_event: StreamTaskMessageDone | None = None) -> TaskMessage: + """Close the streaming context. + + ``done_event`` is the caller-provided terminal update when close is + driven by a streamed ``StreamTaskMessageDone`` — published as-is so its + ``index``/``parent_task_message`` survive. An implicit close (``__aexit__``) + passes nothing and a terminal Done is synthesized. + """ if not self.task_message: raise ValueError("Context not properly initialized - no task message") - if self._is_closed: - return self.task_message # Already done - - # Drain any buffered deltas before announcing DONE so consumers see the - # full sequence in order. + # Reap the buffer ticker before the _is_closed short-circuit, so a + # context already marked done by another path can't orphan it. if self._buffer is not None: await self._buffer.close() self._buffer = None - # Send the DONE event - done_event = StreamTaskMessageDone( - parent_task_message=self.task_message, - type="done", + if self._is_closed: + return self.task_message # Already done (buffer reaped above) + + # Send the DONE event (the caller's, if provided, so its metadata survives). + await self._streaming_service.stream_update( + done_event or StreamTaskMessageDone(parent_task_message=self.task_message, type="done") ) - await self._streaming_service.stream_update(done_event) # Update the task message with the final content has_deltas = ( @@ -486,12 +509,22 @@ async def stream_update(self, update: TaskMessageUpdate) -> TaskMessageUpdate | await self._buffer.add(update) return update - result = await self._streaming_service.stream_update(update) - if isinstance(update, StreamTaskMessageDone): - await self.close() + # close() drains the buffer first, then publishes this exact Done + # once (preserving its index/parent), persists, and marks closed. + await self.close(done_event=update) return update - elif isinstance(update, StreamTaskMessageFull): + + # Full publishes below, so drain and stop the buffer first → leftover + # deltas land in order (deltas -> Full) instead of trailing the terminal + # Full as a stale duplicate tail. Also stops the ticker. + if isinstance(update, StreamTaskMessageFull) and self._buffer is not None: + await self._buffer.close() + self._buffer = None + + result = await self._streaming_service.stream_update(update) + + if isinstance(update, StreamTaskMessageFull): await self._agentex_client.messages.update( task_id=self.task_id, message_id=update.parent_task_message.id, # type: ignore[union-attr] diff --git a/tests/lib/core/services/adk/test_streaming.py b/tests/lib/core/services/adk/test_streaming.py index b07c55f74..3145ef8a6 100644 --- a/tests/lib/core/services/adk/test_streaming.py +++ b/tests/lib/core/services/adk/test_streaming.py @@ -22,7 +22,11 @@ ToolResponseDelta, ReasoningSummaryDelta, ) -from agentex.types.task_message_update import StreamTaskMessageDelta +from agentex.types.task_message_update import ( + StreamTaskMessageDone, + StreamTaskMessageFull, + StreamTaskMessageDelta, +) from agentex.lib.core.services.adk.streaming import ( CoalescingBuffer, StreamingTaskMessageContext, @@ -303,6 +307,50 @@ async def on_flush(u: StreamTaskMessageDelta) -> None: await buf.close() +class TestCoalescingBufferIdleParks: + """The ticker must park on its wake event when idle, not poll every + FLUSH_INTERVAL_S — a buffer orphaned without close() otherwise pins CPU.""" + + @staticmethod + def _count_drains(buf: CoalescingBuffer) -> list[int]: + """Instrument _drain_locked to count drain cycles.""" + n = [0] + orig = buf._drain_locked + + def counting() -> list[StreamTaskMessageDelta]: + n[0] += 1 + return orig() + + buf._drain_locked = counting # type: ignore[method-assign] + return n + + @pytest.mark.asyncio + async def test_idle_buffer_does_not_spin(self) -> None: + buf = CoalescingBuffer(on_flush=AsyncMock()) + drains = self._count_drains(buf) + buf.start() + try: + # ~8 FLUSH_INTERVAL_S windows; a polling ticker would drain ~8x. + await asyncio.sleep(0.4) + assert drains[0] == 0, f"idle ticker woke {drains[0]}x (must park at 0)" + finally: + await buf.close() + + @pytest.mark.asyncio + async def test_orphaned_buffer_parks_after_flush(self, task_message: TaskMessage) -> None: + """An orphaned buffer (close() never runs) must still park once drained.""" + buf = CoalescingBuffer(on_flush=AsyncMock()) + buf.start() + try: + await buf.add(_text(task_message, "hi")) + await asyncio.sleep(0.020) # let the immediate flush land and park + drains = self._count_drains(buf) # count only post-flush cycles + await asyncio.sleep(0.4) + assert drains[0] == 0, f"orphaned ticker woke {drains[0]}x (must park at 0)" + finally: + await buf.close() + + class TestCoalescingBufferClose: @pytest.mark.asyncio async def test_close_drains_remaining_buffered_items(self, task_message: TaskMessage) -> None: @@ -352,6 +400,35 @@ async def on_flush(u: StreamTaskMessageDelta) -> None: await buf.add(_text(task_message, "after")) assert flushed == [] + @pytest.mark.asyncio + async def test_cancelled_close_does_not_drop_in_flight_batch(self, task_message: TaskMessage) -> None: + """If close() is cancelled while the ticker is mid-flush, the already- + drained batch must still publish — not be lost to a force-cancel.""" + flushed: list[StreamTaskMessageDelta] = [] + gate = asyncio.Event() + entered = asyncio.Event() + + async def on_flush(u: StreamTaskMessageDelta) -> None: + entered.set() + await gate.wait() # block mid-flush, before the item is recorded + flushed.append(u) + + buf = CoalescingBuffer(on_flush=on_flush) + buf.start() + await buf.add(_text(task_message, "hi")) # first delta → immediate flush + await entered.wait() # ticker is now blocked inside on_flush + + close_task = asyncio.create_task(buf.close()) + await asyncio.sleep(0) # let close() reach `await self._task` + close_task.cancel() + with pytest.raises(asyncio.CancelledError): + await close_task + + gate.set() # release the in-flight flush + assert buf._task is not None + await buf._task # ticker finishes the batch and exits on _closed + assert len(flushed) == 1, "in-flight batch was dropped on cancelled close()" + class TestCoalescingBufferCloseDuringFlush: @pytest.mark.asyncio @@ -520,3 +597,96 @@ async def test_open_without_created_at_passes_omit(self) -> None: kwargs = client.messages.create.call_args.kwargs assert kwargs["created_at"] is omit + + +class TestFullMessageClosesBuffer: + """A StreamTaskMessageFull must stop the buffer ticker. If it marks the + context done without closing the buffer, close()'s _is_closed short-circuit + leaves the ticker orphaned (the worker CPU leak).""" + + @pytest.mark.asyncio + async def test_full_message_stops_ticker(self) -> None: + ctx, _svc, tm = await _make_context("coalesced") + # A delta makes the buffer and its ticker live. + await ctx.stream_update(_text(tm, "hello")) + buf = ctx._buffer + assert buf is not None + task = buf._task + assert task is not None and not task.done() + + await ctx.stream_update( + StreamTaskMessageFull( + parent_task_message=tm, + content=TextContent(author="agent", content="final", format="markdown"), + type="full", + ) + ) + + assert ctx._buffer is None, "Full message left the buffer un-closed" + assert task.done(), "coalescing-buffer ticker still running after Full (orphaned)" + + @pytest.mark.asyncio + async def test_full_is_terminal_publish_no_trailing_deltas(self) -> None: + # Buffered deltas must publish BEFORE the Full, never after (a trailing + # delta after the terminal Full reads as a stale duplicate tail). + ctx, svc, tm = await _make_context("coalesced") + # "alpha" flushes immediately; "beta" stays buffered in the window. + await ctx.stream_update(_text(tm, "alpha")) + await ctx.stream_update(_text(tm, "beta")) + + full = StreamTaskMessageFull( + parent_task_message=tm, + content=TextContent(author="agent", content="alphabeta", format="markdown"), + type="full", + ) + await ctx.stream_update(full) + + published = [c.args[0] for c in svc.stream_update.await_args_list] + assert published, "nothing was published" + assert published[-1] is full, ( + f"Full must be the terminal publish; saw trailing " + f"{type(published[-1]).__name__} after it (stale duplicate tail)" + ) + assert any(isinstance(u, StreamTaskMessageDelta) for u in published[:-1]), ( + "expected the buffered deltas to be published before the Full" + ) + + @pytest.mark.asyncio + async def test_done_is_single_terminal_publish_no_trailing_deltas(self) -> None: + # Same guarantee as Full: buffered deltas publish BEFORE the terminal + # Done, Done is published exactly once (not duplicated by close()), and + # the caller's Done is published as-is so its metadata (index) survives. + ctx, svc, tm = await _make_context("coalesced") + # "alpha" flushes immediately; "beta" stays buffered in the window. + await ctx.stream_update(_text(tm, "alpha")) + await ctx.stream_update(_text(tm, "beta")) + + done = StreamTaskMessageDone(parent_task_message=tm, type="done", index=7) + await ctx.stream_update(done) + + published = [c.args[0] for c in svc.stream_update.await_args_list] + dones = [u for u in published if isinstance(u, StreamTaskMessageDone)] + assert len(dones) == 1, f"Done must publish exactly once, saw {len(dones)}" + assert published[-1] is done, ( + f"the caller's Done must be the terminal publish (metadata preserved); " + f"saw trailing {type(published[-1]).__name__}" + ) + assert dones[0].index == 7, "caller's Done index must be preserved, not synthesized away" + assert any(isinstance(u, StreamTaskMessageDelta) for u in published[:-1]), ( + "expected the buffered deltas to be published before the Done" + ) + + @pytest.mark.asyncio + async def test_close_reaps_buffer_even_if_already_marked_closed(self) -> None: + # close() must stop the ticker even when _is_closed is already set. + ctx, _svc, tm = await _make_context("coalesced") + await ctx.stream_update(_text(tm, "hi")) + buf = ctx._buffer + assert buf is not None + task = buf._task + assert task is not None and not task.done() + + ctx._is_closed = True # stray "already done" mark with a live buffer + await ctx.close() + + assert task.done(), "close() must reap the buffer even when already marked closed"