branch:
ws-chat-transport.ts
13183 bytesRaw
/**
 * WebSocket-based ChatTransport for useAgentChat.
 *
 * Replaces the aiFetch + DefaultChatTransport indirection with a direct
 * WebSocket implementation that speaks the CF_AGENT protocol natively.
 *
 * Data flow (old): WS → aiFetch fake Response → DefaultChatTransport → useChat
 * Data flow (new): WS → WebSocketChatTransport → useChat
 */

import type { ChatTransport, UIMessage, UIMessageChunk } from "ai";
import { nanoid } from "nanoid";
import { MessageType, type OutgoingMessage } from "./types";

/**
 * Agent-like interface for sending/receiving WebSocket messages.
 * Matches the shape returned by useAgent from agents/react.
 */
export interface AgentConnection {
  send: (data: string) => void;
  addEventListener: (
    type: string,
    listener: (event: MessageEvent) => void,
    options?: { signal?: AbortSignal }
  ) => void;
  removeEventListener: (
    type: string,
    listener: (event: MessageEvent) => void
  ) => void;
}

export type WebSocketChatTransportOptions<
  ChatMessage extends UIMessage = UIMessage
> = {
  /** The agent connection from useAgent */
  agent: AgentConnection;
  /**
   * Callback to prepare the request body before sending.
   * Can add custom headers, body fields, or credentials.
   */
  prepareBody?: (options: {
    messages: ChatMessage[];
    trigger: "submit-message" | "regenerate-message";
    messageId?: string;
  }) => Promise<Record<string, unknown>> | Record<string, unknown>;
  /**
   * Optional set to track active request IDs.
   * IDs are added when a request starts and removed when it completes.
   * Used by the onAgentMessage handler to skip messages already handled by the transport.
   */
  activeRequestIds?: Set<string>;
};

/**
 * ChatTransport that sends messages over WebSocket and returns a
 * ReadableStream<UIMessageChunk> that the AI SDK's useChat consumes directly.
 * No fake fetch, no Response reconstruction, no double SSE parsing.
 */
export class WebSocketChatTransport<
  ChatMessage extends UIMessage = UIMessage
> implements ChatTransport<ChatMessage> {
  agent: AgentConnection;
  private prepareBody?: WebSocketChatTransportOptions<ChatMessage>["prepareBody"];
  private activeRequestIds?: Set<string>;

  // Pending resume resolver — set by reconnectToStream, called by
  // handleStreamResuming when onAgentMessage sees CF_AGENT_STREAM_RESUMING.
  private _resumeResolver: ((data: { id: string }) => void) | null = null;
  // Pending "no stream" resolver — called by handleStreamResumeNone
  // when onAgentMessage sees CF_AGENT_STREAM_RESUME_NONE.
  private _resumeNoneResolver: (() => void) | null = null;

  constructor(options: WebSocketChatTransportOptions<ChatMessage>) {
    this.agent = options.agent;
    this.prepareBody = options.prepareBody;
    this.activeRequestIds = options.activeRequestIds;
  }

  /**
   * Called by onAgentMessage when it receives CF_AGENT_STREAM_RESUMING.
   * If reconnectToStream is waiting, this handles the resume handshake
   * (ACK + stream creation) and returns true. Otherwise returns false
   * so the caller can use its own fallback path.
   */
  handleStreamResuming(data: { id: string }): boolean {
    if (!this._resumeResolver) return false;
    this._resumeResolver(data);
    return true;
  }

  /**
   * Called by onAgentMessage when it receives CF_AGENT_STREAM_RESUME_NONE.
   * If reconnectToStream is waiting, resolves the promise with null
   * immediately (no 5-second timeout). Returns true if handled.
   */
  handleStreamResumeNone(): boolean {
    if (!this._resumeNoneResolver) return false;
    this._resumeNoneResolver();
    return true;
  }

  async sendMessages(options: {
    chatId: string;
    messages: ChatMessage[];
    abortSignal: AbortSignal | undefined;
    trigger: "submit-message" | "regenerate-message";
    messageId?: string;
    body?: object;
    headers?: Record<string, string> | Headers;
    metadata?: unknown;
  }): Promise<ReadableStream<UIMessageChunk>> {
    const requestId = nanoid(8);
    const abortController = new AbortController();
    let completed = false;

    // Build the request body
    let extraBody: Record<string, unknown> = {};
    if (this.prepareBody) {
      extraBody = await this.prepareBody({
        messages: options.messages,
        trigger: options.trigger,
        messageId: options.messageId
      });
    }
    if (options.body) {
      extraBody = {
        ...extraBody,
        ...(options.body as Record<string, unknown>)
      };
    }

    const bodyPayload = JSON.stringify({
      messages: options.messages,
      trigger: options.trigger,
      ...extraBody
    });

    // Track this request so the onAgentMessage handler skips it
    this.activeRequestIds?.add(requestId);

    // Create a ReadableStream<UIMessageChunk> that emits parsed chunks
    // as they arrive over the WebSocket
    const agent = this.agent;
    const activeIds = this.activeRequestIds;

    // Single cleanup helper — every terminal path (done, error, abort)
    // goes through here exactly once.
    // keepId: when true, do NOT remove requestId from activeIds. Used by
    // onAbort so that onAgentMessage continues to skip in-flight chunks
    // and the server's final done:true broadcast until cleanup happens there.
    const finish = (action: () => void, keepId = false) => {
      if (completed) return;
      completed = true;
      try {
        action();
      } catch {
        // Stream may already be closed
      }
      if (!keepId) {
        activeIds?.delete(requestId);
      }
      abortController.abort();
    };

    const abortError = new Error("Aborted");
    abortError.name = "AbortError";

    // Abort handler: send cancel to server, then terminate the stream.
    // Used by both the caller's abortSignal and stream.cancel().
    // keepId=true: keep requestId in activeIds so onAgentMessage skips any
    // in-flight chunks the server broadcasts before its done:true signal.
    // The ID is removed by onAgentMessage when done:true is received.
    const onAbort = () => {
      if (completed) return;
      try {
        agent.send(
          JSON.stringify({
            id: requestId,
            type: MessageType.CF_AGENT_CHAT_REQUEST_CANCEL
          })
        );
      } catch {
        // Ignore failures (e.g. agent already disconnected)
      }
      finish(() => streamController.error(abortError), true);
    };

    // streamController is assigned synchronously by start(), so it is
    // always available by the time onAbort or onMessage can fire.
    let streamController!: ReadableStreamDefaultController<UIMessageChunk>;

    const stream = new ReadableStream<UIMessageChunk>({
      start(controller) {
        streamController = controller;

        const onMessage = (event: MessageEvent) => {
          try {
            const data = JSON.parse(
              event.data as string
            ) as OutgoingMessage<ChatMessage>;

            if (data.type !== MessageType.CF_AGENT_USE_CHAT_RESPONSE) return;
            if (data.id !== requestId) return;

            if (data.error) {
              finish(() =>
                controller.error(new Error(data.body || "Stream error"))
              );
              return;
            }

            // Parse the body as UIMessageChunk and enqueue
            if (data.body?.trim()) {
              try {
                const chunk = JSON.parse(data.body) as UIMessageChunk;
                controller.enqueue(chunk);
              } catch {
                // Skip malformed chunk bodies
              }
            }

            if (data.done) {
              finish(() => controller.close());
            }
          } catch {
            // Ignore non-JSON messages
          }
        };

        agent.addEventListener("message", onMessage, {
          signal: abortController.signal
        });
      },
      cancel() {
        onAbort();
      }
    });

    // Handle abort from the caller
    if (options.abortSignal) {
      options.abortSignal.addEventListener("abort", onAbort, { once: true });
      if (options.abortSignal.aborted) onAbort();
    }

    // Send the request over WebSocket
    agent.send(
      JSON.stringify({
        id: requestId,
        init: {
          method: "POST",
          body: bodyPayload
        },
        type: MessageType.CF_AGENT_USE_CHAT_REQUEST
      })
    );

    return stream;
  }

  async reconnectToStream(_options: {
    chatId: string;
  }): Promise<ReadableStream<UIMessageChunk> | null> {
    // Detect whether the server has an active stream for this chat.
    // Instead of registering our own addEventListener listener (which
    // races with onAgentMessage), we set _resumeResolver so that
    // onAgentMessage can call handleStreamResuming() synchronously
    // when it sees CF_AGENT_STREAM_RESUMING — eliminating the race.
    const activeIds = this.activeRequestIds;

    return new Promise<ReadableStream<UIMessageChunk> | null>((resolve) => {
      let resolved = false;
      let timeout: ReturnType<typeof setTimeout> | undefined;

      const done = (value: ReadableStream<UIMessageChunk> | null) => {
        if (resolved) return;
        resolved = true;
        this._resumeResolver = null;
        this._resumeNoneResolver = null;
        if (timeout) clearTimeout(timeout);
        resolve(value);
      };

      // Set the "no stream" resolver that handleStreamResumeNone() will call.
      // When onAgentMessage sees CF_AGENT_STREAM_RESUME_NONE, it calls
      // handleStreamResumeNone() which resolves immediately with null.
      this._resumeNoneResolver = () => done(null);

      // Set the resolver that handleStreamResuming() will call.
      // When onAgentMessage sees CF_AGENT_STREAM_RESUMING, it calls
      // handleStreamResuming() which invokes this callback.
      this._resumeResolver = (data: { id: string }) => {
        const requestId = data.id;

        // Track this request so onAgentMessage skips subsequent chunks
        activeIds?.add(requestId);

        // Send ACK to server via the latest agent (the socket may
        // have been replaced since reconnectToStream was called).
        this.agent.send(
          JSON.stringify({
            type: MessageType.CF_AGENT_STREAM_RESUME_ACK,
            id: requestId
          })
        );

        // Return a ReadableStream fed by the replayed + live chunks
        done(this._createResumeStream(requestId));
      };

      // Send the resume request. PartySocket queues sends when
      // the socket isn't open yet and flushes on connect, so
      // this works regardless of current readyState.
      try {
        this.agent.send(
          JSON.stringify({
            type: MessageType.CF_AGENT_STREAM_RESUME_REQUEST
          })
        );
      } catch {
        // WebSocket may already be closed
      }

      // Safety-net timeout: if the WebSocket never connects or the
      // server is unreachable, resolve null. Under normal operation
      // the server responds with STREAM_RESUMING or STREAM_RESUME_NONE
      // well before this fires.
      timeout = setTimeout(() => done(null), 5000);
    });
  }

  /**
   * Creates a ReadableStream that receives resumed stream chunks
   * and forwards them to useChat as UIMessageChunk objects.
   */
  private _createResumeStream(
    requestId: string
  ): ReadableStream<UIMessageChunk> {
    // Read agent at resolve time (not when reconnectToStream was called)
    // so chunk listener attaches to the latest socket after _pk changes.
    const agent = this.agent;
    const activeIds = this.activeRequestIds;
    const chunkController = new AbortController();
    let completed = false;

    const finish = (action: () => void) => {
      if (completed) return;
      completed = true;
      try {
        action();
      } catch {
        // Stream may already be closed
      }
      activeIds?.delete(requestId);
      chunkController.abort();
    };

    return new ReadableStream<UIMessageChunk>({
      start(controller) {
        const onMessage = (event: MessageEvent) => {
          try {
            const data = JSON.parse(
              event.data as string
            ) as OutgoingMessage<UIMessage>;

            if (data.type !== MessageType.CF_AGENT_USE_CHAT_RESPONSE) return;
            if (data.id !== requestId) return;

            if (data.error) {
              finish(() =>
                controller.error(new Error(data.body || "Stream error"))
              );
              return;
            }

            // Parse and enqueue the chunk
            if (data.body?.trim()) {
              try {
                const chunk = JSON.parse(data.body) as UIMessageChunk;
                controller.enqueue(chunk);
              } catch {
                // Skip malformed chunk bodies
              }
            }

            if (data.done) {
              finish(() => controller.close());
            }
          } catch {
            // Ignore non-JSON messages
          }
        };

        agent.addEventListener("message", onMessage, {
          signal: chunkController.signal
        });
      },
      cancel() {
        finish(() => {});
      }
    });
  }
}