Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/agentex/lib/core/services/adk/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ async def add(self, update: StreamTaskMessageDelta) -> None:
if self._closed:
return
async with self._lock:
# Re-check under the lock: a concurrent close() (e.g. from a racing
# Full) may have drained and shut down the ticker after the check
# above but before we acquired the lock. Appending now would strand
# the delta in a dead buffer, never published.
if self._closed:
return
self._buf.append(update)
self._buf_chars += _delta_char_len(update.delta)
if not self._first_flushed or self._buf_chars >= self.MAX_BUFFERED_CHARS:
Expand Down Expand Up @@ -420,15 +426,17 @@ async def close(self) -> TaskMessage:
if not self.task_message:
raise ValueError("Context not properly initialized - no task message")

if self._is_closed:
return self.task_message # Already done

# Drain any buffered deltas before announcing DONE so consumers see the
# full sequence in order.
# Reap the buffer (stopping its ticker) before the _is_closed
# short-circuit, so a context already marked done by a Full update can't
# leave the ticker orphaned. Draining here also lets consumers see the
# full delta sequence in order before DONE.
if self._buffer is not None:
await self._buffer.close()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think it would be ever so slightly cleaner to just abstract this into _reap_buffer even if we're only calling it into two places and then call it here and below

self._buffer = None

if self._is_closed:
return self.task_message # Already done (buffer reaped above)

# Send the DONE event
done_event = StreamTaskMessageDone(
parent_task_message=self.task_message,
Expand Down Expand Up @@ -486,6 +494,15 @@ async def stream_update(self, update: TaskMessageUpdate) -> TaskMessageUpdate |
await self._buffer.add(update)
return update

# A Full ends the stream and supersedes buffered deltas. Drain and stop
# the buffer BEFORE publishing the Full, so leftover deltas land in order
# (deltas -> Full) instead of trailing the terminal Full as a stale
# duplicate tail. This also stops the ticker, which would otherwise be
# orphaned when __aexit__'s close() short-circuits on _is_closed.
if isinstance(update, StreamTaskMessageFull) and self._buffer is not None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that Greptile also commented about this and looks like it resolved its own comment but it looks to me like it's still possible that we could have a race here? e.g.

  1. This code block closes the buffer and sets it to none
  2. stream_update (or messages.update) below is being awaited
  3. another delta comes in and skips this block since buffer is already none but _is_closed is not yet set to True

In this case, consumers will still see a delta after full since it just falls through to publishing directly. If this is something we care about (and my logic checks out and you also believe there's a race), we could have a terminal-in-progress flag like _is_closing set before awaiting buffer close to reject deltas?

await self._buffer.close()
self._buffer = None
Comment thread
greptile-apps[bot] marked this conversation as resolved.

result = await self._streaming_service.stream_update(update)

if isinstance(update, StreamTaskMessageDone):
Expand Down
77 changes: 76 additions & 1 deletion tests/lib/core/services/adk/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
ToolResponseDelta,
ReasoningSummaryDelta,
)
from agentex.types.task_message_update import StreamTaskMessageDelta
from agentex.types.task_message_update import (
StreamTaskMessageFull,
StreamTaskMessageDelta,
)
from agentex.lib.core.services.adk.streaming import (
CoalescingBuffer,
StreamingTaskMessageContext,
Expand Down Expand Up @@ -352,6 +355,24 @@ async def on_flush(u: StreamTaskMessageDelta) -> None:
await buf.add(_text(task_message, "after"))
assert flushed == []

@pytest.mark.asyncio
async def test_add_racing_close_is_not_stranded(self, task_message: TaskMessage) -> None:
"""TOCTOU: a delta that passes add()'s pre-lock _closed check but only
acquires the lock after close() set _closed must be dropped, not appended
to a drained, ticker-less buffer where it would never be published."""
buf = CoalescingBuffer(on_flush=AsyncMock())
buf.start()
# Hold the lock so add() parks *after* its pre-lock _closed check.
await buf._lock.acquire()
add_task = asyncio.create_task(buf.add(_text(task_message, "racing")))
await asyncio.sleep(0) # add() passes the _closed check, blocks on the lock
buf._closed = True # close() wins the race
buf._lock.release()
await add_task

assert buf._buf == [], "racing delta was stranded in the closed buffer"
await buf.close() # cleanup


class TestCoalescingBufferCloseDuringFlush:
@pytest.mark.asyncio
Expand Down Expand Up @@ -520,3 +541,57 @@ async def test_open_without_created_at_passes_omit(self) -> None:

kwargs = client.messages.create.call_args.kwargs
assert kwargs["created_at"] is omit


class TestFullMessageClosesBuffer:
"""A StreamTaskMessageFull must stop the buffer ticker and drain its deltas
before the terminal Full. Marking the context done without closing the
buffer leaves close()'s _is_closed short-circuit to orphan the ticker, and
publishing buffered deltas after the Full reads as a stale duplicate tail."""

@pytest.mark.asyncio
async def test_full_message_stops_ticker(self) -> None:
ctx, _svc, tm = await _make_context("coalesced")
# A delta makes the buffer and its ticker live.
await ctx.stream_update(_text(tm, "hello"))
buf = ctx._buffer
assert buf is not None
task = buf._task
assert task is not None and not task.done()

await ctx.stream_update(
StreamTaskMessageFull(
parent_task_message=tm,
content=TextContent(author="agent", content="final", format="markdown"),
type="full",
)
)

assert ctx._buffer is None, "Full message left the buffer un-closed"
assert task.done(), "coalescing-buffer ticker still running after Full (orphaned)"

@pytest.mark.asyncio
async def test_full_is_terminal_publish_no_trailing_deltas(self) -> None:
# Buffered deltas must publish BEFORE the Full, never after (a trailing
# delta after the terminal Full reads as a stale duplicate tail).
ctx, svc, tm = await _make_context("coalesced")
# "alpha" flushes immediately; "beta" stays buffered in the window.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically, "beta" may or may not stay buffered in the window since the assertions only need the Full to be the terminal publish and 1 or more delta before it, right? if so, i think this test is still a solid one as written but maybe just worth updating the comment to be more precise (or adding a clarification to the assertions below)

await ctx.stream_update(_text(tm, "alpha"))
await ctx.stream_update(_text(tm, "beta"))

full = StreamTaskMessageFull(
parent_task_message=tm,
content=TextContent(author="agent", content="alphabeta", format="markdown"),
type="full",
)
await ctx.stream_update(full)

published = [c.args[0] for c in svc.stream_update.await_args_list]
assert published, "nothing was published"
assert published[-1] is full, (
f"Full must be the terminal publish; saw trailing "
f"{type(published[-1]).__name__} after it (stale duplicate tail)"
)
assert any(isinstance(u, StreamTaskMessageDelta) for u in published[:-1]), (
"expected the buffered deltas to be published before the Full"
)
Loading