/** * WebSocket-based ChatTransport for useAgentChat. * * Replaces the aiFetch + DefaultChatTransport indirection with a direct * WebSocket implementation that speaks the CF_AGENT protocol natively. * * Data flow (old): WS → aiFetch fake Response → DefaultChatTransport → useChat * Data flow (new): WS → WebSocketChatTransport → useChat */ import type { ChatTransport, UIMessage, UIMessageChunk } from "ai"; import { nanoid } from "nanoid"; import { MessageType, type OutgoingMessage } from "./types"; /** * Agent-like interface for sending/receiving WebSocket messages. * Matches the shape returned by useAgent from agents/react. */ export interface AgentConnection { send: (data: string) => void; addEventListener: ( type: string, listener: (event: MessageEvent) => void, options?: { signal?: AbortSignal } ) => void; removeEventListener: ( type: string, listener: (event: MessageEvent) => void ) => void; } export type WebSocketChatTransportOptions< ChatMessage extends UIMessage = UIMessage > = { /** The agent connection from useAgent */ agent: AgentConnection; /** * Callback to prepare the request body before sending. * Can add custom headers, body fields, or credentials. */ prepareBody?: (options: { messages: ChatMessage[]; trigger: "submit-message" | "regenerate-message"; messageId?: string; }) => Promise> | Record; /** * Optional set to track active request IDs. * IDs are added when a request starts and removed when it completes. * Used by the onAgentMessage handler to skip messages already handled by the transport. */ activeRequestIds?: Set; }; /** * ChatTransport that sends messages over WebSocket and returns a * ReadableStream that the AI SDK's useChat consumes directly. * No fake fetch, no Response reconstruction, no double SSE parsing. */ export class WebSocketChatTransport< ChatMessage extends UIMessage = UIMessage > implements ChatTransport { agent: AgentConnection; private prepareBody?: WebSocketChatTransportOptions["prepareBody"]; private activeRequestIds?: Set; // Pending resume resolver — set by reconnectToStream, called by // handleStreamResuming when onAgentMessage sees CF_AGENT_STREAM_RESUMING. private _resumeResolver: ((data: { id: string }) => void) | null = null; // Pending "no stream" resolver — called by handleStreamResumeNone // when onAgentMessage sees CF_AGENT_STREAM_RESUME_NONE. private _resumeNoneResolver: (() => void) | null = null; constructor(options: WebSocketChatTransportOptions) { this.agent = options.agent; this.prepareBody = options.prepareBody; this.activeRequestIds = options.activeRequestIds; } /** * Called by onAgentMessage when it receives CF_AGENT_STREAM_RESUMING. * If reconnectToStream is waiting, this handles the resume handshake * (ACK + stream creation) and returns true. Otherwise returns false * so the caller can use its own fallback path. */ handleStreamResuming(data: { id: string }): boolean { if (!this._resumeResolver) return false; this._resumeResolver(data); return true; } /** * Called by onAgentMessage when it receives CF_AGENT_STREAM_RESUME_NONE. * If reconnectToStream is waiting, resolves the promise with null * immediately (no 5-second timeout). Returns true if handled. */ handleStreamResumeNone(): boolean { if (!this._resumeNoneResolver) return false; this._resumeNoneResolver(); return true; } async sendMessages(options: { chatId: string; messages: ChatMessage[]; abortSignal: AbortSignal | undefined; trigger: "submit-message" | "regenerate-message"; messageId?: string; body?: object; headers?: Record | Headers; metadata?: unknown; }): Promise> { const requestId = nanoid(8); const abortController = new AbortController(); let completed = false; // Build the request body let extraBody: Record = {}; if (this.prepareBody) { extraBody = await this.prepareBody({ messages: options.messages, trigger: options.trigger, messageId: options.messageId }); } if (options.body) { extraBody = { ...extraBody, ...(options.body as Record) }; } const bodyPayload = JSON.stringify({ messages: options.messages, trigger: options.trigger, ...extraBody }); // Track this request so the onAgentMessage handler skips it this.activeRequestIds?.add(requestId); // Create a ReadableStream that emits parsed chunks // as they arrive over the WebSocket const agent = this.agent; const activeIds = this.activeRequestIds; // Single cleanup helper — every terminal path (done, error, abort) // goes through here exactly once. // keepId: when true, do NOT remove requestId from activeIds. Used by // onAbort so that onAgentMessage continues to skip in-flight chunks // and the server's final done:true broadcast until cleanup happens there. const finish = (action: () => void, keepId = false) => { if (completed) return; completed = true; try { action(); } catch { // Stream may already be closed } if (!keepId) { activeIds?.delete(requestId); } abortController.abort(); }; const abortError = new Error("Aborted"); abortError.name = "AbortError"; // Abort handler: send cancel to server, then terminate the stream. // Used by both the caller's abortSignal and stream.cancel(). // keepId=true: keep requestId in activeIds so onAgentMessage skips any // in-flight chunks the server broadcasts before its done:true signal. // The ID is removed by onAgentMessage when done:true is received. const onAbort = () => { if (completed) return; try { agent.send( JSON.stringify({ id: requestId, type: MessageType.CF_AGENT_CHAT_REQUEST_CANCEL }) ); } catch { // Ignore failures (e.g. agent already disconnected) } finish(() => streamController.error(abortError), true); }; // streamController is assigned synchronously by start(), so it is // always available by the time onAbort or onMessage can fire. let streamController!: ReadableStreamDefaultController; const stream = new ReadableStream({ start(controller) { streamController = controller; const onMessage = (event: MessageEvent) => { try { const data = JSON.parse( event.data as string ) as OutgoingMessage; if (data.type !== MessageType.CF_AGENT_USE_CHAT_RESPONSE) return; if (data.id !== requestId) return; if (data.error) { finish(() => controller.error(new Error(data.body || "Stream error")) ); return; } // Parse the body as UIMessageChunk and enqueue if (data.body?.trim()) { try { const chunk = JSON.parse(data.body) as UIMessageChunk; controller.enqueue(chunk); } catch { // Skip malformed chunk bodies } } if (data.done) { finish(() => controller.close()); } } catch { // Ignore non-JSON messages } }; agent.addEventListener("message", onMessage, { signal: abortController.signal }); }, cancel() { onAbort(); } }); // Handle abort from the caller if (options.abortSignal) { options.abortSignal.addEventListener("abort", onAbort, { once: true }); if (options.abortSignal.aborted) onAbort(); } // Send the request over WebSocket agent.send( JSON.stringify({ id: requestId, init: { method: "POST", body: bodyPayload }, type: MessageType.CF_AGENT_USE_CHAT_REQUEST }) ); return stream; } async reconnectToStream(_options: { chatId: string; }): Promise | null> { // Detect whether the server has an active stream for this chat. // Instead of registering our own addEventListener listener (which // races with onAgentMessage), we set _resumeResolver so that // onAgentMessage can call handleStreamResuming() synchronously // when it sees CF_AGENT_STREAM_RESUMING — eliminating the race. const activeIds = this.activeRequestIds; return new Promise | null>((resolve) => { let resolved = false; let timeout: ReturnType | undefined; const done = (value: ReadableStream | null) => { if (resolved) return; resolved = true; this._resumeResolver = null; this._resumeNoneResolver = null; if (timeout) clearTimeout(timeout); resolve(value); }; // Set the "no stream" resolver that handleStreamResumeNone() will call. // When onAgentMessage sees CF_AGENT_STREAM_RESUME_NONE, it calls // handleStreamResumeNone() which resolves immediately with null. this._resumeNoneResolver = () => done(null); // Set the resolver that handleStreamResuming() will call. // When onAgentMessage sees CF_AGENT_STREAM_RESUMING, it calls // handleStreamResuming() which invokes this callback. this._resumeResolver = (data: { id: string }) => { const requestId = data.id; // Track this request so onAgentMessage skips subsequent chunks activeIds?.add(requestId); // Send ACK to server via the latest agent (the socket may // have been replaced since reconnectToStream was called). this.agent.send( JSON.stringify({ type: MessageType.CF_AGENT_STREAM_RESUME_ACK, id: requestId }) ); // Return a ReadableStream fed by the replayed + live chunks done(this._createResumeStream(requestId)); }; // Send the resume request. PartySocket queues sends when // the socket isn't open yet and flushes on connect, so // this works regardless of current readyState. try { this.agent.send( JSON.stringify({ type: MessageType.CF_AGENT_STREAM_RESUME_REQUEST }) ); } catch { // WebSocket may already be closed } // Safety-net timeout: if the WebSocket never connects or the // server is unreachable, resolve null. Under normal operation // the server responds with STREAM_RESUMING or STREAM_RESUME_NONE // well before this fires. timeout = setTimeout(() => done(null), 5000); }); } /** * Creates a ReadableStream that receives resumed stream chunks * and forwards them to useChat as UIMessageChunk objects. */ private _createResumeStream( requestId: string ): ReadableStream { // Read agent at resolve time (not when reconnectToStream was called) // so chunk listener attaches to the latest socket after _pk changes. const agent = this.agent; const activeIds = this.activeRequestIds; const chunkController = new AbortController(); let completed = false; const finish = (action: () => void) => { if (completed) return; completed = true; try { action(); } catch { // Stream may already be closed } activeIds?.delete(requestId); chunkController.abort(); }; return new ReadableStream({ start(controller) { const onMessage = (event: MessageEvent) => { try { const data = JSON.parse( event.data as string ) as OutgoingMessage; if (data.type !== MessageType.CF_AGENT_USE_CHAT_RESPONSE) return; if (data.id !== requestId) return; if (data.error) { finish(() => controller.error(new Error(data.body || "Stream error")) ); return; } // Parse and enqueue the chunk if (data.body?.trim()) { try { const chunk = JSON.parse(data.body) as UIMessageChunk; controller.enqueue(chunk); } catch { // Skip malformed chunk bodies } } if (data.done) { finish(() => controller.close()); } } catch { // Ignore non-JSON messages } }; agent.addEventListener("message", onMessage, { signal: chunkController.signal }); }, cancel() { finish(() => {}); } }); } }