branch: main
websocket.rs
15578 bytesRaw
use crate::{Error, Method, Request, Result};
use futures_channel::mpsc::UnboundedReceiver;
use futures_util::Stream;
use js_sys::Uint8Array;
use serde::Serialize;
use url::Url;
use worker_sys::ext::WebSocketExt;

#[cfg(not(feature = "http"))]
use crate::Fetch;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use wasm_bindgen::convert::FromWasmAbi;
use wasm_bindgen::prelude::Closure;
use wasm_bindgen::JsCast;
#[cfg(feature = "http")]
use wasm_bindgen_futures::JsFuture;

pub use crate::ws_events::*;
pub use worker_sys::WebSocketRequestResponsePair;

/// Struct holding the values for a JavaScript `WebSocketPair`
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WebSocketPair {
    pub client: WebSocket,
    pub server: WebSocket,
}

unsafe impl Send for WebSocketPair {}
unsafe impl Sync for WebSocketPair {}

impl WebSocketPair {
    /// Creates a new `WebSocketPair`.
    pub fn new() -> Result<Self> {
        let mut pair = worker_sys::WebSocketPair::new()?;
        let client = pair.client()?.into();
        let server = pair.server()?.into();
        Ok(Self { client, server })
    }
}

/// Wrapper struct for underlying worker-sys `WebSocket`
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WebSocket {
    socket: web_sys::WebSocket,
}

unsafe impl Send for WebSocket {}
unsafe impl Sync for WebSocket {}

impl WebSocket {
    /// Attempts to establish a [`WebSocket`] connection to the provided [`Url`].
    ///
    /// # Example:
    /// ```rust,ignore
    /// let ws = WebSocket::connect("wss://echo.zeb.workers.dev/".parse()?).await?;
    ///
    /// // It's important that we call this before we send our first message, otherwise we will
    /// // not have any event listeners on the socket to receive the echoed message.
    /// let mut event_stream = ws.events()?;
    ///
    /// ws.accept()?;
    /// ws.send_with_str("Hello, world!")?;
    ///
    /// while let Some(event) = event_stream.next().await {
    ///     let event = event?;
    ///
    ///     if let WebsocketEvent::Message(msg) = event {
    ///         if let Some(text) = msg.text() {
    ///             return Response::ok(text);
    ///         }
    ///     }
    /// }
    ///
    /// Response::error("never got a message echoed back :(", 500)
    /// ```
    pub async fn connect(url: Url) -> Result<WebSocket> {
        WebSocket::connect_with_protocols(url, None).await
    }

    /// Attempts to establish a [`WebSocket`] connection to the provided [`Url`] and protocol.
    ///
    /// # Example:
    /// ```rust,ignore
    /// let ws = WebSocket::connect_with_protocols("wss://echo.zeb.workers.dev/".parse()?, Some(vec!["GiggleBytes"])).await?;
    ///
    /// ```
    pub async fn connect_with_protocols(
        mut url: Url,
        protocols: Option<Vec<&str>>,
    ) -> Result<WebSocket> {
        let scheme: String = match url.scheme() {
            "ws" => "http".into(),
            "wss" => "https".into(),
            scheme => scheme.into(),
        };

        // With fetch we can only make requests to http(s) urls, but Workers will allow us to upgrade
        // those connections into websockets if we use the `Upgrade` header.
        url.set_scheme(&scheme).unwrap();

        let mut req = Request::new(url.as_str(), Method::Get)?;
        req.headers_mut()?.set("upgrade", "websocket")?;

        match protocols {
            None => {}
            Some(v) => {
                req.headers_mut()?
                    .set("Sec-WebSocket-Protocol", v.join(",").as_str())?;
            }
        }

        #[cfg(not(feature = "http"))]
        let res = Fetch::Request(req).send().await?;
        #[cfg(feature = "http")]
        let res: crate::Response = fetch_with_request_raw(req).await?.into();

        match res.websocket() {
            Some(ws) => Ok(ws),
            None => Err(Error::RustError("server did not accept".into())),
        }
    }

    /// Accepts the connection, allowing for messages to be sent to and from the `WebSocket`.
    pub fn accept(&self) -> Result<()> {
        self.socket.accept().map_err(Error::from)
    }

    /// Serialize data into a string using serde and send it through the `WebSocket`
    pub fn send<T: Serialize>(&self, data: &T) -> Result<()> {
        let value = serde_json::to_string(data)?;
        self.send_with_str(value.as_str())
    }

    /// Sends a raw string through the `WebSocket`
    pub fn send_with_str<S: AsRef<str>>(&self, data: S) -> Result<()> {
        self.socket
            .send_with_str(data.as_ref())
            .map_err(Error::from)
    }

    /// Sends raw binary data through the `WebSocket`.
    pub fn send_with_bytes<D: AsRef<[u8]>>(&self, bytes: D) -> Result<()> {
        // This clone to Uint8Array must happen, because workerd
        // will not clone the supplied buffer and will send it asynchronously.
        // Rust believes that the lifetime ends when `send` returns, and frees
        // the memory, causing corruption.
        let uint8_array = Uint8Array::from(bytes.as_ref());
        self.socket.send_with_array_buffer(&uint8_array.buffer())?;
        Ok(())
    }

    /// Closes this channel.
    /// This method translates to three different underlying method calls based of the
    /// parameters passed.
    ///
    /// If the following parameters are Some:
    /// * `code` and `reason` -> `close_with_code_and_reason`
    /// * `code`              -> `close_with_code`
    /// * `reason` or `none`  -> `close`
    ///
    /// Effectively, if only `reason` is `Some`, the `reason` argument will be ignored.
    pub fn close<S: AsRef<str>>(&self, code: Option<u16>, reason: Option<S>) -> Result<()> {
        if let Some((code, reason)) = code.zip(reason) {
            self.socket
                .close_with_code_and_reason(code, reason.as_ref())
        } else if let Some(code) = code {
            self.socket.close_with_code(code)
        } else {
            self.socket.close()
        }
        .map_err(Error::from)
    }

    /// Internal utility method to avoid verbose code.
    /// This method registers a closure in the underlying JS environment, which calls back into the
    /// Rust/wasm environment.
    ///
    /// Since this is a 'long living closure', we need to keep the lifetime of this closure until
    /// we remove any references to the closure. So the caller of this must not drop the closure
    /// until they call [`Self::remove_event_handler`].
    fn add_event_handler<T: FromWasmAbi + 'static, F: FnMut(T) + 'static>(
        &self,
        r#type: &str,
        fun: F,
    ) -> Result<Closure<dyn FnMut(T)>> {
        let js_callback = Closure::wrap_assert_unwind_safe(Box::new(fun) as Box<dyn FnMut(T)>);
        self.socket
            .add_event_listener_with_callback(r#type, js_callback.as_ref().unchecked_ref())
            .map_err(Error::from)?;

        Ok(js_callback)
    }

    /// Internal utility method to avoid verbose code.
    /// This method registers a closure in the underlying JS environment, which calls back
    /// into the Rust/wasm environment.
    fn remove_event_handler<T: FromWasmAbi + 'static>(
        &self,
        r#type: &str,
        js_callback: Closure<dyn FnMut(T)>,
    ) -> Result<()> {
        self.socket
            .remove_event_listener_with_callback(r#type, js_callback.as_ref().unchecked_ref())
            .map_err(Error::from)
    }

    /// Gets an implementation [`Stream`](futures::Stream) that yields events from the inner
    /// WebSocket.
    pub fn events(&self) -> Result<EventStream<'_>> {
        let (tx, rx) = futures_channel::mpsc::unbounded::<Result<WebsocketEvent>>();
        let tx = Rc::new(tx);

        let close_closure = self.add_event_handler("close", {
            let tx = tx.clone();
            move |event: web_sys::CloseEvent| {
                tx.unbounded_send(Ok(WebsocketEvent::Close(event.into())))
                    .unwrap();
            }
        })?;
        let message_closure = self.add_event_handler("message", {
            let tx = tx.clone();
            move |event: web_sys::MessageEvent| {
                tx.unbounded_send(Ok(WebsocketEvent::Message(event.into())))
                    .unwrap();
            }
        })?;
        let error_closure =
            self.add_event_handler("error", move |event: web_sys::ErrorEvent| {
                let error = event.error();
                tx.unbounded_send(Err(error.into())).unwrap();
            })?;

        Ok(EventStream {
            ws: self,
            rx,
            closed: false,
            closures: Some((message_closure, error_closure, close_closure)),
        })
    }

    pub fn serialize_attachment<T: Serialize>(&self, value: T) -> Result<()> {
        self.socket
            .serialize_attachment(serde_wasm_bindgen::to_value(&value)?)
            .map_err(Error::from)
    }

    pub fn deserialize_attachment<T: serde::de::DeserializeOwned>(&self) -> Result<Option<T>> {
        let value = self.socket.deserialize_attachment().map_err(Error::from)?;

        if value.is_null() || value.is_undefined() {
            return Ok(None);
        }

        serde_wasm_bindgen::from_value::<T>(value)
            .map(Some)
            .map_err(Error::from)
    }
}

type EvCallback<T> = Closure<dyn FnMut(T)>;

/// A [`Stream`](futures::Stream) that yields [`WebsocketEvent`](crate::ws_events::WebsocketEvent)s
/// emitted by the inner [`WebSocket`](crate::WebSocket). The stream is guaranteed to always yield a
/// `WebsocketEvent::Close` as the final non-none item.
///
/// # Example
/// ```rust,ignore
/// use futures::StreamExt;
///
/// let pair = WebSocketPair::new()?;
/// let server = pair.server;
///
/// server.accept()?;
///
/// // Spawn a future for handling the stream of events from the websocket.
/// wasm_bindgen_futures::spawn_local(async move {
///     let mut event_stream = server.events().expect("could not open stream");
///
///     while let Some(event) = event_stream.next().await {
///         match event.expect("received error in websocket") {
///             WebsocketEvent::Message(msg) => console_log!("{:#?}", msg),
///             WebsocketEvent::Close(event) => console_log!("Closed!"),
///         }
///     }
/// });
/// ```
#[pin_project::pin_project(PinnedDrop)]
#[derive(Debug)]
pub struct EventStream<'ws> {
    ws: &'ws WebSocket,
    #[pin]
    rx: UnboundedReceiver<Result<WebsocketEvent>>,
    closed: bool,
    /// Once we have decided we need to finish the stream, we need to remove any listeners we
    /// registered with the websocket.
    closures: Option<(
        EvCallback<web_sys::MessageEvent>,
        EvCallback<web_sys::ErrorEvent>,
        EvCallback<web_sys::CloseEvent>,
    )>,
}

impl Stream for EventStream<'_> {
    type Item = Result<WebsocketEvent>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let this = self.project();

        if *this.closed {
            return Poll::Ready(None);
        }

        // Poll the inner receiver to check if theres any events from our event callbacks.
        let item = futures_util::ready!(this.rx.poll_next(cx));

        // Mark the stream as closed if we get a close event and yield None next iteration.
        if let Some(item) = &item {
            if matches!(&item, Ok(WebsocketEvent::Close(_))) {
                *this.closed = true;
            }
        }

        Poll::Ready(item)
    }
}

// Because we don't want to receive messages once our stream is done we need to remove all our
// listeners when we go to drop the stream.
#[pin_project::pinned_drop]
impl PinnedDrop for EventStream<'_> {
    fn drop(self: Pin<&'_ mut Self>) {
        let this = self.project();

        // remove_event_handler takes an owned closure, so we'll do this little hack and wrap our
        // closures in an Option. This should never panic because we should never call drop twice.
        let (message_closure, error_closure, close_closure) =
            std::mem::take(this.closures).expect("double drop on worker::EventStream");

        this.ws
            .remove_event_handler("message", message_closure)
            .expect("could not remove message handler");
        this.ws
            .remove_event_handler("error", error_closure)
            .expect("could not remove error handler");
        this.ws
            .remove_event_handler("close", close_closure)
            .expect("could not remove close handler");
    }
}

impl From<web_sys::WebSocket> for WebSocket {
    fn from(socket: web_sys::WebSocket) -> Self {
        Self { socket }
    }
}

impl AsRef<web_sys::WebSocket> for WebSocket {
    fn as_ref(&self) -> &web_sys::WebSocket {
        &self.socket
    }
}

pub mod ws_events {
    use serde::de::DeserializeOwned;
    use wasm_bindgen::JsValue;

    use crate::Error;

    /// Events that can be yielded by a [`EventStream`](crate::EventStream).
    #[derive(Debug, Clone)]
    pub enum WebsocketEvent {
        Message(MessageEvent),
        Close(CloseEvent),
    }

    /// Wrapper/Utility struct for the `web_sys::MessageEvent`
    #[derive(Debug, Clone, PartialEq, Eq)]
    pub struct MessageEvent {
        event: web_sys::MessageEvent,
    }

    impl From<web_sys::MessageEvent> for MessageEvent {
        fn from(event: web_sys::MessageEvent) -> Self {
            Self { event }
        }
    }

    impl AsRef<web_sys::MessageEvent> for MessageEvent {
        fn as_ref(&self) -> &web_sys::MessageEvent {
            &self.event
        }
    }

    impl MessageEvent {
        /// Gets the data/payload from the message.
        fn data(&self) -> JsValue {
            self.event.data()
        }

        pub fn text(&self) -> Option<String> {
            let value = self.data();
            value.as_string()
        }

        pub fn bytes(&self) -> Option<Vec<u8>> {
            let value = self.data();
            if value.is_object() {
                Some(js_sys::Uint8Array::new(&value).to_vec())
            } else {
                None
            }
        }

        pub fn json<T: DeserializeOwned>(&self) -> crate::Result<T> {
            let text = match self.text() {
                Some(text) => text,
                None => return Err(Error::from("data of message event is not text")),
            };

            serde_json::from_str(&text).map_err(Error::from)
        }
    }

    /// Wrapper/Utility struct for the `web_sys::CloseEvent`
    #[derive(Debug, Clone, PartialEq, Eq)]
    pub struct CloseEvent {
        event: web_sys::CloseEvent,
    }

    impl CloseEvent {
        pub fn reason(&self) -> String {
            self.event.reason()
        }

        pub fn code(&self) -> u16 {
            self.event.code()
        }

        pub fn was_clean(&self) -> bool {
            self.event.was_clean()
        }
    }

    impl From<web_sys::CloseEvent> for CloseEvent {
        fn from(event: web_sys::CloseEvent) -> Self {
            Self { event }
        }
    }

    impl AsRef<web_sys::CloseEvent> for CloseEvent {
        fn as_ref(&self) -> &web_sys::CloseEvent {
            &self.event
        }
    }
}

/// TODO: Convert WebSocket to use `http` types and `reqwest`.
#[cfg(feature = "http")]
async fn fetch_with_request_raw(request: crate::Request) -> Result<web_sys::Response> {
    let req = request.inner();
    let fut = {
        let worker: web_sys::WorkerGlobalScope = js_sys::global().unchecked_into();
        crate::send::SendFuture::new(JsFuture::from(worker.fetch_with_request(req)))
    };
    let resp = fut.await?;
    Ok(resp.dyn_into()?)
}