branch:
workers-ai-providers.ts
16455 bytesRaw
/**
 * Workers AI provider implementations for the voice pipeline.
 *
 * These are convenience classes that wrap the Workers AI binding
 * (env.AI) for STT, TTS, and VAD. They are not required — any
 * object satisfying the provider interfaces works.
 */

import type {
  STTProvider,
  TTSProvider,
  VADProvider,
  StreamingSTTProvider,
  StreamingSTTSession,
  StreamingSTTSessionOptions
} from "./types";

// --- Audio utilities ---

function toStream(buffer: ArrayBuffer): ReadableStream<Uint8Array> {
  return new ReadableStream({
    start(controller) {
      controller.enqueue(new Uint8Array(buffer));
      controller.close();
    }
  });
}

function writeString(view: DataView, offset: number, str: string) {
  for (let i = 0; i < str.length; i++) {
    view.setUint8(offset + i, str.charCodeAt(i));
  }
}

/** Convert raw PCM audio to WAV format. Exported for custom providers. */
export function pcmToWav(
  pcmData: ArrayBuffer,
  sampleRate: number,
  channels: number,
  bitsPerSample: number
): ArrayBuffer {
  const byteRate = (sampleRate * channels * bitsPerSample) / 8;
  const blockAlign = (channels * bitsPerSample) / 8;
  const dataSize = pcmData.byteLength;
  const headerSize = 44;
  const buffer = new ArrayBuffer(headerSize + dataSize);
  const view = new DataView(buffer);

  writeString(view, 0, "RIFF");
  view.setUint32(4, 36 + dataSize, true);
  writeString(view, 8, "WAVE");
  writeString(view, 12, "fmt ");
  view.setUint32(16, 16, true);
  view.setUint16(20, 1, true);
  view.setUint16(22, channels, true);
  view.setUint32(24, sampleRate, true);
  view.setUint32(28, byteRate, true);
  view.setUint16(32, blockAlign, true);
  view.setUint16(34, bitsPerSample, true);
  writeString(view, 36, "data");
  view.setUint32(40, dataSize, true);
  new Uint8Array(buffer, headerSize).set(new Uint8Array(pcmData));

  return buffer;
}

// --- Loose AI binding type ---

/** Loose type for the Workers AI binding — avoids hard dependency on @cloudflare/workers-types. */
interface AiLike {
  run(
    model: string,
    input: Record<string, unknown>,
    options?: Record<string, unknown>
  ): Promise<unknown>;
}

// --- STT ---

export interface WorkersAISTTOptions {
  /** STT model name. @default "@cf/deepgram/nova-3" */
  model?: string;
  /** Language code (e.g. "en", "es", "fr"). @default "en" */
  language?: string;
}

/**
 * Workers AI speech-to-text provider.
 *
 * @example
 * ```ts
 * class MyAgent extends VoiceAgent<Env> {
 *   stt = new WorkersAISTT(this.env.AI);
 * }
 * ```
 */
export class WorkersAISTT implements STTProvider {
  #ai: AiLike;
  #model: string;
  #language: string;

  constructor(ai: AiLike, options?: WorkersAISTTOptions) {
    this.#ai = ai;
    this.#model = options?.model ?? "@cf/deepgram/nova-3";
    this.#language = options?.language ?? "en";
  }

  async transcribe(
    audioData: ArrayBuffer,
    signal?: AbortSignal
  ): Promise<string> {
    const wavBuffer = pcmToWav(audioData, 16000, 1, 16);
    const result = (await this.#ai.run(
      this.#model,
      {
        audio: {
          body: toStream(wavBuffer),
          contentType: "audio/wav"
        },
        language: this.#language,
        punctuate: true,
        smart_format: true
      },
      signal ? { signal } : undefined
    )) as {
      results?: {
        channels?: Array<{
          alternatives?: Array<{
            transcript?: string;
          }>;
        }>;
      };
    };

    return result?.results?.channels?.[0]?.alternatives?.[0]?.transcript ?? "";
  }
}

// --- TTS ---

export interface WorkersAITTSOptions {
  /** TTS model name. @default "@cf/deepgram/aura-1" */
  model?: string;
  /** TTS speaker voice. @default "asteria" */
  speaker?: string;
}

/**
 * Workers AI text-to-speech provider.
 *
 * @example
 * ```ts
 * class MyAgent extends VoiceAgent<Env> {
 *   tts = new WorkersAITTS(this.env.AI);
 * }
 * ```
 */
export class WorkersAITTS implements TTSProvider {
  #ai: AiLike;
  #model: string;
  #speaker: string;

  constructor(ai: AiLike, options?: WorkersAITTSOptions) {
    this.#ai = ai;
    this.#model = options?.model ?? "@cf/deepgram/aura-1";
    this.#speaker = options?.speaker ?? "asteria";
  }

  async synthesize(
    text: string,
    signal?: AbortSignal
  ): Promise<ArrayBuffer | null> {
    const response = (await this.#ai.run(
      this.#model,
      { text, speaker: this.#speaker },
      { returnRawResponse: true, ...(signal ? { signal } : {}) }
    )) as Response;

    return await response.arrayBuffer();
  }
}

// --- Streaming STT (Flux) ---

export interface WorkersAIFluxSTTOptions {
  /** End-of-turn confidence threshold (0.5-0.9). @default 0.7 */
  eotThreshold?: number;
  /**
   * Eager end-of-turn threshold (0.3-0.9). When set, enables
   * EagerEndOfTurn and TurnResumed events for speculative processing.
   */
  eagerEotThreshold?: number;
  /** EOT timeout in milliseconds. @default 5000 */
  eotTimeoutMs?: number;
  /** Keyterms to boost recognition of specialized terminology. */
  keyterms?: string[];
  /** Sample rate in Hz. @default 16000 */
  sampleRate?: number;
}

/**
 * Workers AI streaming speech-to-text provider using the Flux model.
 *
 * Flux is a conversational STT model with built-in end-of-turn detection.
 * It transcribes audio incrementally via a WebSocket connection to the
 * Workers AI binding — no external API key required.
 *
 * When using Flux, the separate VAD provider is optional — Flux detects
 * end-of-turn natively. Client-side silence detection still triggers the
 * pipeline, but the server-side VAD call can be skipped for lower latency.
 *
 * @example
 * ```ts
 * import { Agent } from "agents";
 * import { withVoice, WorkersAIFluxSTT, WorkersAITTS } from "agents/experimental/voice";
 *
 * const VoiceAgent = withVoice(Agent);
 *
 * class MyAgent extends VoiceAgent<Env> {
 *   streamingStt = new WorkersAIFluxSTT(this.env.AI);
 *   tts = new WorkersAITTS(this.env.AI);
 *   // No VAD needed — Flux handles turn detection
 *
 *   async onTurn(transcript, context) { ... }
 * }
 * ```
 */
export class WorkersAIFluxSTT implements StreamingSTTProvider {
  #ai: AiLike;
  #sampleRate: number;
  #eotThreshold: number | undefined;
  #eagerEotThreshold: number | undefined;
  #eotTimeoutMs: number | undefined;
  #keyterms: string[] | undefined;

  constructor(ai: AiLike, options?: WorkersAIFluxSTTOptions) {
    this.#ai = ai;
    this.#sampleRate = options?.sampleRate ?? 16000;
    this.#eotThreshold = options?.eotThreshold;
    this.#eagerEotThreshold = options?.eagerEotThreshold;
    this.#eotTimeoutMs = options?.eotTimeoutMs;
    this.#keyterms = options?.keyterms;
  }

  createSession(options?: StreamingSTTSessionOptions): StreamingSTTSession {
    return new FluxSTTSession(
      this.#ai,
      {
        sampleRate: this.#sampleRate,
        eotThreshold: this.#eotThreshold,
        eagerEotThreshold: this.#eagerEotThreshold,
        eotTimeoutMs: this.#eotTimeoutMs,
        keyterms: this.#keyterms
      },
      options
    );
  }
}

interface FluxSessionConfig {
  sampleRate: number;
  eotThreshold?: number;
  eagerEotThreshold?: number;
  eotTimeoutMs?: number;
  keyterms?: string[];
}

interface FluxEvent {
  event:
    | "Update"
    | "StartOfTurn"
    | "EagerEndOfTurn"
    | "TurnResumed"
    | "EndOfTurn";
  transcript?: string;
  end_of_turn_confidence?: number;
}

/**
 * A single streaming STT session backed by a Flux WebSocket via env.AI.
 *
 * Lifecycle: created at start-of-speech, receives audio via feed(),
 * flushed via finish() at end-of-speech, or aborted on interrupt.
 */
class FluxSTTSession implements StreamingSTTSession {
  #onInterim: ((text: string) => void) | undefined;
  #onFinal: ((text: string) => void) | undefined;
  #onEndOfTurn: ((text: string) => void) | undefined;

  #ws: WebSocket | null = null;
  #connected = false;
  #aborted = false;

  // Audio chunks queued before the WebSocket is open
  #pendingChunks: ArrayBuffer[] = [];

  // Latest transcript from Update events (may still change)
  #latestTranscript = "";

  // Transcript from EndOfTurn event (stable)
  #endOfTurnTranscript: string | null = null;

  // finish() state
  #finishing = false;
  #finishResolve: ((transcript: string) => void) | null = null;
  #finishPromise: Promise<string> | null = null;
  #finishTimeout: ReturnType<typeof setTimeout> | null = null;

  constructor(
    ai: AiLike,
    config: FluxSessionConfig,
    options?: StreamingSTTSessionOptions
  ) {
    this.#onInterim = options?.onInterim;
    this.#onFinal = options?.onFinal;
    this.#onEndOfTurn = options?.onEndOfTurn;
    this.#connect(ai, config);
  }

  async #connect(ai: AiLike, config: FluxSessionConfig): Promise<void> {
    try {
      const input: Record<string, unknown> = {
        encoding: "linear16",
        sample_rate: String(config.sampleRate)
      };
      if (config.eotThreshold != null)
        input.eot_threshold = String(config.eotThreshold);
      if (config.eagerEotThreshold != null)
        input.eager_eot_threshold = String(config.eagerEotThreshold);
      if (config.eotTimeoutMs != null)
        input.eot_timeout_ms = String(config.eotTimeoutMs);
      if (config.keyterms?.length) input.keyterm = config.keyterms[0];

      const resp = await ai.run("@cf/deepgram/flux", input, {
        websocket: true
      });

      if (this.#aborted) {
        const ws = (resp as { webSocket?: WebSocket }).webSocket;
        if (ws) {
          ws.accept();
          ws.close();
        }
        return;
      }

      const ws = (resp as { webSocket?: WebSocket }).webSocket;
      if (!ws) {
        console.error("[FluxSTT] Failed to establish WebSocket connection");
        this.#resolveFinish();
        return;
      }

      ws.accept();
      this.#ws = ws;
      this.#connected = true;

      ws.addEventListener("message", (event: MessageEvent) => {
        this.#handleMessage(event);
      });

      ws.addEventListener("close", () => {
        this.#clearFinishTimeout();
        this.#connected = false;
        this.#resolveFinish();
      });

      ws.addEventListener("error", (event: Event) => {
        console.error("[FluxSTT] WebSocket error:", event);
        this.#connected = false;
        this.#resolveFinish();
      });

      // Flush any audio chunks that arrived before the WS was open
      for (const chunk of this.#pendingChunks) {
        ws.send(chunk);
      }
      this.#pendingChunks = [];

      // If finish() was called while we were connecting, start the
      // finish timeout instead of closing immediately. This gives Flux
      // time to process the audio we just flushed.
      if (this.#finishing) {
        this.#startFinishTimeout();
      }
    } catch (err) {
      console.error("[FluxSTT] Connection error:", err);
      this.#resolveFinish();
    }
  }

  feed(chunk: ArrayBuffer): void {
    if (this.#aborted || this.#finishing) return;

    if (this.#connected && this.#ws) {
      this.#ws.send(chunk);
    } else {
      // Queue until connected
      this.#pendingChunks.push(chunk);
    }
  }

  async finish(): Promise<string> {
    if (this.#aborted) return "";

    this.#finishing = true;

    // If we already got an EndOfTurn, return immediately
    if (this.#endOfTurnTranscript !== null) {
      this.#close();
      return this.#endOfTurnTranscript;
    }

    // Create the promise that will resolve when we have the transcript
    if (!this.#finishPromise) {
      this.#finishPromise = new Promise<string>((resolve) => {
        this.#finishResolve = resolve;
      });
    }

    // Don't close the WS immediately — keep it open so Flux can finish
    // processing buffered audio and send EndOfTurn. The timeout is a
    // safety net: if Flux doesn't respond in time, resolve with whatever
    // partial transcript we have.
    if (this.#connected && this.#ws) {
      this.#startFinishTimeout();
    }
    // else: #connect() will start the timeout after flushing

    return this.#finishPromise;
  }

  abort(): void {
    if (this.#aborted) return;
    this.#aborted = true;
    this.#clearFinishTimeout();
    this.#pendingChunks = [];
    this.#close();
    this.#resolveFinish();
  }

  #close(): void {
    if (this.#ws) {
      try {
        this.#ws.close();
      } catch {
        // ignore close errors
      }
      this.#ws = null;
    }
    this.#connected = false;
  }

  #closeAndResolve(): void {
    this.#clearFinishTimeout();
    this.#close();
    this.#resolveFinish();
  }

  /**
   * Start a timeout that gives Flux time to process remaining audio.
   * If EndOfTurn arrives before the timeout, it resolves immediately
   * (via the EndOfTurn handler). If the WS closes, the close handler
   * resolves. The timeout is the safety net for neither happening.
   */
  #startFinishTimeout(): void {
    if (this.#finishTimeout) return; // already running
    this.#finishTimeout = setTimeout(() => {
      this.#finishTimeout = null;
      this.#close();
      this.#resolveFinish();
    }, 3000);
  }

  #clearFinishTimeout(): void {
    if (this.#finishTimeout) {
      clearTimeout(this.#finishTimeout);
      this.#finishTimeout = null;
    }
  }

  #resolveFinish(): void {
    if (this.#finishResolve) {
      const transcript = this.#endOfTurnTranscript ?? this.#latestTranscript;
      this.#finishResolve(transcript.trim());
      this.#finishResolve = null;
    }
  }

  #handleMessage(event: MessageEvent): void {
    if (this.#aborted) return;

    try {
      const data: FluxEvent =
        typeof event.data === "string" ? JSON.parse(event.data) : null;

      if (!data || !data.event) return;

      const transcript = data.transcript ?? "";

      switch (data.event) {
        case "Update":
          if (transcript) {
            this.#latestTranscript = transcript;
            this.#onInterim?.(transcript);
          }
          break;

        case "EndOfTurn":
          if (transcript) {
            this.#endOfTurnTranscript = transcript;
            this.#latestTranscript = transcript;
            this.#onFinal?.(transcript);
            this.#onEndOfTurn?.(transcript);
          }
          // If finish() was already called and waiting, resolve now.
          // Clear the timeout — we got a proper EndOfTurn.
          if (this.#finishing) {
            this.#clearFinishTimeout();
            this.#closeAndResolve();
          }
          break;

        case "EagerEndOfTurn":
          // Speculative EOT — transcript is current but may change
          // if TurnResumed fires. Fire onInterim, not onFinal.
          if (transcript) {
            this.#latestTranscript = transcript;
            this.#onInterim?.(transcript);
          }
          break;

        case "TurnResumed":
          // User resumed speaking after EagerEndOfTurn — keep accumulating.
          break;

        case "StartOfTurn":
          // New turn started.
          break;
      }
    } catch {
      // Ignore non-JSON or malformed messages
    }
  }
}

// --- VAD ---

export interface WorkersAIVADOptions {
  /** VAD model name. @default "@cf/pipecat-ai/smart-turn-v2" */
  model?: string;
  /** Audio window in seconds (uses last N seconds of audio). @default 2 */
  windowSeconds?: number;
}

/**
 * Workers AI voice activity detection provider.
 *
 * @example
 * ```ts
 * class MyAgent extends VoiceAgent<Env> {
 *   vad = new WorkersAIVAD(this.env.AI);
 * }
 * ```
 */
export class WorkersAIVAD implements VADProvider {
  #ai: AiLike;
  #model: string;
  #windowSeconds: number;

  constructor(ai: AiLike, options?: WorkersAIVADOptions) {
    this.#ai = ai;
    this.#model = options?.model ?? "@cf/pipecat-ai/smart-turn-v2";
    this.#windowSeconds = options?.windowSeconds ?? 2;
  }

  async checkEndOfTurn(
    audioData: ArrayBuffer
  ): Promise<{ isComplete: boolean; probability: number }> {
    const maxBytes = this.#windowSeconds * 16000 * 2;
    const vadAudio =
      audioData.byteLength > maxBytes
        ? audioData.slice(audioData.byteLength - maxBytes)
        : audioData;

    const wavBuffer = pcmToWav(vadAudio, 16000, 1, 16);

    const result = (await this.#ai.run(this.#model, {
      audio: {
        body: toStream(wavBuffer),
        contentType: "application/octet-stream"
      }
    })) as { is_complete?: boolean; probability?: number };

    return {
      isComplete: result.is_complete ?? false,
      probability: result.probability ?? 0
    };
  }
}