/** * AgentChatTransport — bridges the AI SDK's useChat hook with an Agent * WebSocket connection that speaks Think's streaming protocol. * * Features: * - Request ID correlation: each request gets a unique ID, only matching * WS messages are processed * - Cancel: sends { type: "cancel", requestId } to stop server-side streaming * - Completion guard: close/error/abort are idempotent * - Signal-based cleanup: uses AbortController signal on addEventListener * - Stream resumption: reconnectToStream sends resume-request, server replays * buffered chunks via ChunkRelay * * @example * ```tsx * import { AgentChatTransport } from "@cloudflare/think/transport"; * import { useAgent } from "agents/react"; * import { useChat } from "@ai-sdk/react"; * * const agent = useAgent({ agent: "MyAssistant" }); * const transport = useMemo(() => new AgentChatTransport(agent), [agent]); * const { messages, sendMessage, status } = useChat({ transport }); * ``` */ import type { UIMessage, UIMessageChunk, ChatTransport } from "ai"; /** * Minimal interface for the agent connection object. * Satisfied by the return value of `useAgent()` from `agents/react`. */ export interface AgentSocket { addEventListener( type: "message", handler: (event: MessageEvent) => void, options?: { signal?: AbortSignal } ): void; removeEventListener( type: "message", handler: (event: MessageEvent) => void ): void; call(method: string, args?: unknown[]): Promise; send(data: string): void; } /** * Options for constructing an AgentChatTransport. */ export interface AgentChatTransportOptions { /** * The server-side RPC method to call when sending a message. * Receives `[text, requestId]` as arguments. * @default "sendMessage" */ sendMethod?: string; /** * Timeout in milliseconds for reconnectToStream to wait for a * stream-resuming response before giving up. * @default 500 */ resumeTimeout?: number; } /** * Extract the text content from a UIMessage's parts. */ function getMessageText(msg: UIMessage): string { return msg.parts .filter((p): p is { type: "text"; text: string } => p.type === "text") .map((p) => p.text) .join(""); } /** * ChatTransport implementation for Agent WebSocket connections. * * Speaks the wire protocol used by Think's `chat()` method * and ChunkRelay on the server: * - `stream-start` → new stream with requestId * - `stream-event` → UIMessageChunk payload * - `stream-done` → stream complete * - `stream-resuming` → replay after reconnect * - `cancel` → client→server abort */ export class AgentChatTransport implements ChatTransport { #agent: AgentSocket; #activeRequestIds = new Set(); #currentFinish: (() => void) | null = null; #sendMethod: string; #resumeTimeout: number; constructor(agent: AgentSocket, options?: AgentChatTransportOptions) { this.#agent = agent; this.#sendMethod = options?.sendMethod ?? "sendMessage"; this.#resumeTimeout = options?.resumeTimeout ?? 500; } /** * Detach from the current stream. Call this before switching agents * or cleaning up to ensure the stream controller is closed. */ detach() { this.#currentFinish?.(); this.#currentFinish = null; } async sendMessages({ messages, abortSignal }: Parameters["sendMessages"]>[0]): Promise< ReadableStream > { const lastMessage = messages[messages.length - 1]; const text = getMessageText(lastMessage); const requestId = crypto.randomUUID().slice(0, 8); let completed = false; const abortController = new AbortController(); let streamController!: ReadableStreamDefaultController; const finish = (action: () => void) => { if (completed) return; completed = true; this.#currentFinish = null; try { action(); } catch { /* stream may already be closed */ } this.#activeRequestIds.delete(requestId); abortController.abort(); }; this.#currentFinish = () => finish(() => streamController.close()); const onAbort = () => { if (completed) return; try { this.#agent.send(JSON.stringify({ type: "cancel", requestId })); } catch { /* ignore send failures */ } finish(() => streamController.error( Object.assign(new Error("Aborted"), { name: "AbortError" }) ) ); }; const stream = new ReadableStream({ start(controller) { streamController = controller; }, cancel() { onAbort(); } }); this.#agent.addEventListener( "message", (event: MessageEvent) => { if (typeof event.data !== "string") return; try { const msg = JSON.parse(event.data); if (msg.requestId !== requestId) return; if (msg.type === "stream-event") { const chunk: UIMessageChunk = JSON.parse(msg.event); streamController.enqueue(chunk); } else if (msg.type === "stream-done") { finish(() => streamController.close()); } } catch { /* ignore parse errors */ } }, { signal: abortController.signal } ); if (abortSignal) { abortSignal.addEventListener("abort", onAbort, { once: true }); if (abortSignal.aborted) onAbort(); } this.#activeRequestIds.add(requestId); this.#agent .call(this.#sendMethod, [text, requestId]) .catch((error: Error) => { finish(() => streamController.error(error)); }); return stream; } async reconnectToStream(): Promise | null> { const resumeTimeout = this.#resumeTimeout; return new Promise | null>((resolve) => { let resolved = false; let timeout: ReturnType | undefined; const done = (value: ReadableStream | null) => { if (resolved) return; resolved = true; if (timeout) clearTimeout(timeout); this.#agent.removeEventListener("message", handler); resolve(value); }; const handler = (event: MessageEvent) => { if (typeof event.data !== "string") return; try { const msg = JSON.parse(event.data); if (msg.type === "stream-resuming") { done(this.#createResumeStream(msg.requestId)); } } catch { /* ignore */ } }; this.#agent.addEventListener("message", handler); try { this.#agent.send(JSON.stringify({ type: "resume-request" })); } catch { /* WebSocket may not be open yet */ } timeout = setTimeout(() => done(null), resumeTimeout); }); } #createResumeStream(requestId: string): ReadableStream { const abortController = new AbortController(); let completed = false; const finish = (action: () => void) => { if (completed) return; completed = true; try { action(); } catch { /* stream may already be closed */ } this.#activeRequestIds.delete(requestId); abortController.abort(); }; this.#activeRequestIds.add(requestId); return new ReadableStream({ start: (controller) => { this.#agent.addEventListener( "message", (event: MessageEvent) => { if (typeof event.data !== "string") return; try { const msg = JSON.parse(event.data); if (msg.requestId !== requestId) return; if (msg.type === "stream-event") { const chunk: UIMessageChunk = JSON.parse(msg.event); controller.enqueue(chunk); } else if (msg.type === "stream-done") { finish(() => controller.close()); } } catch { /* ignore */ } }, { signal: abortController.signal } ); }, cancel() { finish(() => {}); } }); } }