branch:
transport.ts
8263 bytesRaw
/**
 * AgentChatTransport — bridges the AI SDK's useChat hook with an Agent
 * WebSocket connection that speaks Think's streaming protocol.
 *
 * Features:
 *   - Request ID correlation: each request gets a unique ID, only matching
 *     WS messages are processed
 *   - Cancel: sends { type: "cancel", requestId } to stop server-side streaming
 *   - Completion guard: close/error/abort are idempotent
 *   - Signal-based cleanup: uses AbortController signal on addEventListener
 *   - Stream resumption: reconnectToStream sends resume-request, server replays
 *     buffered chunks via ChunkRelay
 *
 * @example
 * ```tsx
 * import { AgentChatTransport } from "@cloudflare/think/transport";
 * import { useAgent } from "agents/react";
 * import { useChat } from "@ai-sdk/react";
 *
 * const agent = useAgent({ agent: "MyAssistant" });
 * const transport = useMemo(() => new AgentChatTransport(agent), [agent]);
 * const { messages, sendMessage, status } = useChat({ transport });
 * ```
 */

import type { UIMessage, UIMessageChunk, ChatTransport } from "ai";

/**
 * Minimal interface for the agent connection object.
 * Satisfied by the return value of `useAgent()` from `agents/react`.
 */
export interface AgentSocket {
  addEventListener(
    type: "message",
    handler: (event: MessageEvent) => void,
    options?: { signal?: AbortSignal }
  ): void;
  removeEventListener(
    type: "message",
    handler: (event: MessageEvent) => void
  ): void;
  call(method: string, args?: unknown[]): Promise<unknown>;
  send(data: string): void;
}

/**
 * Options for constructing an AgentChatTransport.
 */
export interface AgentChatTransportOptions {
  /**
   * The server-side RPC method to call when sending a message.
   * Receives `[text, requestId]` as arguments.
   * @default "sendMessage"
   */
  sendMethod?: string;

  /**
   * Timeout in milliseconds for reconnectToStream to wait for a
   * stream-resuming response before giving up.
   * @default 500
   */
  resumeTimeout?: number;
}

/**
 * Extract the text content from a UIMessage's parts.
 */
function getMessageText(msg: UIMessage): string {
  return msg.parts
    .filter((p): p is { type: "text"; text: string } => p.type === "text")
    .map((p) => p.text)
    .join("");
}

/**
 * ChatTransport implementation for Agent WebSocket connections.
 *
 * Speaks the wire protocol used by Think's `chat()` method
 * and ChunkRelay on the server:
 *   - `stream-start`   → new stream with requestId
 *   - `stream-event`   → UIMessageChunk payload
 *   - `stream-done`    → stream complete
 *   - `stream-resuming` → replay after reconnect
 *   - `cancel`         → client→server abort
 */
export class AgentChatTransport implements ChatTransport<UIMessage> {
  #agent: AgentSocket;
  #activeRequestIds = new Set<string>();
  #currentFinish: (() => void) | null = null;
  #sendMethod: string;
  #resumeTimeout: number;

  constructor(agent: AgentSocket, options?: AgentChatTransportOptions) {
    this.#agent = agent;
    this.#sendMethod = options?.sendMethod ?? "sendMessage";
    this.#resumeTimeout = options?.resumeTimeout ?? 500;
  }

  /**
   * Detach from the current stream. Call this before switching agents
   * or cleaning up to ensure the stream controller is closed.
   */
  detach() {
    this.#currentFinish?.();
    this.#currentFinish = null;
  }

  async sendMessages({
    messages,
    abortSignal
  }: Parameters<ChatTransport<UIMessage>["sendMessages"]>[0]): Promise<
    ReadableStream<UIMessageChunk>
  > {
    const lastMessage = messages[messages.length - 1];
    const text = getMessageText(lastMessage);
    const requestId = crypto.randomUUID().slice(0, 8);

    let completed = false;
    const abortController = new AbortController();
    let streamController!: ReadableStreamDefaultController<UIMessageChunk>;

    const finish = (action: () => void) => {
      if (completed) return;
      completed = true;
      this.#currentFinish = null;
      try {
        action();
      } catch {
        /* stream may already be closed */
      }
      this.#activeRequestIds.delete(requestId);
      abortController.abort();
    };

    this.#currentFinish = () => finish(() => streamController.close());

    const onAbort = () => {
      if (completed) return;
      try {
        this.#agent.send(JSON.stringify({ type: "cancel", requestId }));
      } catch {
        /* ignore send failures */
      }
      finish(() =>
        streamController.error(
          Object.assign(new Error("Aborted"), { name: "AbortError" })
        )
      );
    };

    const stream = new ReadableStream<UIMessageChunk>({
      start(controller) {
        streamController = controller;
      },
      cancel() {
        onAbort();
      }
    });

    this.#agent.addEventListener(
      "message",
      (event: MessageEvent) => {
        if (typeof event.data !== "string") return;
        try {
          const msg = JSON.parse(event.data);
          if (msg.requestId !== requestId) return;
          if (msg.type === "stream-event") {
            const chunk: UIMessageChunk = JSON.parse(msg.event);
            streamController.enqueue(chunk);
          } else if (msg.type === "stream-done") {
            finish(() => streamController.close());
          }
        } catch {
          /* ignore parse errors */
        }
      },
      { signal: abortController.signal }
    );

    if (abortSignal) {
      abortSignal.addEventListener("abort", onAbort, { once: true });
      if (abortSignal.aborted) onAbort();
    }

    this.#activeRequestIds.add(requestId);

    this.#agent
      .call(this.#sendMethod, [text, requestId])
      .catch((error: Error) => {
        finish(() => streamController.error(error));
      });

    return stream;
  }

  async reconnectToStream(): Promise<ReadableStream<UIMessageChunk> | null> {
    const resumeTimeout = this.#resumeTimeout;

    return new Promise<ReadableStream<UIMessageChunk> | null>((resolve) => {
      let resolved = false;
      let timeout: ReturnType<typeof setTimeout> | undefined;

      const done = (value: ReadableStream<UIMessageChunk> | null) => {
        if (resolved) return;
        resolved = true;
        if (timeout) clearTimeout(timeout);
        this.#agent.removeEventListener("message", handler);
        resolve(value);
      };

      const handler = (event: MessageEvent) => {
        if (typeof event.data !== "string") return;
        try {
          const msg = JSON.parse(event.data);
          if (msg.type === "stream-resuming") {
            done(this.#createResumeStream(msg.requestId));
          }
        } catch {
          /* ignore */
        }
      };

      this.#agent.addEventListener("message", handler);

      try {
        this.#agent.send(JSON.stringify({ type: "resume-request" }));
      } catch {
        /* WebSocket may not be open yet */
      }

      timeout = setTimeout(() => done(null), resumeTimeout);
    });
  }

  #createResumeStream(requestId: string): ReadableStream<UIMessageChunk> {
    const abortController = new AbortController();
    let completed = false;

    const finish = (action: () => void) => {
      if (completed) return;
      completed = true;
      try {
        action();
      } catch {
        /* stream may already be closed */
      }
      this.#activeRequestIds.delete(requestId);
      abortController.abort();
    };

    this.#activeRequestIds.add(requestId);

    return new ReadableStream<UIMessageChunk>({
      start: (controller) => {
        this.#agent.addEventListener(
          "message",
          (event: MessageEvent) => {
            if (typeof event.data !== "string") return;
            try {
              const msg = JSON.parse(event.data);
              if (msg.requestId !== requestId) return;
              if (msg.type === "stream-event") {
                const chunk: UIMessageChunk = JSON.parse(msg.event);
                controller.enqueue(chunk);
              } else if (msg.type === "stream-done") {
                finish(() => controller.close());
              }
            } catch {
              /* ignore */
            }
          },
          { signal: abortController.signal }
        );
      },
      cancel() {
        finish(() => {});
      }
    });
  }
}