Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions src/acp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,14 @@ describe("Connection", () => {
const appAgent = createAgent({ name: "app-agent" })
.onRequest(AGENT_METHODS.initialize, (c) => {
events.push(`initialize:${c.params.protocolVersion}`);
expect(Object.keys(c).sort()).toEqual(["client", "params", "signal"]);
expect(Object.keys(c).sort()).toEqual([
"client",
"params",
"requestId",
"signal",
]);
expect(c.requestId).toBe(0);
expect(c.client.requestId).toBe(0);
expect(c.signal.aborted).toBe(false);

return {
Expand All @@ -732,6 +739,8 @@ describe("Connection", () => {
})
.onRequest(AGENT_METHODS.session_new, (c) => {
events.push(`new:${c.params.cwd}`);
expect(c.requestId).toBe(1);
expect(c.client.requestId).toBe(1);
return { sessionId: "app-session" };
})
.onNotification(
Expand All @@ -744,11 +753,28 @@ describe("Connection", () => {
},
(c) => {
expect(Object.keys(c).sort()).toEqual(["client", "params", "signal"]);
expect("requestId" in c).toBe(false);
expect(c.client.requestId).toBeUndefined();
events.push(`agent-route:${String(c.params.message)}`);
},
)
.onRequest(AGENT_METHODS.session_prompt, async (c) => {
events.push(`prompt:${c.params.sessionId}`);
expect(c.requestId).toBe(2);
expect(c.client.requestId).toBe(2);
const elicitation = await c.client.request(
methods.client.elicitation.create,
{
requestId: c.requestId,
mode: "form",
message: "Need input",
requestedSchema: {
type: "object",
properties: { value: { type: "string" } },
},
},
);
events.push(`elicitation:${elicitation.action}`);
const pong = await c.client.request<{ message: string }>(
"vendor/ping",
{
Expand All @@ -769,16 +795,42 @@ describe("Connection", () => {

const appClient = createClient({ name: "app-client" })
.onNotification(CLIENT_METHODS.session_update, (c) => {
expect("requestId" in c).toBe(false);
expect(c.agent.requestId).toBeUndefined();
events.push(`update:${c.params.sessionId}`);
})
.onRequest(CLIENT_METHODS.elicitation_create, (c) => {
expect(Object.keys(c).sort()).toEqual([
"agent",
"params",
"requestId",
"signal",
]);
expect(c.requestId).toBe(0);
expect(c.agent.requestId).toBe(0);
if (!("requestId" in c.params)) {
throw new Error("Expected request-scoped elicitation");
}
expect(c.params.requestId).toBe(2);
events.push(`client-elicitation:${String(c.params.requestId)}`);

return { action: "decline" };
})
.onRequest(
"vendor/ping",
(params) => {
const message = (params as Record<string, unknown>).message;
return { message: String(message).toUpperCase() };
},
(c) => {
expect(Object.keys(c).sort()).toEqual(["agent", "params", "signal"]);
expect(Object.keys(c).sort()).toEqual([
"agent",
"params",
"requestId",
"signal",
]);
expect(c.requestId).toBe(1);
expect(c.agent.requestId).toBe(1);
expect(c.signal.aborted).toBe(false);
events.push(`client-route:${String(c.params.message)}`);

Expand Down Expand Up @@ -820,6 +872,8 @@ describe("Connection", () => {
"new:/app",
"agent-route:from-client:parsed",
"prompt:app-session",
"client-elicitation:2",
"elicitation:decline",
"client-route:HELLO",
"pong:HELLO",
"update:app-session",
Expand Down
142 changes: 116 additions & 26 deletions src/acp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export type {
AnyRequest,
AnyResponse,
ErrorResponse,
JsonRpcId,
MaybePromise,
Result,
SendRequestOptions,
Expand All @@ -28,6 +29,7 @@ import type {
ConnectionContext,
HandleResult,
IncomingMessage,
JsonRpcId,
JsonRpcHandler,
MaybePromise,
SendRequestOptions,
Expand Down Expand Up @@ -186,7 +188,20 @@ export interface ClientConnection extends AcpConnection {

class AcpContext {
/** @internal */
constructor(private readonly cx: ConnectionContext) {}
constructor(
private readonly cx: ConnectionContext,
private readonly currentRequestId?: JsonRpcId,
) {}

/**
* JSON-RPC id of the request currently being handled.
*
* This is `undefined` for notification handlers and for contexts created
* outside an inbound request, such as `connect(...)` and `connectWith(...)`.
*/
get requestId(): JsonRpcId | undefined {
return this.currentRequestId;
}

/** @internal */
protected get connectionContext(): ConnectionContext {
Expand Down Expand Up @@ -221,13 +236,13 @@ class AcpContext {
* requests such as `session/prompt`.
*/
export class AgentContext extends AcpContext {
private constructor(cx: ConnectionContext) {
super(cx);
private constructor(cx: ConnectionContext, requestId?: JsonRpcId) {
super(cx, requestId);
}

/** @internal */
static create(cx: ConnectionContext): AgentContext {
return new AgentContext(cx);
static create(cx: ConnectionContext, requestId?: JsonRpcId): AgentContext {
return new AgentContext(cx, requestId);
}

/**
Expand Down Expand Up @@ -280,13 +295,13 @@ export class AgentContext extends AcpContext {
* receive one as `ctx.agent` when they need to call back into the agent.
*/
export class ClientContext extends AcpContext {
private constructor(cx: ConnectionContext) {
super(cx);
private constructor(cx: ConnectionContext, requestId?: JsonRpcId) {
super(cx, requestId);
}

/** @internal */
static create(cx: ConnectionContext): ClientContext {
return new ClientContext(cx);
static create(cx: ConnectionContext, requestId?: JsonRpcId): ClientContext {
return new ClientContext(cx, requestId);
}

/** @internal */
Expand Down Expand Up @@ -895,7 +910,7 @@ export type ParamsParser<Params> =
| ((params: unknown) => Params);

/**
* Context passed to agent-side request and notification handlers.
* Common context passed to agent-side handlers.
*/
export type AgentHandlerContext<Params> = {
/**
Expand All @@ -914,7 +929,24 @@ export type AgentHandlerContext<Params> = {
};

/**
* Context passed to client-side request and notification handlers.
* Context passed to agent-side request handlers.
*/
export type AgentRequestContext<Params> = AgentHandlerContext<Params> & {
/**
* JSON-RPC id of the request currently being handled.
*/
requestId: JsonRpcId;
};

/**
* Context passed to agent-side notification handlers.
*
* Notifications do not have JSON-RPC request ids.
*/
export type AgentNotificationContext<Params> = AgentHandlerContext<Params>;

/**
* Common context passed to client-side handlers.
*/
export type ClientHandlerContext<Params> = {
/**
Expand All @@ -932,32 +964,49 @@ export type ClientHandlerContext<Params> = {
agent: ClientContext;
};

/**
* Context passed to client-side request handlers.
*/
export type ClientRequestContext<Params> = ClientHandlerContext<Params> & {
/**
* JSON-RPC id of the request currently being handled.
*/
requestId: JsonRpcId;
};

/**
* Context passed to client-side notification handlers.
*
* Notifications do not have JSON-RPC request ids.
*/
export type ClientNotificationContext<Params> = ClientHandlerContext<Params>;

/**
* Request handler registered on an `AgentApp`.
*/
export type AgentRequestHandler<Params, Response> = (
context: AgentHandlerContext<Params>,
context: AgentRequestContext<Params>,
) => MaybePromise<Response>;

/**
* Notification handler registered on an `AgentApp`.
*/
export type AgentNotificationHandler<Params> = (
context: AgentHandlerContext<Params>,
context: AgentNotificationContext<Params>,
) => MaybePromise<void>;

/**
* Request handler registered on a `ClientApp`.
*/
export type ClientRequestHandler<Params, Response> = (
context: ClientHandlerContext<Params>,
context: ClientRequestContext<Params>,
) => MaybePromise<Response>;

/**
* Notification handler registered on a `ClientApp`.
*/
export type ClientNotificationHandler<Params> = (
context: ClientHandlerContext<Params>,
context: ClientNotificationContext<Params>,
) => MaybePromise<void>;

/**
Expand Down Expand Up @@ -1022,14 +1071,17 @@ function registerAppRequest<Params, Response, WireResponse, Context>(
params: Params,
cx: ConnectionContext,
signal: AbortSignal,
requestId: JsonRpcId,
) => Context,
handler: (context: Context) => MaybePromise<Response>,
): void {
builder.onReceiveRequest<Params, WireResponse>(
spec.method,
(params) => parseParams(spec.params, params),
async (params, responder, cx) => {
const response = await handler(context(params, cx, responder.signal));
const response = await handler(
context(params, cx, responder.signal, responder.id),
);
await responder.respond(
(spec.mapResponse
? spec.mapResponse(response)
Expand Down Expand Up @@ -1561,23 +1613,51 @@ export type ClientNotificationParamsByMethod = {
: never;
};

function agentHandlerContext<Params>(
function agentRequestContext<Params>(
params: Params,
client: AgentContext,
signal: AbortSignal,
requestId: JsonRpcId,
): AgentRequestContext<Params> {
return {
params,
requestId,
signal,
client,
};
}

function agentNotificationContext<Params>(
params: Params,
client: AgentContext,
signal: AbortSignal,
): AgentHandlerContext<Params> {
): AgentNotificationContext<Params> {
return {
params,
signal,
client,
};
}

function clientHandlerContext<Params>(
function clientRequestContext<Params>(
params: Params,
agent: ClientContext,
signal: AbortSignal,
): ClientHandlerContext<Params> {
requestId: JsonRpcId,
): ClientRequestContext<Params> {
return {
params,
requestId,
signal,
agent,
};
}

function clientNotificationContext<Params>(
params: Params,
agent: ClientContext,
signal: AbortSignal,
): ClientNotificationContext<Params> {
return {
params,
signal,
Expand Down Expand Up @@ -1875,8 +1955,13 @@ export class AgentApp {
registerAppRequest(
this.builder,
spec,
(params, cx, signal) =>
agentHandlerContext(params, AgentContext.create(cx), signal),
(params, cx, signal, requestId) =>
agentRequestContext(
params,
AgentContext.create(cx, requestId),
signal,
requestId,
),
handler,
);
return this;
Expand All @@ -1890,7 +1975,7 @@ export class AgentApp {
this.builder,
spec,
(params, cx, signal) =>
agentHandlerContext(params, AgentContext.create(cx), signal),
agentNotificationContext(params, AgentContext.create(cx), signal),
handler,
);
return this;
Expand Down Expand Up @@ -2120,8 +2205,13 @@ export class ClientApp {
registerAppRequest(
this.builder,
spec,
(params, cx, signal) =>
clientHandlerContext(params, ClientContext.create(cx), signal),
(params, cx, signal, requestId) =>
clientRequestContext(
params,
ClientContext.create(cx, requestId),
signal,
requestId,
),
handler,
);
return this;
Expand All @@ -2135,7 +2225,7 @@ export class ClientApp {
this.builder,
spec,
(params, cx, signal) =>
clientHandlerContext(params, ClientContext.create(cx), signal),
clientNotificationContext(params, ClientContext.create(cx), signal),
handler,
);
return this;
Expand Down
Loading