branch: main
socket.rs
16381 bytesRaw
use std::{
convert::TryFrom,
pin::Pin,
task::{Context, Poll},
};
use crate::Result;
use crate::{r2::js_object, Error};
use futures_util::FutureExt;
use js_sys::{
Boolean as JsBoolean, Error as JsError, JsString, Number as JsNumber, Object as JsObject,
Reflect, Uint8Array,
};
use std::convert::TryInto;
use std::io::Error as IoError;
use std::io::Result as IoResult;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::JsFuture;
use web_sys::{
ReadableStream, ReadableStreamDefaultReader, WritableStream, WritableStreamDefaultWriter,
};
#[derive(Debug)]
pub struct SocketInfo {
pub remote_address: Option<String>,
pub local_address: Option<String>,
}
impl TryFrom<JsValue> for SocketInfo {
type Error = Error;
fn try_from(value: JsValue) -> Result<Self> {
let remote_address_value =
js_sys::Reflect::get(&value, &JsValue::from_str("remoteAddress"))?;
let local_address_value = js_sys::Reflect::get(&value, &JsValue::from_str("localAddress"))?;
Ok(Self {
remote_address: remote_address_value.as_string(),
local_address: local_address_value.as_string(),
})
}
}
#[derive(Debug, Default)]
enum Reading {
#[default]
None,
Pending(JsFuture, ReadableStreamDefaultReader),
Ready(Vec<u8>),
}
#[derive(Debug, Default)]
enum Writing {
Pending(JsFuture, WritableStreamDefaultWriter, usize),
#[default]
None,
}
#[derive(Debug, Default)]
enum Closing {
Pending(JsFuture),
#[default]
None,
}
/// Represents an outbound TCP connection from your Worker.
#[derive(Debug)]
pub struct Socket {
inner: worker_sys::Socket,
writable: WritableStream,
readable: ReadableStream,
write: Option<Writing>,
read: Option<Reading>,
close: Option<Closing>,
}
// This can only be done because workers are single threaded.
unsafe impl Send for Socket {}
unsafe impl Sync for Socket {}
impl Socket {
fn new(inner: worker_sys::Socket) -> Self {
let writable = inner.writable().unwrap();
let readable = inner.readable().unwrap();
Socket {
inner,
writable,
readable,
read: None,
write: None,
close: None,
}
}
/// Closes the TCP socket. Both the readable and writable streams are forcibly closed.
pub async fn close(&mut self) -> Result<()> {
JsFuture::from(self.inner.close()?).await?;
Ok(())
}
/// This Future is resolved when the socket is closed
/// and is rejected if the socket encounters an error.
pub async fn closed(&self) -> Result<()> {
JsFuture::from(self.inner.closed()?).await?;
Ok(())
}
pub async fn opened(&self) -> Result<SocketInfo> {
let value = JsFuture::from(self.inner.opened()?).await?;
value.try_into()
}
/// Upgrades an insecure socket to a secure one that uses TLS,
/// returning a new Socket. Note that in order to call this method,
/// you must set [`secure_transport`](SocketOptions::secure_transport)
/// to [`StartTls`](SecureTransport::StartTls) when initially
/// calling [`connect`](connect) to create the socket.
pub fn start_tls(self) -> Socket {
let inner = self.inner.start_tls().unwrap();
Socket::new(inner)
}
pub fn builder() -> ConnectionBuilder {
ConnectionBuilder::default()
}
fn handle_write_future(
cx: &mut Context<'_>,
mut fut: JsFuture,
writer: WritableStreamDefaultWriter,
len: usize,
) -> (Writing, Poll<IoResult<usize>>) {
match fut.poll_unpin(cx) {
Poll::Pending => (Writing::Pending(fut, writer, len), Poll::Pending),
Poll::Ready(res) => {
writer.release_lock();
match res {
Ok(_) => (Writing::None, Poll::Ready(Ok(len))),
Err(e) => (Writing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
}
}
}
}
}
fn js_value_to_std_io_error(value: JsValue) -> IoError {
let s = if value.is_string() {
value.as_string().unwrap()
} else if let Some(value) = value.dyn_ref::<JsError>() {
value.to_string().into()
} else {
format!("Error interpreting JsError: {value:?}")
};
IoError::other(s)
}
impl AsyncRead for Socket {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
fn handle_future(
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
mut fut: JsFuture,
reader: ReadableStreamDefaultReader,
) -> (Reading, Poll<IoResult<()>>) {
match fut.poll_unpin(cx) {
Poll::Pending => (Reading::Pending(fut, reader), Poll::Pending),
Poll::Ready(res) => match res {
Ok(value) => {
reader.release_lock();
let done: JsBoolean = match Reflect::get(&value, &JsValue::from("done")) {
Ok(value) => value.into(),
Err(error) => {
let msg = format!("Unable to interpret field 'done' in ReadableStreamDefaultReader.read(): {error:?}");
return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
}
};
if done.is_truthy() {
(Reading::None, Poll::Ready(Ok(())))
} else {
let arr: Uint8Array = match Reflect::get(
&value,
&JsValue::from("value"),
) {
Ok(value) => value.into(),
Err(error) => {
let msg = format!("Unable to interpret field 'value' in ReadableStreamDefaultReader.read(): {error:?}");
return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
}
};
let data = arr.to_vec();
handle_data(buf, data)
}
}
Err(e) => (Reading::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
},
}
}
let (new_reading, poll) = match self.read.take().unwrap_or_default() {
Reading::None => {
let reader: ReadableStreamDefaultReader =
match self.readable.get_reader().dyn_into() {
Ok(reader) => reader,
Err(error) => {
let msg = format!(
"Unable to cast JsObject to ReadableStreamDefaultReader: {error:?}"
);
return Poll::Ready(Err(IoError::other(msg)));
}
};
handle_future(cx, buf, JsFuture::from(reader.read()), reader)
}
Reading::Pending(fut, reader) => handle_future(cx, buf, fut, reader),
Reading::Ready(data) => handle_data(buf, data),
};
self.read = Some(new_reading);
poll
}
}
impl AsyncWrite for Socket {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<IoResult<usize>> {
let (new_writing, poll) = match self.write.take().unwrap_or_default() {
Writing::None => {
let obj = JsValue::from(Uint8Array::from(buf));
let writer: WritableStreamDefaultWriter = match self.writable.get_writer() {
Ok(writer) => writer,
Err(error) => {
let msg = format!("Could not retrieve Writer: {error:?}");
return Poll::Ready(Err(IoError::other(msg)));
}
};
Self::handle_write_future(
cx,
JsFuture::from(writer.write_with_chunk(&obj)),
writer,
buf.len(),
)
}
Writing::Pending(fut, writer, len) => Self::handle_write_future(cx, fut, writer, len),
};
self.write = Some(new_writing);
poll
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
// Poll existing write future if it exists.
let (new_writing, poll) = match self.write.take().unwrap_or_default() {
Writing::Pending(fut, writer, len) => {
let (writing, poll) = Self::handle_write_future(cx, fut, writer, len);
// Map poll output to ()
(writing, poll.map(|res| res.map(|_| ())))
}
writing => (writing, Poll::Ready(Ok(()))),
};
self.write = Some(new_writing);
poll
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
fn handle_future(cx: &mut Context<'_>, mut fut: JsFuture) -> (Closing, Poll<IoResult<()>>) {
match fut.poll_unpin(cx) {
Poll::Pending => (Closing::Pending(fut), Poll::Pending),
Poll::Ready(res) => match res {
Ok(_) => (Closing::None, Poll::Ready(Ok(()))),
Err(e) => (Closing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
},
}
}
let (new_closing, poll) = match self.close.take().unwrap_or_default() {
Closing::None => handle_future(cx, JsFuture::from(self.writable.close())),
Closing::Pending(fut) => handle_future(cx, fut),
};
self.close = Some(new_closing);
poll
}
}
/// Secure transport options for outbound TCP connections.
#[derive(Debug, Clone)]
pub enum SecureTransport {
/// Do not use TLS.
Off,
/// Use TLS.
On,
/// Do not use TLS initially, but allow the socket to be upgraded to
/// use TLS by calling [`Socket.start_tls`](Socket::start_tls).
StartTls,
}
/// Used to configure outbound TCP connections.
#[derive(Debug, Clone)]
pub struct SocketOptions {
/// Specifies whether or not to use TLS when creating the TCP socket.
pub secure_transport: SecureTransport,
/// Defines whether the writable side of the TCP socket will automatically
/// close on end-of-file (EOF). When set to false, the writable side of the
/// TCP socket will automatically close on EOF. When set to true, the
/// writable side of the TCP socket will remain open on EOF.
pub allow_half_open: bool,
}
impl Default for SocketOptions {
fn default() -> Self {
SocketOptions {
secure_transport: SecureTransport::Off,
allow_half_open: false,
}
}
}
/// The host and port that you wish to connect to.
#[derive(Debug, Clone)]
pub struct SocketAddress {
/// The hostname to connect to. Example: `cloudflare.com`.
pub hostname: String,
/// The port number to connect to. Example: `5432`.
pub port: u16,
}
#[derive(Default, Debug, Clone)]
pub struct ConnectionBuilder {
options: SocketOptions,
}
impl ConnectionBuilder {
/// Create a new `ConnectionBuilder` with default settings.
pub fn new() -> Self {
ConnectionBuilder {
options: SocketOptions::default(),
}
}
/// Set whether the writable side of the TCP socket will automatically
/// close on end-of-file (EOF).
pub fn allow_half_open(mut self, allow_half_open: bool) -> Self {
self.options.allow_half_open = allow_half_open;
self
}
// Specify whether or not to use TLS when creating the TCP socket.
pub fn secure_transport(mut self, secure_transport: SecureTransport) -> Self {
self.options.secure_transport = secure_transport;
self
}
/// Open the connection to `hostname` on port `port`, returning a [`Socket`](Socket).
pub fn connect(self, hostname: impl Into<String>, port: u16) -> Result<Socket> {
let address: JsValue = js_object!(
"hostname" => JsObject::from(JsString::from(hostname.into())),
"port" => JsNumber::from(port)
)
.into();
let options: JsValue = js_object!(
"allowHalfOpen" => JsBoolean::from(self.options.allow_half_open),
"secureTransport" => JsString::from(match self.options.secure_transport {
SecureTransport::On => "on",
SecureTransport::Off => "off",
SecureTransport::StartTls => "starttls",
})
)
.into();
let inner = worker_sys::connect(address, options)?;
Ok(Socket::new(inner))
}
}
// Writes as much as possible to buf, and stores the rest in internal buffer
fn handle_data(buf: &mut ReadBuf<'_>, mut data: Vec<u8>) -> (Reading, Poll<IoResult<()>>) {
let idx = buf.remaining().min(data.len());
let store = data.split_off(idx);
buf.put_slice(&data);
if store.is_empty() {
(Reading::None, Poll::Ready(Ok(())))
} else {
(Reading::Ready(store), Poll::Ready(Ok(())))
}
}
#[cfg(feature = "tokio-postgres")]
/// Implements [`TlsConnect`](tokio_postgres::TlsConnect) for
/// [`Socket`](crate::Socket) to enable `tokio_postgres` connections
/// to databases using TLS.
pub mod postgres_tls {
use super::Socket;
use futures_util::future::{ready, Ready};
use std::error::Error;
use std::fmt::{self, Display, Formatter};
use tokio_postgres::tls::{ChannelBinding, TlsConnect, TlsStream};
/// Supply this to `connect_raw` in place of `NoTls` to specify TLS
/// when using Workers.
///
/// ```rust
/// let config = tokio_postgres::config::Config::new();
/// let socket = Socket::builder()
/// .secure_transport(SecureTransport::StartTls)
/// .connect("database_url", 5432)?;
/// let _ = config.connect_raw(socket, PassthroughTls).await?;
/// ```
#[derive(Debug, Clone, Default)]
pub struct PassthroughTls;
#[derive(Debug)]
/// Error type for PassthroughTls.
/// Should never be returned.
pub struct PassthroughTlsError;
impl Error for PassthroughTlsError {}
impl Display for PassthroughTlsError {
fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
fmt.write_str("PassthroughTlsError")
}
}
impl TlsConnect<Socket> for PassthroughTls {
type Stream = Socket;
type Error = PassthroughTlsError;
type Future = Ready<Result<Socket, PassthroughTlsError>>;
fn connect(self, s: Self::Stream) -> Self::Future {
let tls = s.start_tls();
ready(Ok(tls))
}
}
impl TlsStream for Socket {
fn channel_binding(&self) -> ChannelBinding {
ChannelBinding::none()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handle_data() {
let mut arr = vec![0u8; 32];
let mut buf = ReadBuf::new(&mut arr);
let data = vec![1u8; 32];
let (reading, _) = handle_data(&mut buf, data);
assert!(matches!(reading, Reading::None));
assert_eq!(buf.remaining(), 0);
assert_eq!(buf.filled().len(), 32);
}
#[test]
fn test_handle_large_data() {
let mut arr = vec![0u8; 32];
let mut buf = ReadBuf::new(&mut arr);
let data = vec![1u8; 64];
let (reading, _) = handle_data(&mut buf, data);
assert!(matches!(reading, Reading::Ready(store) if store.len() == 32));
assert_eq!(buf.remaining(), 0);
assert_eq!(buf.filled().len(), 32);
}
#[test]
fn test_handle_small_data() {
let mut arr = vec![0u8; 32];
let mut buf = ReadBuf::new(&mut arr);
let data = vec![1u8; 16];
let (reading, _) = handle_data(&mut buf, data);
assert!(matches!(reading, Reading::None));
assert_eq!(buf.remaining(), 16);
assert_eq!(buf.filled().len(), 16);
}
}