branch: main
ai.rs
5721 bytesRaw
use crate::streams::ByteStream;
use crate::{env::EnvBinding, send::SendFuture};
use crate::{Error, Result};
use serde::de::DeserializeOwned;
use serde::Serialize;
use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::JsFuture;
use web_sys::ReadableStream;
use worker_sys::Ai as AiSys;
/// Enables access to Workers AI functionality.
#[derive(Debug)]
pub struct Ai(AiSys);
impl Ai {
/// Execute a Workers AI operation using the specified model.
/// Various forms of the input are documented in the Workers
/// AI documentation.
pub async fn run<T: Serialize, U: DeserializeOwned>(
&self,
model: impl AsRef<str>,
input: T,
) -> Result<U> {
let fut = SendFuture::new(JsFuture::from(
self.0
.run(model.as_ref(), serde_wasm_bindgen::to_value(&input)?),
));
match fut.await {
Ok(output) => Ok(serde_wasm_bindgen::from_value(output)?),
Err(err) => Err(Error::from(err)),
}
}
/// Execute a Workers AI operation that returns binary data as a [`ByteStream`].
///
/// This method is designed for AI models that return raw bytes, such as:
/// - Image generation models (e.g., Stable Diffusion)
/// - Text-to-speech models
/// - Any other model that returns binary output
///
/// The returned [`ByteStream`] implements [`Stream`](futures_util::Stream) and can be:
/// - Streamed directly to a [`Response`] using [`Response::from_stream`]
/// - Collected into a `Vec<u8>` by iterating over the chunks
///
/// # Examples
///
/// ## Streaming directly to a response (recommended)
///
/// This approach is more memory-efficient as it doesn't buffer the entire
/// response in memory:
///
/// ```ignore
/// use worker::*;
/// use serde::Serialize;
///
/// #[derive(Serialize)]
/// struct ImageGenRequest {
/// prompt: String,
/// }
///
/// async fn generate_image(env: &Env) -> Result<Response> {
/// let ai = env.ai("AI")?;
/// let request = ImageGenRequest {
/// prompt: "a beautiful sunset".to_string(),
/// };
/// let stream = ai.run_bytes(
/// "@cf/stabilityai/stable-diffusion-xl-base-1.0",
/// &request
/// ).await?;
///
/// // Stream directly to the response
/// let mut response = Response::from_stream(stream)?;
/// response.headers_mut().set("Content-Type", "image/png")?;
/// Ok(response)
/// }
/// ```
///
/// ## Collecting into bytes
///
/// Use this approach if you need to inspect or modify the bytes before sending:
///
/// ```ignore
/// use worker::*;
/// use serde::Serialize;
/// use futures_util::StreamExt;
///
/// #[derive(Serialize)]
/// struct ImageGenRequest {
/// prompt: String,
/// }
///
/// async fn generate_image(env: &Env) -> Result<Response> {
/// let ai = env.ai("AI")?;
/// let request = ImageGenRequest {
/// prompt: "a beautiful sunset".to_string(),
/// };
/// let mut stream = ai.run_bytes(
/// "@cf/stabilityai/stable-diffusion-xl-base-1.0",
/// &request
/// ).await?;
///
/// // Collect all chunks into a Vec<u8>
/// let mut bytes = Vec::new();
/// while let Some(chunk) = stream.next().await {
/// bytes.extend_from_slice(&chunk?);
/// }
///
/// let mut response = Response::from_bytes(bytes)?;
/// response.headers_mut().set("Content-Type", "image/png")?;
/// Ok(response)
/// }
/// ```
pub async fn run_bytes<T: Serialize>(
&self,
model: impl AsRef<str>,
input: T,
) -> Result<ByteStream> {
let fut = SendFuture::new(JsFuture::from(
self.0
.run(model.as_ref(), serde_wasm_bindgen::to_value(&input)?),
));
match fut.await {
Ok(output) => {
if output.is_instance_of::<ReadableStream>() {
let stream = ReadableStream::unchecked_from_js(output);
Ok(ByteStream::from(stream))
} else {
Err(Error::RustError(
"AI model did not return binary data. Use run() for non-binary responses."
.into(),
))
}
}
Err(err) => Err(Error::from(err)),
}
}
}
unsafe impl Sync for Ai {}
unsafe impl Send for Ai {}
impl From<AiSys> for Ai {
fn from(inner: AiSys) -> Self {
Self(inner)
}
}
impl AsRef<JsValue> for Ai {
fn as_ref(&self) -> &JsValue {
&self.0
}
}
impl From<Ai> for JsValue {
fn from(database: Ai) -> Self {
JsValue::from(database.0)
}
}
impl JsCast for Ai {
fn instanceof(val: &JsValue) -> bool {
val.is_instance_of::<AiSys>()
}
fn unchecked_from_js(val: JsValue) -> Self {
Self(val.into())
}
fn unchecked_from_js_ref(val: &JsValue) -> &Self {
unsafe { &*(val as *const JsValue as *const Self) }
}
}
impl EnvBinding for Ai {
const TYPE_NAME: &'static str = "Ai";
fn get(val: JsValue) -> Result<Self> {
let obj = js_sys::Object::from(val);
if obj.constructor().name() == Self::TYPE_NAME {
Ok(obj.unchecked_into())
} else {
Err(format!(
"Binding cannot be cast to the type {} from {}",
Self::TYPE_NAME,
obj.constructor().name()
)
.into())
}
}
}