branch:
server.ts
8926 bytesRaw
import { createWorkersAI } from "workers-ai-provider";
import { callable, routeAgentRequest } from "agents";
import { AIChatAgent } from "@cloudflare/ai-chat";
import {
convertToModelMessages,
type StreamTextOnFinishCallback,
stepCountIs,
streamText,
type ToolSet,
type LanguageModel
} from "ai";
import { cleanupMessages } from "./utils";
import { nanoid } from "nanoid";
import { createAiGateway } from "ai-gateway-provider";
import { createOpenAI as createOpenAIGateway } from "ai-gateway-provider/providers/openai";
import { createAnthropic as createAnthropicGateway } from "ai-gateway-provider/providers/anthropic";
import { createGoogleGenerativeAI as createGoogleGateway } from "ai-gateway-provider/providers/google";
import { createOpenAI } from "@ai-sdk/openai";
import { createAnthropic } from "@ai-sdk/anthropic";
import { createGoogleGenerativeAI } from "@ai-sdk/google";
import { env } from "cloudflare:workers";
// Gateway is now created per-request with accountId and gatewayId from state
export interface PlaygroundState {
model: string;
temperature: number;
stream: boolean;
system: string;
// External provider models mode
useExternalProvider?: boolean;
externalProvider?: "openai" | "anthropic" | "google" | "xai";
externalModel?: string;
authMethod?: "provider-key" | "gateway";
// Provider key auth (BYOK)
providerApiKey?: string;
// Gateway auth (Unified Billing)
gatewayAccountId?: string;
gatewayId?: string;
gatewayApiKey?: string;
}
/**
* Chat Agent implementation that handles real-time AI chat interactions
*/
export class Playground extends AIChatAgent<Env, PlaygroundState> {
initialState: PlaygroundState = {
model: "@cf/moonshotai/kimi-k2.5",
temperature: 1,
stream: true,
system:
"You are a helpful assistant that can do various tasks using MCP tools.",
useExternalProvider: false,
externalProvider: "openai",
authMethod: "provider-key"
};
onStart() {
this.mcp.configureOAuthCallback({
customHandler: () => {
return new Response("<script>window.close();</script>", {
headers: { "content-type": "text/html" },
status: 200
});
}
});
}
/**
* Handles incoming chat messages and manages the response stream
*/
async onChatMessage(
onFinish: StreamTextOnFinishCallback<ToolSet>,
_options?: { abortSignal?: AbortSignal }
) {
const workersAi = createWorkersAI({
binding: env.AI,
gateway: {
id: "playground"
}
});
let tools: ToolSet = {};
try {
tools = this.mcp.getAITools();
} catch (e) {
console.error("Failed to get AI tools", e);
}
await this.ensureDestroy();
// Clean up incomplete tool calls to prevent API errors
const cleanedMessages = cleanupMessages(this.messages);
// Determine which model provider to use
let modelProvider: LanguageModel;
if (
this.state.useExternalProvider &&
this.state.externalProvider &&
this.state.externalModel
) {
// Extract model name from provider/model format (e.g., "openai/gpt-5.2" -> "gpt-5.2")
let modelName = this.state.externalModel;
if (modelName.includes("/")) {
modelName = modelName.split("/")[1];
}
if (
this.state.authMethod === "gateway" &&
this.state.gatewayAccountId &&
this.state.gatewayId &&
this.state.gatewayApiKey
) {
// Use AI Gateway with unified billing
const gateway = createAiGateway({
accountId: this.state.gatewayAccountId,
gateway: this.state.gatewayId,
apiKey: this.state.gatewayApiKey
});
let baseModel: LanguageModel;
if (this.state.externalProvider === "openai") {
const openai = createOpenAIGateway();
baseModel = openai.chat(modelName);
} else if (this.state.externalProvider === "anthropic") {
const anthropic = createAnthropicGateway();
baseModel = anthropic.chat(modelName);
} else if (this.state.externalProvider === "google") {
const google = createGoogleGateway();
baseModel = google.chat(modelName);
} else if (this.state.externalProvider === "xai") {
const openai = createOpenAIGateway();
baseModel = openai.chat(modelName);
} else {
const fallbackModel = this.state.model as Parameters<
typeof workersAi
>[0];
baseModel = workersAi(fallbackModel, {
sessionAffinity: this.sessionAffinity
});
}
modelProvider = gateway(baseModel);
} else if (
this.state.authMethod === "provider-key" &&
this.state.providerApiKey
) {
// Use provider SDK directly with user's API key (BYOK)
if (this.state.externalProvider === "openai") {
const openai = createOpenAI({
apiKey: this.state.providerApiKey
});
modelProvider = openai(modelName);
} else if (this.state.externalProvider === "anthropic") {
const anthropic = createAnthropic({
apiKey: this.state.providerApiKey
});
modelProvider = anthropic(modelName);
} else if (this.state.externalProvider === "google") {
const google = createGoogleGenerativeAI({
apiKey: this.state.providerApiKey
});
modelProvider = google(modelName);
} else if (this.state.externalProvider === "xai") {
const xai = createOpenAI({
apiKey: this.state.providerApiKey,
baseURL: "https://api.x.ai/v1"
});
modelProvider = xai(modelName);
} else {
modelProvider = workersAi(
this.state.model as Parameters<typeof workersAi>[0],
{ sessionAffinity: this.sessionAffinity }
);
}
} else {
// Missing required auth, fallback to Workers AI
modelProvider = workersAi(
this.state.model as Parameters<typeof workersAi>[0],
{ sessionAffinity: this.sessionAffinity }
);
}
} else {
// Use Workers AI (default)
modelProvider = workersAi(
this.state.model as Parameters<typeof workersAi>[0],
{ sessionAffinity: this.sessionAffinity }
);
}
const result = streamText({
system: this.state.system,
messages: await convertToModelMessages(cleanedMessages),
model: modelProvider,
tools,
onFinish: onFinish as unknown as StreamTextOnFinishCallback<typeof tools>,
temperature: this.state.temperature,
stopWhen: stepCountIs(10)
});
return result.toUIMessageStreamResponse();
}
async ensureDestroy() {
const schedules = this.getSchedules().filter(
(s) => s.callback === "destroy"
);
if (schedules.length > 0) {
// Cancel previously set destroy schedules
for (const s of schedules) {
await this.cancelSchedule(s.id);
}
}
// Destroy after 15 minutes of inactivity
await this.schedule(60 * 15, "destroy");
}
@callable()
async connectMCPServer(url: string, headers?: Record<string, string>) {
const { servers } = await this.getMcpServers();
// Check for duplicate URL
const existingServer = Object.values(servers).find(
(server) => server.server_url === url
);
if (existingServer) {
throw new Error(`Server with URL "${url}" is already connected`);
}
// Generate unique server ID
const serverId = `mcp-${nanoid(8)}`;
if (!headers) {
return await this.addMcpServer(serverId, url, {
callbackHost: this.env.HOST
});
}
return await this.addMcpServer(serverId, url, {
callbackHost: this.env.HOST,
transport: {
type: "auto",
headers
}
});
}
@callable()
async disconnectMCPServer(serverId?: string) {
if (serverId) {
// Disconnect specific server
await this.removeMcpServer(serverId);
} else {
// Disconnect all servers if no serverId provided
const { servers } = await this.getMcpServers();
for (const id of Object.keys(servers)) {
await this.removeMcpServer(id);
}
}
}
@callable()
async refreshMcpTools(serverId: string) {
await this.mcp.discoverIfConnected(serverId);
}
@callable()
async getModels() {
// TODO: get finetunes when the binding supports finetunes.public.list endpoint
return await this.env.AI.models({ per_page: 1000 });
}
onStateChanged() {}
}
/**
* Worker entry point that routes incoming requests to the appropriate handler
*/
export default {
async fetch(request: Request, env: Env, _ctx: ExecutionContext) {
return (
(await routeAgentRequest(request, env)) ||
new Response("Not found", { status: 404 })
);
}
} satisfies ExportedHandler<Env>;