-
Notifications
You must be signed in to change notification settings - Fork 9
fix(streaming): StreamTaskMessageFull closes the coalescing buffer #426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: next
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -177,6 +177,12 @@ async def add(self, update: StreamTaskMessageDelta) -> None: | |
| if self._closed: | ||
| return | ||
| async with self._lock: | ||
| # Re-check under the lock: a concurrent close() (e.g. from a racing | ||
| # Full) may have drained and shut down the ticker after the check | ||
| # above but before we acquired the lock. Appending now would strand | ||
| # the delta in a dead buffer, never published. | ||
| if self._closed: | ||
| return | ||
| self._buf.append(update) | ||
| self._buf_chars += _delta_char_len(update.delta) | ||
| if not self._first_flushed or self._buf_chars >= self.MAX_BUFFERED_CHARS: | ||
|
|
@@ -420,15 +426,17 @@ async def close(self) -> TaskMessage: | |
| if not self.task_message: | ||
| raise ValueError("Context not properly initialized - no task message") | ||
|
|
||
| if self._is_closed: | ||
| return self.task_message # Already done | ||
|
|
||
| # Drain any buffered deltas before announcing DONE so consumers see the | ||
| # full sequence in order. | ||
| # Reap the buffer (stopping its ticker) before the _is_closed | ||
| # short-circuit, so a context already marked done by a Full update can't | ||
| # leave the ticker orphaned. Draining here also lets consumers see the | ||
| # full delta sequence in order before DONE. | ||
| if self._buffer is not None: | ||
| await self._buffer.close() | ||
| self._buffer = None | ||
|
|
||
| if self._is_closed: | ||
| return self.task_message # Already done (buffer reaped above) | ||
|
|
||
| # Send the DONE event | ||
| done_event = StreamTaskMessageDone( | ||
| parent_task_message=self.task_message, | ||
|
|
@@ -486,6 +494,15 @@ async def stream_update(self, update: TaskMessageUpdate) -> TaskMessageUpdate | | |
| await self._buffer.add(update) | ||
| return update | ||
|
|
||
| # A Full ends the stream and supersedes buffered deltas. Drain and stop | ||
| # the buffer BEFORE publishing the Full, so leftover deltas land in order | ||
| # (deltas -> Full) instead of trailing the terminal Full as a stale | ||
| # duplicate tail. This also stops the ticker, which would otherwise be | ||
| # orphaned when __aexit__'s close() short-circuits on _is_closed. | ||
| if isinstance(update, StreamTaskMessageFull) and self._buffer is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw that Greptile also commented about this and looks like it resolved its own comment but it looks to me like it's still possible that we could have a race here? e.g.
In this case, consumers will still see a delta after full since it just falls through to publishing directly. If this is something we care about (and my logic checks out and you also believe there's a race), we could have a terminal-in-progress flag like _is_closing set before awaiting buffer close to reject deltas? |
||
| await self._buffer.close() | ||
| self._buffer = None | ||
|
greptile-apps[bot] marked this conversation as resolved.
|
||
|
|
||
| result = await self._streaming_service.stream_update(update) | ||
|
|
||
| if isinstance(update, StreamTaskMessageDone): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,10 @@ | |
| ToolResponseDelta, | ||
| ReasoningSummaryDelta, | ||
| ) | ||
| from agentex.types.task_message_update import StreamTaskMessageDelta | ||
| from agentex.types.task_message_update import ( | ||
| StreamTaskMessageFull, | ||
| StreamTaskMessageDelta, | ||
| ) | ||
| from agentex.lib.core.services.adk.streaming import ( | ||
| CoalescingBuffer, | ||
| StreamingTaskMessageContext, | ||
|
|
@@ -352,6 +355,24 @@ async def on_flush(u: StreamTaskMessageDelta) -> None: | |
| await buf.add(_text(task_message, "after")) | ||
| assert flushed == [] | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_add_racing_close_is_not_stranded(self, task_message: TaskMessage) -> None: | ||
| """TOCTOU: a delta that passes add()'s pre-lock _closed check but only | ||
| acquires the lock after close() set _closed must be dropped, not appended | ||
| to a drained, ticker-less buffer where it would never be published.""" | ||
| buf = CoalescingBuffer(on_flush=AsyncMock()) | ||
| buf.start() | ||
| # Hold the lock so add() parks *after* its pre-lock _closed check. | ||
| await buf._lock.acquire() | ||
| add_task = asyncio.create_task(buf.add(_text(task_message, "racing"))) | ||
| await asyncio.sleep(0) # add() passes the _closed check, blocks on the lock | ||
| buf._closed = True # close() wins the race | ||
| buf._lock.release() | ||
| await add_task | ||
|
|
||
| assert buf._buf == [], "racing delta was stranded in the closed buffer" | ||
| await buf.close() # cleanup | ||
|
|
||
|
|
||
| class TestCoalescingBufferCloseDuringFlush: | ||
| @pytest.mark.asyncio | ||
|
|
@@ -520,3 +541,57 @@ async def test_open_without_created_at_passes_omit(self) -> None: | |
|
|
||
| kwargs = client.messages.create.call_args.kwargs | ||
| assert kwargs["created_at"] is omit | ||
|
|
||
|
|
||
| class TestFullMessageClosesBuffer: | ||
| """A StreamTaskMessageFull must stop the buffer ticker and drain its deltas | ||
| before the terminal Full. Marking the context done without closing the | ||
| buffer leaves close()'s _is_closed short-circuit to orphan the ticker, and | ||
| publishing buffered deltas after the Full reads as a stale duplicate tail.""" | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_full_message_stops_ticker(self) -> None: | ||
| ctx, _svc, tm = await _make_context("coalesced") | ||
| # A delta makes the buffer and its ticker live. | ||
| await ctx.stream_update(_text(tm, "hello")) | ||
| buf = ctx._buffer | ||
| assert buf is not None | ||
| task = buf._task | ||
| assert task is not None and not task.done() | ||
|
|
||
| await ctx.stream_update( | ||
| StreamTaskMessageFull( | ||
| parent_task_message=tm, | ||
| content=TextContent(author="agent", content="final", format="markdown"), | ||
| type="full", | ||
| ) | ||
| ) | ||
|
|
||
| assert ctx._buffer is None, "Full message left the buffer un-closed" | ||
| assert task.done(), "coalescing-buffer ticker still running after Full (orphaned)" | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_full_is_terminal_publish_no_trailing_deltas(self) -> None: | ||
| # Buffered deltas must publish BEFORE the Full, never after (a trailing | ||
| # delta after the terminal Full reads as a stale duplicate tail). | ||
| ctx, svc, tm = await _make_context("coalesced") | ||
| # "alpha" flushes immediately; "beta" stays buffered in the window. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. technically, "beta" may or may not stay buffered in the window since the assertions only need the Full to be the terminal publish and 1 or more delta before it, right? if so, i think this test is still a solid one as written but maybe just worth updating the comment to be more precise (or adding a clarification to the assertions below) |
||
| await ctx.stream_update(_text(tm, "alpha")) | ||
| await ctx.stream_update(_text(tm, "beta")) | ||
|
|
||
| full = StreamTaskMessageFull( | ||
| parent_task_message=tm, | ||
| content=TextContent(author="agent", content="alphabeta", format="markdown"), | ||
| type="full", | ||
| ) | ||
| await ctx.stream_update(full) | ||
|
|
||
| published = [c.args[0] for c in svc.stream_update.await_args_list] | ||
| assert published, "nothing was published" | ||
| assert published[-1] is full, ( | ||
| f"Full must be the terminal publish; saw trailing " | ||
| f"{type(published[-1]).__name__} after it (stale duplicate tail)" | ||
| ) | ||
| assert any(isinstance(u, StreamTaskMessageDelta) for u in published[:-1]), ( | ||
| "expected the buffered deltas to be published before the Full" | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think it would be ever so slightly cleaner to just abstract this into _reap_buffer even if we're only calling it into two places and then call it here and below