use std::{ convert::TryFrom, pin::Pin, task::{Context, Poll}, }; use crate::Result; use crate::{r2::js_object, Error}; use futures_util::FutureExt; use js_sys::{ Boolean as JsBoolean, Error as JsError, JsString, Number as JsNumber, Object as JsObject, Reflect, Uint8Array, }; use std::convert::TryInto; use std::io::Error as IoError; use std::io::Result as IoResult; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use wasm_bindgen::{JsCast, JsValue}; use wasm_bindgen_futures::JsFuture; use web_sys::{ ReadableStream, ReadableStreamDefaultReader, WritableStream, WritableStreamDefaultWriter, }; #[derive(Debug)] pub struct SocketInfo { pub remote_address: Option, pub local_address: Option, } impl TryFrom for SocketInfo { type Error = Error; fn try_from(value: JsValue) -> Result { let remote_address_value = js_sys::Reflect::get(&value, &JsValue::from_str("remoteAddress"))?; let local_address_value = js_sys::Reflect::get(&value, &JsValue::from_str("localAddress"))?; Ok(Self { remote_address: remote_address_value.as_string(), local_address: local_address_value.as_string(), }) } } #[derive(Debug, Default)] enum Reading { #[default] None, Pending(JsFuture, ReadableStreamDefaultReader), Ready(Vec), } #[derive(Debug, Default)] enum Writing { Pending(JsFuture, WritableStreamDefaultWriter, usize), #[default] None, } #[derive(Debug, Default)] enum Closing { Pending(JsFuture), #[default] None, } /// Represents an outbound TCP connection from your Worker. #[derive(Debug)] pub struct Socket { inner: worker_sys::Socket, writable: WritableStream, readable: ReadableStream, write: Option, read: Option, close: Option, } // This can only be done because workers are single threaded. unsafe impl Send for Socket {} unsafe impl Sync for Socket {} impl Socket { fn new(inner: worker_sys::Socket) -> Self { let writable = inner.writable().unwrap(); let readable = inner.readable().unwrap(); Socket { inner, writable, readable, read: None, write: None, close: None, } } /// Closes the TCP socket. Both the readable and writable streams are forcibly closed. pub async fn close(&mut self) -> Result<()> { JsFuture::from(self.inner.close()?).await?; Ok(()) } /// This Future is resolved when the socket is closed /// and is rejected if the socket encounters an error. pub async fn closed(&self) -> Result<()> { JsFuture::from(self.inner.closed()?).await?; Ok(()) } pub async fn opened(&self) -> Result { let value = JsFuture::from(self.inner.opened()?).await?; value.try_into() } /// Upgrades an insecure socket to a secure one that uses TLS, /// returning a new Socket. Note that in order to call this method, /// you must set [`secure_transport`](SocketOptions::secure_transport) /// to [`StartTls`](SecureTransport::StartTls) when initially /// calling [`connect`](connect) to create the socket. pub fn start_tls(self) -> Socket { let inner = self.inner.start_tls().unwrap(); Socket::new(inner) } pub fn builder() -> ConnectionBuilder { ConnectionBuilder::default() } fn handle_write_future( cx: &mut Context<'_>, mut fut: JsFuture, writer: WritableStreamDefaultWriter, len: usize, ) -> (Writing, Poll>) { match fut.poll_unpin(cx) { Poll::Pending => (Writing::Pending(fut, writer, len), Poll::Pending), Poll::Ready(res) => { writer.release_lock(); match res { Ok(_) => (Writing::None, Poll::Ready(Ok(len))), Err(e) => (Writing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))), } } } } } fn js_value_to_std_io_error(value: JsValue) -> IoError { let s = if value.is_string() { value.as_string().unwrap() } else if let Some(value) = value.dyn_ref::() { value.to_string().into() } else { format!("Error interpreting JsError: {value:?}") }; IoError::other(s) } impl AsyncRead for Socket { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { fn handle_future( cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, mut fut: JsFuture, reader: ReadableStreamDefaultReader, ) -> (Reading, Poll>) { match fut.poll_unpin(cx) { Poll::Pending => (Reading::Pending(fut, reader), Poll::Pending), Poll::Ready(res) => match res { Ok(value) => { reader.release_lock(); let done: JsBoolean = match Reflect::get(&value, &JsValue::from("done")) { Ok(value) => value.into(), Err(error) => { let msg = format!("Unable to interpret field 'done' in ReadableStreamDefaultReader.read(): {error:?}"); return (Reading::None, Poll::Ready(Err(IoError::other(msg)))); } }; if done.is_truthy() { (Reading::None, Poll::Ready(Ok(()))) } else { let arr: Uint8Array = match Reflect::get( &value, &JsValue::from("value"), ) { Ok(value) => value.into(), Err(error) => { let msg = format!("Unable to interpret field 'value' in ReadableStreamDefaultReader.read(): {error:?}"); return (Reading::None, Poll::Ready(Err(IoError::other(msg)))); } }; let data = arr.to_vec(); handle_data(buf, data) } } Err(e) => (Reading::None, Poll::Ready(Err(js_value_to_std_io_error(e)))), }, } } let (new_reading, poll) = match self.read.take().unwrap_or_default() { Reading::None => { let reader: ReadableStreamDefaultReader = match self.readable.get_reader().dyn_into() { Ok(reader) => reader, Err(error) => { let msg = format!( "Unable to cast JsObject to ReadableStreamDefaultReader: {error:?}" ); return Poll::Ready(Err(IoError::other(msg))); } }; handle_future(cx, buf, JsFuture::from(reader.read()), reader) } Reading::Pending(fut, reader) => handle_future(cx, buf, fut, reader), Reading::Ready(data) => handle_data(buf, data), }; self.read = Some(new_reading); poll } } impl AsyncWrite for Socket { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let (new_writing, poll) = match self.write.take().unwrap_or_default() { Writing::None => { let obj = JsValue::from(Uint8Array::from(buf)); let writer: WritableStreamDefaultWriter = match self.writable.get_writer() { Ok(writer) => writer, Err(error) => { let msg = format!("Could not retrieve Writer: {error:?}"); return Poll::Ready(Err(IoError::other(msg))); } }; Self::handle_write_future( cx, JsFuture::from(writer.write_with_chunk(&obj)), writer, buf.len(), ) } Writing::Pending(fut, writer, len) => Self::handle_write_future(cx, fut, writer, len), }; self.write = Some(new_writing); poll } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // Poll existing write future if it exists. let (new_writing, poll) = match self.write.take().unwrap_or_default() { Writing::Pending(fut, writer, len) => { let (writing, poll) = Self::handle_write_future(cx, fut, writer, len); // Map poll output to () (writing, poll.map(|res| res.map(|_| ()))) } writing => (writing, Poll::Ready(Ok(()))), }; self.write = Some(new_writing); poll } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { fn handle_future(cx: &mut Context<'_>, mut fut: JsFuture) -> (Closing, Poll>) { match fut.poll_unpin(cx) { Poll::Pending => (Closing::Pending(fut), Poll::Pending), Poll::Ready(res) => match res { Ok(_) => (Closing::None, Poll::Ready(Ok(()))), Err(e) => (Closing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))), }, } } let (new_closing, poll) = match self.close.take().unwrap_or_default() { Closing::None => handle_future(cx, JsFuture::from(self.writable.close())), Closing::Pending(fut) => handle_future(cx, fut), }; self.close = Some(new_closing); poll } } /// Secure transport options for outbound TCP connections. #[derive(Debug, Clone)] pub enum SecureTransport { /// Do not use TLS. Off, /// Use TLS. On, /// Do not use TLS initially, but allow the socket to be upgraded to /// use TLS by calling [`Socket.start_tls`](Socket::start_tls). StartTls, } /// Used to configure outbound TCP connections. #[derive(Debug, Clone)] pub struct SocketOptions { /// Specifies whether or not to use TLS when creating the TCP socket. pub secure_transport: SecureTransport, /// Defines whether the writable side of the TCP socket will automatically /// close on end-of-file (EOF). When set to false, the writable side of the /// TCP socket will automatically close on EOF. When set to true, the /// writable side of the TCP socket will remain open on EOF. pub allow_half_open: bool, } impl Default for SocketOptions { fn default() -> Self { SocketOptions { secure_transport: SecureTransport::Off, allow_half_open: false, } } } /// The host and port that you wish to connect to. #[derive(Debug, Clone)] pub struct SocketAddress { /// The hostname to connect to. Example: `cloudflare.com`. pub hostname: String, /// The port number to connect to. Example: `5432`. pub port: u16, } #[derive(Default, Debug, Clone)] pub struct ConnectionBuilder { options: SocketOptions, } impl ConnectionBuilder { /// Create a new `ConnectionBuilder` with default settings. pub fn new() -> Self { ConnectionBuilder { options: SocketOptions::default(), } } /// Set whether the writable side of the TCP socket will automatically /// close on end-of-file (EOF). pub fn allow_half_open(mut self, allow_half_open: bool) -> Self { self.options.allow_half_open = allow_half_open; self } // Specify whether or not to use TLS when creating the TCP socket. pub fn secure_transport(mut self, secure_transport: SecureTransport) -> Self { self.options.secure_transport = secure_transport; self } /// Open the connection to `hostname` on port `port`, returning a [`Socket`](Socket). pub fn connect(self, hostname: impl Into, port: u16) -> Result { let address: JsValue = js_object!( "hostname" => JsObject::from(JsString::from(hostname.into())), "port" => JsNumber::from(port) ) .into(); let options: JsValue = js_object!( "allowHalfOpen" => JsBoolean::from(self.options.allow_half_open), "secureTransport" => JsString::from(match self.options.secure_transport { SecureTransport::On => "on", SecureTransport::Off => "off", SecureTransport::StartTls => "starttls", }) ) .into(); let inner = worker_sys::connect(address, options)?; Ok(Socket::new(inner)) } } // Writes as much as possible to buf, and stores the rest in internal buffer fn handle_data(buf: &mut ReadBuf<'_>, mut data: Vec) -> (Reading, Poll>) { let idx = buf.remaining().min(data.len()); let store = data.split_off(idx); buf.put_slice(&data); if store.is_empty() { (Reading::None, Poll::Ready(Ok(()))) } else { (Reading::Ready(store), Poll::Ready(Ok(()))) } } #[cfg(feature = "tokio-postgres")] /// Implements [`TlsConnect`](tokio_postgres::TlsConnect) for /// [`Socket`](crate::Socket) to enable `tokio_postgres` connections /// to databases using TLS. pub mod postgres_tls { use super::Socket; use futures_util::future::{ready, Ready}; use std::error::Error; use std::fmt::{self, Display, Formatter}; use tokio_postgres::tls::{ChannelBinding, TlsConnect, TlsStream}; /// Supply this to `connect_raw` in place of `NoTls` to specify TLS /// when using Workers. /// /// ```rust /// let config = tokio_postgres::config::Config::new(); /// let socket = Socket::builder() /// .secure_transport(SecureTransport::StartTls) /// .connect("database_url", 5432)?; /// let _ = config.connect_raw(socket, PassthroughTls).await?; /// ``` #[derive(Debug, Clone, Default)] pub struct PassthroughTls; #[derive(Debug)] /// Error type for PassthroughTls. /// Should never be returned. pub struct PassthroughTlsError; impl Error for PassthroughTlsError {} impl Display for PassthroughTlsError { fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result { fmt.write_str("PassthroughTlsError") } } impl TlsConnect for PassthroughTls { type Stream = Socket; type Error = PassthroughTlsError; type Future = Ready>; fn connect(self, s: Self::Stream) -> Self::Future { let tls = s.start_tls(); ready(Ok(tls)) } } impl TlsStream for Socket { fn channel_binding(&self) -> ChannelBinding { ChannelBinding::none() } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_handle_data() { let mut arr = vec![0u8; 32]; let mut buf = ReadBuf::new(&mut arr); let data = vec![1u8; 32]; let (reading, _) = handle_data(&mut buf, data); assert!(matches!(reading, Reading::None)); assert_eq!(buf.remaining(), 0); assert_eq!(buf.filled().len(), 32); } #[test] fn test_handle_large_data() { let mut arr = vec![0u8; 32]; let mut buf = ReadBuf::new(&mut arr); let data = vec![1u8; 64]; let (reading, _) = handle_data(&mut buf, data); assert!(matches!(reading, Reading::Ready(store) if store.len() == 32)); assert_eq!(buf.remaining(), 0); assert_eq!(buf.filled().len(), 32); } #[test] fn test_handle_small_data() { let mut arr = vec![0u8; 32]; let mut buf = ReadBuf::new(&mut arr); let data = vec![1u8; 16]; let (reading, _) = handle_data(&mut buf, data); assert!(matches!(reading, Reading::None)); assert_eq!(buf.remaining(), 16); assert_eq!(buf.filled().len(), 16); } }