use crate::{Error, Method, Request, Result}; use futures_channel::mpsc::UnboundedReceiver; use futures_util::Stream; use js_sys::Uint8Array; use serde::Serialize; use url::Url; use worker_sys::ext::WebSocketExt; #[cfg(not(feature = "http"))] use crate::Fetch; use std::pin::Pin; use std::rc::Rc; use std::task::{Context, Poll}; use wasm_bindgen::convert::FromWasmAbi; use wasm_bindgen::prelude::Closure; use wasm_bindgen::JsCast; #[cfg(feature = "http")] use wasm_bindgen_futures::JsFuture; pub use crate::ws_events::*; pub use worker_sys::WebSocketRequestResponsePair; /// Struct holding the values for a JavaScript `WebSocketPair` #[derive(Debug, Clone, PartialEq, Eq)] pub struct WebSocketPair { pub client: WebSocket, pub server: WebSocket, } unsafe impl Send for WebSocketPair {} unsafe impl Sync for WebSocketPair {} impl WebSocketPair { /// Creates a new `WebSocketPair`. pub fn new() -> Result { let mut pair = worker_sys::WebSocketPair::new()?; let client = pair.client()?.into(); let server = pair.server()?.into(); Ok(Self { client, server }) } } /// Wrapper struct for underlying worker-sys `WebSocket` #[derive(Debug, Clone, PartialEq, Eq)] pub struct WebSocket { socket: web_sys::WebSocket, } unsafe impl Send for WebSocket {} unsafe impl Sync for WebSocket {} impl WebSocket { /// Attempts to establish a [`WebSocket`] connection to the provided [`Url`]. /// /// # Example: /// ```rust,ignore /// let ws = WebSocket::connect("wss://echo.zeb.workers.dev/".parse()?).await?; /// /// // It's important that we call this before we send our first message, otherwise we will /// // not have any event listeners on the socket to receive the echoed message. /// let mut event_stream = ws.events()?; /// /// ws.accept()?; /// ws.send_with_str("Hello, world!")?; /// /// while let Some(event) = event_stream.next().await { /// let event = event?; /// /// if let WebsocketEvent::Message(msg) = event { /// if let Some(text) = msg.text() { /// return Response::ok(text); /// } /// } /// } /// /// Response::error("never got a message echoed back :(", 500) /// ``` pub async fn connect(url: Url) -> Result { WebSocket::connect_with_protocols(url, None).await } /// Attempts to establish a [`WebSocket`] connection to the provided [`Url`] and protocol. /// /// # Example: /// ```rust,ignore /// let ws = WebSocket::connect_with_protocols("wss://echo.zeb.workers.dev/".parse()?, Some(vec!["GiggleBytes"])).await?; /// /// ``` pub async fn connect_with_protocols( mut url: Url, protocols: Option>, ) -> Result { let scheme: String = match url.scheme() { "ws" => "http".into(), "wss" => "https".into(), scheme => scheme.into(), }; // With fetch we can only make requests to http(s) urls, but Workers will allow us to upgrade // those connections into websockets if we use the `Upgrade` header. url.set_scheme(&scheme).unwrap(); let mut req = Request::new(url.as_str(), Method::Get)?; req.headers_mut()?.set("upgrade", "websocket")?; match protocols { None => {} Some(v) => { req.headers_mut()? .set("Sec-WebSocket-Protocol", v.join(",").as_str())?; } } #[cfg(not(feature = "http"))] let res = Fetch::Request(req).send().await?; #[cfg(feature = "http")] let res: crate::Response = fetch_with_request_raw(req).await?.into(); match res.websocket() { Some(ws) => Ok(ws), None => Err(Error::RustError("server did not accept".into())), } } /// Accepts the connection, allowing for messages to be sent to and from the `WebSocket`. pub fn accept(&self) -> Result<()> { self.socket.accept().map_err(Error::from) } /// Serialize data into a string using serde and send it through the `WebSocket` pub fn send(&self, data: &T) -> Result<()> { let value = serde_json::to_string(data)?; self.send_with_str(value.as_str()) } /// Sends a raw string through the `WebSocket` pub fn send_with_str>(&self, data: S) -> Result<()> { self.socket .send_with_str(data.as_ref()) .map_err(Error::from) } /// Sends raw binary data through the `WebSocket`. pub fn send_with_bytes>(&self, bytes: D) -> Result<()> { // This clone to Uint8Array must happen, because workerd // will not clone the supplied buffer and will send it asynchronously. // Rust believes that the lifetime ends when `send` returns, and frees // the memory, causing corruption. let uint8_array = Uint8Array::from(bytes.as_ref()); self.socket.send_with_array_buffer(&uint8_array.buffer())?; Ok(()) } /// Closes this channel. /// This method translates to three different underlying method calls based of the /// parameters passed. /// /// If the following parameters are Some: /// * `code` and `reason` -> `close_with_code_and_reason` /// * `code` -> `close_with_code` /// * `reason` or `none` -> `close` /// /// Effectively, if only `reason` is `Some`, the `reason` argument will be ignored. pub fn close>(&self, code: Option, reason: Option) -> Result<()> { if let Some((code, reason)) = code.zip(reason) { self.socket .close_with_code_and_reason(code, reason.as_ref()) } else if let Some(code) = code { self.socket.close_with_code(code) } else { self.socket.close() } .map_err(Error::from) } /// Internal utility method to avoid verbose code. /// This method registers a closure in the underlying JS environment, which calls back into the /// Rust/wasm environment. /// /// Since this is a 'long living closure', we need to keep the lifetime of this closure until /// we remove any references to the closure. So the caller of this must not drop the closure /// until they call [`Self::remove_event_handler`]. fn add_event_handler( &self, r#type: &str, fun: F, ) -> Result> { let js_callback = Closure::wrap_assert_unwind_safe(Box::new(fun) as Box); self.socket .add_event_listener_with_callback(r#type, js_callback.as_ref().unchecked_ref()) .map_err(Error::from)?; Ok(js_callback) } /// Internal utility method to avoid verbose code. /// This method registers a closure in the underlying JS environment, which calls back /// into the Rust/wasm environment. fn remove_event_handler( &self, r#type: &str, js_callback: Closure, ) -> Result<()> { self.socket .remove_event_listener_with_callback(r#type, js_callback.as_ref().unchecked_ref()) .map_err(Error::from) } /// Gets an implementation [`Stream`](futures::Stream) that yields events from the inner /// WebSocket. pub fn events(&self) -> Result> { let (tx, rx) = futures_channel::mpsc::unbounded::>(); let tx = Rc::new(tx); let close_closure = self.add_event_handler("close", { let tx = tx.clone(); move |event: web_sys::CloseEvent| { tx.unbounded_send(Ok(WebsocketEvent::Close(event.into()))) .unwrap(); } })?; let message_closure = self.add_event_handler("message", { let tx = tx.clone(); move |event: web_sys::MessageEvent| { tx.unbounded_send(Ok(WebsocketEvent::Message(event.into()))) .unwrap(); } })?; let error_closure = self.add_event_handler("error", move |event: web_sys::ErrorEvent| { let error = event.error(); tx.unbounded_send(Err(error.into())).unwrap(); })?; Ok(EventStream { ws: self, rx, closed: false, closures: Some((message_closure, error_closure, close_closure)), }) } pub fn serialize_attachment(&self, value: T) -> Result<()> { self.socket .serialize_attachment(serde_wasm_bindgen::to_value(&value)?) .map_err(Error::from) } pub fn deserialize_attachment(&self) -> Result> { let value = self.socket.deserialize_attachment().map_err(Error::from)?; if value.is_null() || value.is_undefined() { return Ok(None); } serde_wasm_bindgen::from_value::(value) .map(Some) .map_err(Error::from) } } type EvCallback = Closure; /// A [`Stream`](futures::Stream) that yields [`WebsocketEvent`](crate::ws_events::WebsocketEvent)s /// emitted by the inner [`WebSocket`](crate::WebSocket). The stream is guaranteed to always yield a /// `WebsocketEvent::Close` as the final non-none item. /// /// # Example /// ```rust,ignore /// use futures::StreamExt; /// /// let pair = WebSocketPair::new()?; /// let server = pair.server; /// /// server.accept()?; /// /// // Spawn a future for handling the stream of events from the websocket. /// wasm_bindgen_futures::spawn_local(async move { /// let mut event_stream = server.events().expect("could not open stream"); /// /// while let Some(event) = event_stream.next().await { /// match event.expect("received error in websocket") { /// WebsocketEvent::Message(msg) => console_log!("{:#?}", msg), /// WebsocketEvent::Close(event) => console_log!("Closed!"), /// } /// } /// }); /// ``` #[pin_project::pin_project(PinnedDrop)] #[derive(Debug)] pub struct EventStream<'ws> { ws: &'ws WebSocket, #[pin] rx: UnboundedReceiver>, closed: bool, /// Once we have decided we need to finish the stream, we need to remove any listeners we /// registered with the websocket. closures: Option<( EvCallback, EvCallback, EvCallback, )>, } impl Stream for EventStream<'_> { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); if *this.closed { return Poll::Ready(None); } // Poll the inner receiver to check if theres any events from our event callbacks. let item = futures_util::ready!(this.rx.poll_next(cx)); // Mark the stream as closed if we get a close event and yield None next iteration. if let Some(item) = &item { if matches!(&item, Ok(WebsocketEvent::Close(_))) { *this.closed = true; } } Poll::Ready(item) } } // Because we don't want to receive messages once our stream is done we need to remove all our // listeners when we go to drop the stream. #[pin_project::pinned_drop] impl PinnedDrop for EventStream<'_> { fn drop(self: Pin<&'_ mut Self>) { let this = self.project(); // remove_event_handler takes an owned closure, so we'll do this little hack and wrap our // closures in an Option. This should never panic because we should never call drop twice. let (message_closure, error_closure, close_closure) = std::mem::take(this.closures).expect("double drop on worker::EventStream"); this.ws .remove_event_handler("message", message_closure) .expect("could not remove message handler"); this.ws .remove_event_handler("error", error_closure) .expect("could not remove error handler"); this.ws .remove_event_handler("close", close_closure) .expect("could not remove close handler"); } } impl From for WebSocket { fn from(socket: web_sys::WebSocket) -> Self { Self { socket } } } impl AsRef for WebSocket { fn as_ref(&self) -> &web_sys::WebSocket { &self.socket } } pub mod ws_events { use serde::de::DeserializeOwned; use wasm_bindgen::JsValue; use crate::Error; /// Events that can be yielded by a [`EventStream`](crate::EventStream). #[derive(Debug, Clone)] pub enum WebsocketEvent { Message(MessageEvent), Close(CloseEvent), } /// Wrapper/Utility struct for the `web_sys::MessageEvent` #[derive(Debug, Clone, PartialEq, Eq)] pub struct MessageEvent { event: web_sys::MessageEvent, } impl From for MessageEvent { fn from(event: web_sys::MessageEvent) -> Self { Self { event } } } impl AsRef for MessageEvent { fn as_ref(&self) -> &web_sys::MessageEvent { &self.event } } impl MessageEvent { /// Gets the data/payload from the message. fn data(&self) -> JsValue { self.event.data() } pub fn text(&self) -> Option { let value = self.data(); value.as_string() } pub fn bytes(&self) -> Option> { let value = self.data(); if value.is_object() { Some(js_sys::Uint8Array::new(&value).to_vec()) } else { None } } pub fn json(&self) -> crate::Result { let text = match self.text() { Some(text) => text, None => return Err(Error::from("data of message event is not text")), }; serde_json::from_str(&text).map_err(Error::from) } } /// Wrapper/Utility struct for the `web_sys::CloseEvent` #[derive(Debug, Clone, PartialEq, Eq)] pub struct CloseEvent { event: web_sys::CloseEvent, } impl CloseEvent { pub fn reason(&self) -> String { self.event.reason() } pub fn code(&self) -> u16 { self.event.code() } pub fn was_clean(&self) -> bool { self.event.was_clean() } } impl From for CloseEvent { fn from(event: web_sys::CloseEvent) -> Self { Self { event } } } impl AsRef for CloseEvent { fn as_ref(&self) -> &web_sys::CloseEvent { &self.event } } } /// TODO: Convert WebSocket to use `http` types and `reqwest`. #[cfg(feature = "http")] async fn fetch_with_request_raw(request: crate::Request) -> Result { let req = request.inner(); let fut = { let worker: web_sys::WorkerGlobalScope = js_sys::global().unchecked_into(); crate::send::SendFuture::new(JsFuture::from(worker.fetch_with_request(req))) }; let resp = fut.await?; Ok(resp.dyn_into()?) }