branch:
assistant-agent-loop.test.ts
11365 bytesRaw
import { env, exports } from "cloudflare:workers";
import { describe, expect, it } from "vitest";
import { getAgentByName } from "agents";
import type { UIMessage } from "ai";
import type { Session } from "../session/index";

// ── Wire protocol constants (must match agent.ts) ─────────────────
const MSG_CHAT_MESSAGES = "cf_agent_chat_messages";
const MSG_CHAT_REQUEST = "cf_agent_use_chat_request";
const MSG_CHAT_RESPONSE = "cf_agent_use_chat_response";

// ── Helpers ────────────────────────────────────────────────────────

function kebab(className: string): string {
  return className
    .replace(/([a-z])([A-Z])/g, "$1-$2")
    .replace(/([A-Z]+)([A-Z][a-z])/g, "$1-$2")
    .toLowerCase();
}

async function connectWS(agentClass: string, room: string) {
  const slug = kebab(agentClass);
  const res = await exports.default.fetch(
    `http://example.com/agents/${slug}/${room}`,
    {
      headers: { Upgrade: "websocket" }
    }
  );
  expect(res.status).toBe(101);
  const ws = res.webSocket as WebSocket;
  expect(ws).toBeDefined();
  ws.accept();
  return { ws };
}

function collectMessages(
  ws: WebSocket,
  count: number,
  timeout = 5000
): Promise<Array<Record<string, unknown>>> {
  return new Promise((resolve) => {
    const messages: Array<Record<string, unknown>> = [];
    const timer = setTimeout(() => resolve(messages), timeout);
    const handler = (e: MessageEvent) => {
      try {
        messages.push(JSON.parse(e.data as string) as Record<string, unknown>);
        if (messages.length >= count) {
          clearTimeout(timer);
          ws.removeEventListener("message", handler);
          resolve(messages);
        }
      } catch {
        // ignore parse errors
      }
    };
    ws.addEventListener("message", handler);
  });
}

function waitForDone(
  ws: WebSocket,
  timeout = 10000
): Promise<Array<Record<string, unknown>>> {
  return new Promise((resolve, reject) => {
    const messages: Array<Record<string, unknown>> = [];
    const timer = setTimeout(
      () => reject(new Error("Timeout waiting for done")),
      timeout
    );
    const handler = (e: MessageEvent) => {
      try {
        const msg = JSON.parse(e.data as string) as Record<string, unknown>;
        messages.push(msg);
        if (msg.type === MSG_CHAT_RESPONSE && msg.done === true) {
          clearTimeout(timer);
          ws.removeEventListener("message", handler);
          resolve(messages);
        }
      } catch {
        // ignore parse errors
      }
    };
    ws.addEventListener("message", handler);
  });
}

function closeWS(ws: WebSocket): Promise<void> {
  return new Promise((resolve) => {
    const timer = setTimeout(resolve, 200);
    ws.addEventListener(
      "close",
      () => {
        clearTimeout(timer);
        resolve();
      },
      { once: true }
    );
    ws.close();
  });
}

function sendChatRequest(ws: WebSocket, text: string, requestId?: string) {
  const id = requestId ?? crypto.randomUUID();
  const userMessage: UIMessage = {
    id: crypto.randomUUID(),
    role: "user",
    parts: [{ type: "text", text }]
  };
  ws.send(
    JSON.stringify({
      type: MSG_CHAT_REQUEST,
      id,
      init: {
        method: "POST",
        body: JSON.stringify({ messages: [userMessage] })
      }
    })
  );
  return { id, userMessage };
}

// ── Tests ─────────────────────────────────────────────────────────

describe("AssistantAgent — agentic loop", () => {
  describe("getModel() error", () => {
    it("returns an error when getModel is not overridden", async () => {
      const room = crypto.randomUUID();
      const { ws } = await connectWS("BareAssistantAgent", room);

      // Create a session first via RPC
      const agent = await getAgentByName(env.BareAssistantAgent, room);
      await (
        agent as unknown as { createSession(n: string): Promise<Session> }
      ).createSession("test");

      // Drain the session broadcast
      await collectMessages(ws, 1, 500);

      // Send a chat message — should get an error response
      const done = waitForDone(ws);
      sendChatRequest(ws, "hello");
      const messages = await done;

      // The last message should be a done message with error
      const errorMsg = messages.find(
        (m) =>
          m.type === MSG_CHAT_RESPONSE && m.done === true && m.error === true
      );
      expect(errorMsg).toBeDefined();
      expect(errorMsg!.body).toContain("getModel");

      await closeWS(ws);
    });
  });

  describe("default loop — text only", () => {
    it("streams a response using the mock model", async () => {
      const room = crypto.randomUUID();
      const { ws } = await connectWS("LoopTestAgent", room);

      // Create session via RPC
      const agent = await getAgentByName(env.LoopTestAgent, room);
      const rpc = agent as unknown as {
        createSession(n: string): Promise<Session>;
        getMessages(): Promise<UIMessage[]>;
        getSessionHistory(id: string): Promise<UIMessage[]>;
        getSessions(): Promise<Session[]>;
      };
      await rpc.createSession("loop-test");
      await collectMessages(ws, 1, 500);

      // Send chat and wait for done
      const done = waitForDone(ws);
      sendChatRequest(ws, "Say hi");
      const messages = await done;

      // Should have response chunks and a done message
      const responseChunks = messages.filter(
        (m) => m.type === MSG_CHAT_RESPONSE && m.done === false
      );
      expect(responseChunks.length).toBeGreaterThan(0);

      // Verify the stream contains text content
      const bodies = responseChunks
        .map((m) => m.body as string)
        .filter(Boolean);
      const hasText = bodies.some((b) => {
        try {
          const parsed = JSON.parse(b) as Record<string, unknown>;
          return parsed.type === "text-delta" || parsed.type === "text-start";
        } catch {
          return false;
        }
      });
      expect(hasText).toBe(true);

      await closeWS(ws);
    });

    it("persists assistant message after streaming", async () => {
      const room = crypto.randomUUID();
      const { ws } = await connectWS("LoopTestAgent", room);

      const agent = await getAgentByName(env.LoopTestAgent, room);
      const rpc = agent as unknown as {
        createSession(n: string): Promise<Session>;
        getMessages(): Promise<UIMessage[]>;
        getSessionHistory(id: string): Promise<UIMessage[]>;
        getSessions(): Promise<Session[]>;
      };
      const session = (await rpc.createSession(
        "persist-test"
      )) as unknown as Session;
      await collectMessages(ws, 1, 500);

      // Send chat and wait for completion
      const done = waitForDone(ws);
      sendChatRequest(ws, "Hello");
      await done;

      // Wait for the cf_agent_chat_messages broadcast after persistence
      const postStream = await collectMessages(ws, 1, 3000);
      const chatMsgs = postStream.find((m) => m.type === MSG_CHAT_MESSAGES);

      // If no broadcast arrived, check via RPC
      if (!chatMsgs) {
        const history = (await rpc.getSessionHistory(
          session.id
        )) as unknown as UIMessage[];
        // Should have user + assistant
        expect(history.length).toBeGreaterThanOrEqual(2);
        const assistantMsg = history.find((m) => m.role === "assistant");
        expect(assistantMsg).toBeDefined();
      } else {
        const msgs = chatMsgs.messages as UIMessage[];
        expect(msgs.length).toBeGreaterThanOrEqual(2);
        const assistantMsg = msgs.find((m) => m.role === "assistant");
        expect(assistantMsg).toBeDefined();
      }

      await closeWS(ws);
    });
  });

  describe("default loop — with tools", () => {
    it("executes a tool and returns text after", async () => {
      const room = crypto.randomUUID();
      const { ws } = await connectWS("LoopToolTestAgent", room);

      const agent = await getAgentByName(env.LoopToolTestAgent, room);
      const rpc = agent as unknown as {
        createSession(n: string): Promise<Session>;
        getMessages(): Promise<UIMessage[]>;
        getSessionHistory(id: string): Promise<UIMessage[]>;
        getSessions(): Promise<Session[]>;
      };
      await rpc.createSession("tool-test");
      await collectMessages(ws, 1, 500);

      // Send chat and wait for done
      const done = waitForDone(ws, 15000);
      sendChatRequest(ws, "Use the echo tool");
      const messages = await done;

      // Should have response chunks
      const responseChunks = messages.filter(
        (m) => m.type === MSG_CHAT_RESPONSE && m.done === false
      );
      expect(responseChunks.length).toBeGreaterThan(0);

      // After completion, check persisted messages
      await collectMessages(ws, 1, 2000);

      const sessions = (await rpc.getSessions()) as unknown as Session[];
      const history = (await rpc.getSessionHistory(
        sessions[0].id
      )) as unknown as UIMessage[];

      // Should have at least user + assistant messages
      expect(history.length).toBeGreaterThanOrEqual(2);

      await closeWS(ws);
    });

    it("custom getMaxSteps is respected", async () => {
      const room = crypto.randomUUID();
      const agent = await getAgentByName(env.LoopToolTestAgent, room);
      const rpc = agent as unknown as {
        createSession(n: string): Promise<Session>;
      };
      await rpc.createSession("steps-test");

      // LoopToolTestAgent has getMaxSteps() = 3
      const { ws } = await connectWS("LoopToolTestAgent", room);

      // Drain session switch
      await collectMessages(ws, 1, 500);

      const done = waitForDone(ws, 15000);
      sendChatRequest(ws, "test step limit");
      const messages = await done;

      // Should complete without timeout (step limit prevents runaway)
      const doneMsg = messages.find(
        (m) => m.type === MSG_CHAT_RESPONSE && m.done === true
      );
      expect(doneMsg).toBeDefined();

      await closeWS(ws);
    });
  });

  describe("assembleContext", () => {
    it("converts messages to model format", async () => {
      const room = crypto.randomUUID();
      const { ws } = await connectWS("LoopTestAgent", room);

      const agent = await getAgentByName(env.LoopTestAgent, room);
      const rpc = agent as unknown as {
        createSession(n: string): Promise<Session>;
        getMessages(): Promise<UIMessage[]>;
      };
      await rpc.createSession("context-test");
      await collectMessages(ws, 1, 500);

      // Send a message and let it complete
      const done = waitForDone(ws);
      sendChatRequest(ws, "Hello for context test");
      await done;

      // Wait for persistence
      await collectMessages(ws, 1, 2000);

      // Verify messages were persisted correctly
      const msgs = (await rpc.getMessages()) as unknown as UIMessage[];
      expect(msgs.length).toBeGreaterThanOrEqual(2);

      // User message should be present
      const userMsg = msgs.find((m) => m.role === "user");
      expect(userMsg).toBeDefined();
      expect(userMsg!.parts).toBeDefined();

      await closeWS(ws);
    });
  });
});