branch:
voice-input.ts
17352 bytesRaw
/**
 * Voice-to-text input mixin for the Agents SDK.
 *
 * Unlike `withVoice` (which builds a full conversational voice agent with
 * STT → LLM → TTS), `withVoiceInput` only does STT and sends the
 * transcript back to the client. There is no TTS, no `onTurn`, and no
 * response generation — making it ideal for dictation / voice input UIs.
 *
 * Usage:
 *   import { Agent } from "agents";
 *   import { withVoiceInput, WorkersAIFluxSTT } from "@cloudflare/voice";
 *
 *   const InputAgent = withVoiceInput(Agent);
 *
 *   class MyAgent extends InputAgent<Env> {
 *     streamingStt = new WorkersAIFluxSTT(this.env.AI);
 *
 *     onTranscript(text, connection) {
 *       console.log("User said:", text);
 *     }
 *   }
 *
 * @experimental This API is not yet stable and may change.
 */

import type { Connection, WSMessage } from "agents";
import { VOICE_PROTOCOL_VERSION } from "./types";
import type { STTProvider, VADProvider, StreamingSTTProvider } from "./types";
import {
  AudioConnectionManager,
  sendVoiceJSON,
  DEFAULT_VAD_THRESHOLD,
  DEFAULT_MIN_AUDIO_BYTES,
  DEFAULT_VAD_PUSHBACK_SECONDS,
  DEFAULT_VAD_RETRY_MS
} from "./audio-pipeline";

// --- Public types ---

/** Configuration options for the voice input mixin. */
export interface VoiceInputAgentOptions {
  /** Minimum audio bytes to process (16kHz mono 16-bit). @default 16000 (0.5s) */
  minAudioBytes?: number;
  /** VAD probability threshold — only used when `vad` is set. @default 0.5 */
  vadThreshold?: number;
  /** Seconds of audio to push back to buffer when VAD rejects. @default 2 */
  vadPushbackSeconds?: number;
  /** Milliseconds to wait after VAD rejects before retrying without VAD. @default 3000 */
  vadRetryMs?: number;
}

// --- Mixin ---

// oxlint-disable-next-line @typescript-eslint/no-explicit-any -- mixin constructor constraint
type Constructor<T = object> = new (...args: any[]) => T;

/**
 * Voice-to-text input mixin. Adds STT-only voice input to an Agent class.
 *
 * Subclasses must set an `stt` or `streamingStt` provider property.
 * No TTS provider is needed. Override `onTranscript` to handle each
 * transcribed utterance.
 *
 * @param Base - The Agent class to extend (e.g. `Agent`).
 * @param voiceInputOptions - Optional pipeline configuration.
 *
 * @example
 * ```typescript
 * import { Agent } from "agents";
 * import { withVoiceInput, WorkersAIFluxSTT } from "@cloudflare/voice";
 *
 * const InputAgent = withVoiceInput(Agent);
 *
 * class MyAgent extends InputAgent<Env> {
 *   streamingStt = new WorkersAIFluxSTT(this.env.AI);
 *
 *   onTranscript(text, connection) {
 *     console.log("User said:", text);
 *   }
 * }
 * ```
 */
export function withVoiceInput<TBase extends Constructor>(
  Base: TBase,
  voiceInputOptions?: VoiceInputAgentOptions
) {
  console.log(
    "[@cloudflare/voice] Note: The voice API is experimental and may change between releases. Pin your version to avoid surprises."
  );

  const opts = voiceInputOptions ?? {};

  function opt<K extends keyof VoiceInputAgentOptions>(
    key: K,
    fallback: NonNullable<VoiceInputAgentOptions[K]>
  ): NonNullable<VoiceInputAgentOptions[K]> {
    return (opts[key] ?? fallback) as NonNullable<VoiceInputAgentOptions[K]>;
  }

  class VoiceInputMixin extends Base {
    // --- Provider properties (set by subclass) ---

    /** Speech-to-text provider (batch). Required unless streamingStt is set. */
    stt?: STTProvider;
    /** Streaming speech-to-text provider. Optional — if set, used instead of batch `stt`. */
    streamingStt?: StreamingSTTProvider;
    /** Voice activity detection provider. Optional. */
    vad?: VADProvider;

    // Shared per-connection audio state manager
    #cm = new AudioConnectionManager("VoiceInput");

    // Voice protocol message types handled internally
    static #VOICE_MESSAGES = new Set([
      "hello",
      "start_call",
      "end_call",
      "start_of_speech",
      "end_of_speech",
      "interrupt"
    ]);

    // oxlint-disable-next-line @typescript-eslint/no-explicit-any -- mixin constructor must accept any args
    constructor(...args: any[]) {
      super(...args);

      // Capture the consumer's lifecycle methods (defined on the subclass
      // prototype) and wrap them so voice logic always runs first.
      // This is the same pattern used by Agent and PartyServer.

      // oxlint-disable-next-line @typescript-eslint/no-explicit-any -- binding consumer methods
      const _onConnect = (this as any).onConnect?.bind(this);
      // oxlint-disable-next-line @typescript-eslint/no-explicit-any -- binding consumer methods
      const _onClose = (this as any).onClose?.bind(this);
      // oxlint-disable-next-line @typescript-eslint/no-explicit-any -- binding consumer methods
      const _onMessage = (this as any).onMessage?.bind(this);

      // oxlint-disable-next-line @typescript-eslint/no-explicit-any -- overwriting lifecycle
      (this as any).onConnect = (
        connection: Connection,
        ...rest: unknown[]
      ) => {
        sendVoiceJSON(
          connection,
          {
            type: "welcome",
            protocol_version: VOICE_PROTOCOL_VERSION
          },
          "VoiceInput"
        );
        sendVoiceJSON(
          connection,
          { type: "status", status: "idle" },
          "VoiceInput"
        );
        return _onConnect?.(connection, ...rest);
      };

      // oxlint-disable-next-line @typescript-eslint/no-explicit-any -- overwriting lifecycle
      (this as any).onClose = (connection: Connection, ...rest: unknown[]) => {
        this.#cm.cleanup(connection.id);
        return _onClose?.(connection, ...rest);
      };

      // oxlint-disable-next-line @typescript-eslint/no-explicit-any -- overwriting lifecycle
      (this as any).onMessage = (
        connection: Connection,
        message: WSMessage
      ) => {
        // Binary audio — always handled by voice, never forwarded
        if (message instanceof ArrayBuffer) {
          this.#cm.bufferAudio(connection.id, message);
          return;
        }

        if (typeof message !== "string") {
          return _onMessage?.(connection, message);
        }

        // Try to parse as voice protocol
        let parsed: { type: string };
        try {
          parsed = JSON.parse(message);
        } catch {
          // Not JSON — forward to consumer
          return _onMessage?.(connection, message);
        }

        // Voice protocol message — handle internally
        if (VoiceInputMixin.#VOICE_MESSAGES.has(parsed.type)) {
          switch (parsed.type) {
            case "hello":
              break;
            case "start_call":
              this.#handleStartCall(connection);
              break;
            case "end_call":
              this.#handleEndCall(connection);
              break;
            case "start_of_speech":
              this.#handleStartOfSpeech(connection);
              break;
            case "end_of_speech":
              this.#cm.clearVadRetry(connection.id);
              this.#handleEndOfSpeech(connection);
              break;
            case "interrupt":
              this.#handleInterrupt(connection);
              break;
          }
          return;
        }

        // Not a voice message — forward to consumer
        return _onMessage?.(connection, message);
      };
    }

    // --- User-overridable hooks ---

    /**
     * Called after each utterance is transcribed.
     * Override this to process the transcript (e.g. save to storage,
     * trigger a search, or forward to another service).
     *
     * @param text - The transcribed text.
     * @param connection - The WebSocket connection that sent the audio.
     */
    onTranscript(
      _text: string,
      _connection: Connection
    ): void | Promise<void> {}

    /**
     * Called before accepting a call. Return `false` to reject.
     */
    beforeCallStart(_connection: Connection): boolean | Promise<boolean> {
      return true;
    }

    onCallStart(_connection: Connection): void | Promise<void> {}
    onCallEnd(_connection: Connection): void | Promise<void> {}
    onInterrupt(_connection: Connection): void | Promise<void> {}

    /**
     * Hook to transform audio before STT. Return null to skip this utterance.
     */
    beforeTranscribe(
      audio: ArrayBuffer,
      _connection: Connection
    ): ArrayBuffer | null | Promise<ArrayBuffer | null> {
      return audio;
    }

    /**
     * Hook to transform or filter the transcript after STT.
     * Return null to discard this utterance.
     */
    afterTranscribe(
      transcript: string,
      _connection: Connection
    ): string | null | Promise<string | null> {
      return transcript;
    }

    // --- Streaming STT session management ---

    #handleStartOfSpeech(connection: Connection) {
      if (!this.streamingStt) return;
      if (this.#cm.hasSTTSession(connection.id)) return;
      if (!this.#cm.isInCall(connection.id)) return;

      // Clear EOT flag from any previous turn
      this.#cm.clearEOT(connection.id);

      // Accumulate finalized segments for the full transcript
      let accumulated = "";

      this.#cm.startSTTSession(connection.id, this.streamingStt, {
        onFinal: (text: string) => {
          accumulated += (accumulated ? " " : "") + text;
          sendVoiceJSON(
            connection,
            {
              type: "transcript_interim",
              text: accumulated
            },
            "VoiceInput"
          );
        },
        onInterim: (text: string) => {
          const display = accumulated ? accumulated + " " + text : text;
          sendVoiceJSON(
            connection,
            {
              type: "transcript_interim",
              text: display
            },
            "VoiceInput"
          );
        },
        // Provider-driven end-of-turn: transcribe immediately
        onEndOfTurn: (transcript: string) => {
          if (this.#cm.isEOTTriggered(connection.id)) return;
          this.#cm.setEOTTriggered(connection.id);

          this.#cm.removeSTTSession(connection.id);
          this.#cm.clearAudioBuffer(connection.id);
          this.#cm.clearVadRetry(connection.id);

          // Emit transcript and go straight back to listening
          this.#emitTranscript(connection, transcript);
        }
      });
    }

    // --- Internal: call lifecycle ---

    async #handleStartCall(connection: Connection) {
      const allowed = await this.beforeCallStart(connection);
      if (!allowed) return;

      this.#cm.initConnection(connection.id);
      sendVoiceJSON(
        connection,
        { type: "status", status: "listening" },
        "VoiceInput"
      );

      await this.onCallStart(connection);
    }

    #handleEndCall(connection: Connection) {
      this.#cm.cleanup(connection.id);
      sendVoiceJSON(
        connection,
        { type: "status", status: "idle" },
        "VoiceInput"
      );

      this.onCallEnd(connection);
    }

    #handleInterrupt(connection: Connection) {
      this.#cm.abortPipeline(connection.id);
      this.#cm.abortSTTSession(connection.id);
      this.#cm.clearVadRetry(connection.id);
      this.#cm.clearEOT(connection.id);
      this.#cm.clearAudioBuffer(connection.id);
      sendVoiceJSON(
        connection,
        { type: "status", status: "listening" },
        "VoiceInput"
      );

      this.onInterrupt(connection);
    }

    // --- Internal: audio pipeline ---

    async #handleEndOfSpeech(connection: Connection, skipVad = false) {
      // If already triggered by provider-driven EOT, ignore
      if (this.#cm.isEOTTriggered(connection.id)) {
        this.#cm.clearEOT(connection.id);
        return;
      }

      const audioData = this.#cm.getAndClearAudio(connection.id);
      if (!audioData) return;

      const hasStreamingSession = this.#cm.hasSTTSession(connection.id);

      const minAudioBytes = opt("minAudioBytes", DEFAULT_MIN_AUDIO_BYTES);
      if (audioData.byteLength < minAudioBytes) {
        this.#cm.abortSTTSession(connection.id);
        sendVoiceJSON(
          connection,
          { type: "status", status: "listening" },
          "VoiceInput"
        );
        return;
      }

      if (this.vad && !skipVad) {
        const vadResult = await this.vad.checkEndOfTurn(audioData);
        const vadThreshold = opt("vadThreshold", DEFAULT_VAD_THRESHOLD);
        const shouldProceed =
          vadResult.isComplete || vadResult.probability > vadThreshold;

        if (!shouldProceed) {
          const pushbackSeconds = opt(
            "vadPushbackSeconds",
            DEFAULT_VAD_PUSHBACK_SECONDS
          );
          const maxPushbackBytes = pushbackSeconds * 16000 * 2;
          const pushback =
            audioData.byteLength > maxPushbackBytes
              ? audioData.slice(audioData.byteLength - maxPushbackBytes)
              : audioData;
          this.#cm.pushbackAudio(connection.id, pushback);
          sendVoiceJSON(
            connection,
            { type: "status", status: "listening" },
            "VoiceInput"
          );
          this.#cm.scheduleVadRetry(
            connection.id,
            () => this.#handleEndOfSpeech(connection, true),
            opt("vadRetryMs", DEFAULT_VAD_RETRY_MS) as number
          );
          return;
        }
      }

      // --- STT phase ---

      const signal = this.#cm.createPipelineAbort(connection.id);

      sendVoiceJSON(
        connection,
        { type: "status", status: "thinking" },
        "VoiceInput"
      );

      try {
        let userText: string | null;

        if (hasStreamingSession) {
          // Streaming STT path — flush and get final transcript
          const rawTranscript = await this.#cm.flushSTTSession(connection.id);

          if (signal.aborted) return;

          if (!rawTranscript || rawTranscript.trim().length === 0) {
            sendVoiceJSON(
              connection,
              {
                type: "status",
                status: "listening"
              },
              "VoiceInput"
            );
            return;
          }

          userText = await this.afterTranscribe(rawTranscript, connection);
        } else {
          // Batch STT path
          if (!this.stt) {
            sendVoiceJSON(
              connection,
              {
                type: "status",
                status: "listening"
              },
              "VoiceInput"
            );
            return;
          }

          const processedAudio = await this.beforeTranscribe(
            audioData,
            connection
          );
          if (!processedAudio || signal.aborted) {
            sendVoiceJSON(
              connection,
              {
                type: "status",
                status: "listening"
              },
              "VoiceInput"
            );
            return;
          }

          const rawTranscript = await this.stt.transcribe(
            processedAudio,
            signal
          );
          if (signal.aborted) return;

          if (!rawTranscript || rawTranscript.trim().length === 0) {
            sendVoiceJSON(
              connection,
              {
                type: "status",
                status: "listening"
              },
              "VoiceInput"
            );
            return;
          }

          userText = await this.afterTranscribe(rawTranscript, connection);
        }

        if (!userText || signal.aborted) {
          sendVoiceJSON(
            connection,
            { type: "status", status: "listening" },
            "VoiceInput"
          );
          return;
        }

        // Emit the transcript and go straight back to listening
        await this.#emitTranscript(connection, userText);
      } catch (error) {
        if (signal.aborted) return;
        console.error("[VoiceInput] STT pipeline error:", error);
        sendVoiceJSON(
          connection,
          {
            type: "error",
            message:
              error instanceof Error ? error.message : "Voice input failed"
          },
          "VoiceInput"
        );
        sendVoiceJSON(
          connection,
          { type: "status", status: "listening" },
          "VoiceInput"
        );
      } finally {
        this.#cm.clearPipelineAbort(connection.id);
      }
    }

    /**
     * Send the user transcript to the client and call the onTranscript hook.
     * Then immediately return to listening — no LLM/TTS pipeline.
     */
    async #emitTranscript(connection: Connection, text: string) {
      // Clear interim transcript
      sendVoiceJSON(
        connection,
        {
          type: "transcript_interim",
          text: ""
        },
        "VoiceInput"
      );

      // Send the final user transcript
      sendVoiceJSON(
        connection,
        {
          type: "transcript",
          role: "user",
          text
        },
        "VoiceInput"
      );

      // Call the user hook
      try {
        await this.onTranscript(text, connection);
      } catch (err) {
        console.error("[VoiceInput] onTranscript error:", err);
      }

      // Back to listening immediately
      sendVoiceJSON(
        connection,
        { type: "status", status: "listening" },
        "VoiceInput"
      );
    }
  }

  return VoiceInputMixin;
}