diff --git a/src/agentex/lib/core/services/adk/streaming.py b/src/agentex/lib/core/services/adk/streaming.py index 7215f084c..33ca7bc1c 100644 --- a/src/agentex/lib/core/services/adk/streaming.py +++ b/src/agentex/lib/core/services/adk/streaming.py @@ -177,6 +177,12 @@ async def add(self, update: StreamTaskMessageDelta) -> None: if self._closed: return async with self._lock: + # Re-check under the lock: a concurrent close() (e.g. from a racing + # Full) may have drained and shut down the ticker after the check + # above but before we acquired the lock. Appending now would strand + # the delta in a dead buffer, never published. + if self._closed: + return self._buf.append(update) self._buf_chars += _delta_char_len(update.delta) if not self._first_flushed or self._buf_chars >= self.MAX_BUFFERED_CHARS: @@ -415,19 +421,28 @@ async def open(self) -> "StreamingTaskMessageContext": return self + async def _reap_buffer(self) -> None: + """Drain and stop the coalescing buffer, releasing its background ticker. + + Idempotent: a no-op once the buffer has already been reaped. + """ + if self._buffer is not None: + await self._buffer.close() + self._buffer = None + async def close(self) -> TaskMessage: """Close the streaming context.""" if not self.task_message: raise ValueError("Context not properly initialized - no task message") - if self._is_closed: - return self.task_message # Already done + # Reap the buffer (stopping its ticker) before the _is_closed + # short-circuit, so a context already marked done by a Full update can't + # leave the ticker orphaned. Draining here also lets consumers see the + # full delta sequence in order before DONE. + await self._reap_buffer() - # Drain any buffered deltas before announcing DONE so consumers see the - # full sequence in order. - if self._buffer is not None: - await self._buffer.close() - self._buffer = None + if self._is_closed: + return self.task_message # Already done (buffer reaped above) # Send the DONE event done_event = StreamTaskMessageDone( @@ -486,6 +501,14 @@ async def stream_update(self, update: TaskMessageUpdate) -> TaskMessageUpdate | await self._buffer.add(update) return update + # A Full ends the stream and supersedes buffered deltas. Drain and stop + # the buffer BEFORE publishing the Full, so leftover deltas land in order + # (deltas -> Full) instead of trailing the terminal Full as a stale + # duplicate tail. This also stops the ticker, which would otherwise be + # orphaned when __aexit__'s close() short-circuits on _is_closed. + if isinstance(update, StreamTaskMessageFull): + await self._reap_buffer() + result = await self._streaming_service.stream_update(update) if isinstance(update, StreamTaskMessageDone): diff --git a/tests/lib/core/services/adk/test_streaming.py b/tests/lib/core/services/adk/test_streaming.py index b07c55f74..a8068f307 100644 --- a/tests/lib/core/services/adk/test_streaming.py +++ b/tests/lib/core/services/adk/test_streaming.py @@ -22,7 +22,10 @@ ToolResponseDelta, ReasoningSummaryDelta, ) -from agentex.types.task_message_update import StreamTaskMessageDelta +from agentex.types.task_message_update import ( + StreamTaskMessageFull, + StreamTaskMessageDelta, +) from agentex.lib.core.services.adk.streaming import ( CoalescingBuffer, StreamingTaskMessageContext, @@ -352,6 +355,24 @@ async def on_flush(u: StreamTaskMessageDelta) -> None: await buf.add(_text(task_message, "after")) assert flushed == [] + @pytest.mark.asyncio + async def test_add_racing_close_is_not_stranded(self, task_message: TaskMessage) -> None: + """TOCTOU: a delta that passes add()'s pre-lock _closed check but only + acquires the lock after close() set _closed must be dropped, not appended + to a drained, ticker-less buffer where it would never be published.""" + buf = CoalescingBuffer(on_flush=AsyncMock()) + buf.start() + # Hold the lock so add() parks *after* its pre-lock _closed check. + await buf._lock.acquire() + add_task = asyncio.create_task(buf.add(_text(task_message, "racing"))) + await asyncio.sleep(0) # add() passes the _closed check, blocks on the lock + buf._closed = True # close() wins the race + buf._lock.release() + await add_task + + assert buf._buf == [], "racing delta was stranded in the closed buffer" + await buf.close() # cleanup + class TestCoalescingBufferCloseDuringFlush: @pytest.mark.asyncio @@ -520,3 +541,59 @@ async def test_open_without_created_at_passes_omit(self) -> None: kwargs = client.messages.create.call_args.kwargs assert kwargs["created_at"] is omit + + +class TestFullMessageClosesBuffer: + """A StreamTaskMessageFull must stop the buffer ticker and drain its deltas + before the terminal Full. Marking the context done without closing the + buffer leaves close()'s _is_closed short-circuit to orphan the ticker, and + publishing buffered deltas after the Full reads as a stale duplicate tail.""" + + @pytest.mark.asyncio + async def test_full_message_stops_ticker(self) -> None: + ctx, _svc, tm = await _make_context("coalesced") + # A delta makes the buffer and its ticker live. + await ctx.stream_update(_text(tm, "hello")) + buf = ctx._buffer + assert buf is not None + task = buf._task + assert task is not None and not task.done() + + await ctx.stream_update( + StreamTaskMessageFull( + parent_task_message=tm, + content=TextContent(author="agent", content="final", format="markdown"), + type="full", + ) + ) + + assert ctx._buffer is None, "Full message left the buffer un-closed" + assert task.done(), "coalescing-buffer ticker still running after Full (orphaned)" + + @pytest.mark.asyncio + async def test_full_is_terminal_publish_no_trailing_deltas(self) -> None: + # Buffered deltas must publish BEFORE the Full, never after (a trailing + # delta after the terminal Full reads as a stale duplicate tail). + ctx, svc, tm = await _make_context("coalesced") + # Two deltas through the buffer. Regardless of how the coalescing window + # batches them (1 or 2 publishes), the invariant under test is the same: + # every delta publishes before the terminal Full, never after it. + await ctx.stream_update(_text(tm, "alpha")) + await ctx.stream_update(_text(tm, "beta")) + + full = StreamTaskMessageFull( + parent_task_message=tm, + content=TextContent(author="agent", content="alphabeta", format="markdown"), + type="full", + ) + await ctx.stream_update(full) + + published = [c.args[0] for c in svc.stream_update.await_args_list] + assert published, "nothing was published" + assert published[-1] is full, ( + f"Full must be the terminal publish; saw trailing " + f"{type(published[-1]).__name__} after it (stale duplicate tail)" + ) + assert any(isinstance(u, StreamTaskMessageDelta) for u in published[:-1]), ( + "expected the buffered deltas to be published before the Full" + )