Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/acp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

import asyncio
from typing import Any

from .agent.connection import AgentSideConnection
Expand Down Expand Up @@ -69,7 +70,10 @@ async def run_agent(
use_unstable_protocol=use_unstable_protocol,
**connection_kwargs,
)
await conn.listen()
try:
await conn.listen()
finally:
await asyncio.shield(conn.close())


def connect_to_agent(
Expand Down
55 changes: 55 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

import asyncio
import contextlib
from typing import Any

import pytest

from acp.core import run_agent


@pytest.mark.asyncio
async def test_run_agent_closes_connection_when_cancelled(server, agent) -> None:
sender_created = asyncio.Event()
sender_closed = asyncio.Event()
dispatcher_started = asyncio.Event()
dispatcher_stopped = asyncio.Event()

class TrackingSender:
def __init__(self, writer: asyncio.StreamWriter, supervisor: Any) -> None:
sender_created.set()

async def send(self, payload: dict[str, Any]) -> None:
msg = "test does not send messages"
raise AssertionError(msg)

async def close(self) -> None:
sender_closed.set()

class TrackingDispatcher:
def start(self) -> None:
dispatcher_started.set()

async def stop(self) -> None:
dispatcher_stopped.set()

task = asyncio.create_task(
run_agent(
agent,
server.server_writer,
server.server_reader,
sender_factory=TrackingSender,
dispatcher_factory=lambda *args: TrackingDispatcher(),
)
)

await asyncio.wait_for(sender_created.wait(), timeout=1)
await asyncio.wait_for(dispatcher_started.wait(), timeout=1)

task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await asyncio.wait_for(task, timeout=1)

await asyncio.wait_for(dispatcher_stopped.wait(), timeout=1)
await asyncio.wait_for(sender_closed.wait(), timeout=1)
Loading