Skip to main content

worker/
socket.rs

1use std::{
2    convert::TryFrom,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use crate::Result;
8use crate::{r2::js_object, Error};
9use futures_util::FutureExt;
10use js_sys::{
11    Boolean as JsBoolean, Error as JsError, JsString, Number as JsNumber, Object as JsObject,
12    Reflect, Uint8Array,
13};
14use std::convert::TryInto;
15use std::io::Error as IoError;
16use std::io::Result as IoResult;
17use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18use wasm_bindgen::{JsCast, JsValue};
19use wasm_bindgen_futures::JsFuture;
20use web_sys::{
21    ReadableStream, ReadableStreamDefaultReader, WritableStream, WritableStreamDefaultWriter,
22};
23
24#[derive(Debug)]
25pub struct SocketInfo {
26    pub remote_address: Option<String>,
27    pub local_address: Option<String>,
28}
29
30impl TryFrom<JsValue> for SocketInfo {
31    type Error = Error;
32    fn try_from(value: JsValue) -> Result<Self> {
33        let remote_address_value =
34            js_sys::Reflect::get(&value, &JsValue::from_str("remoteAddress"))?;
35        let local_address_value = js_sys::Reflect::get(&value, &JsValue::from_str("localAddress"))?;
36        Ok(Self {
37            remote_address: remote_address_value.as_string(),
38            local_address: local_address_value.as_string(),
39        })
40    }
41}
42
43#[derive(Debug, Default)]
44enum Reading {
45    #[default]
46    None,
47    Pending(JsFuture, ReadableStreamDefaultReader),
48    Ready(Vec<u8>),
49}
50
51#[derive(Debug, Default)]
52enum Writing {
53    Pending(JsFuture, WritableStreamDefaultWriter, usize),
54    #[default]
55    None,
56}
57
58#[derive(Debug, Default)]
59enum Closing {
60    Pending(JsFuture),
61    #[default]
62    None,
63}
64
65/// Represents an outbound TCP connection from your Worker.
66#[derive(Debug)]
67pub struct Socket {
68    inner: worker_sys::Socket,
69    writable: WritableStream,
70    readable: ReadableStream,
71    write: Option<Writing>,
72    read: Option<Reading>,
73    close: Option<Closing>,
74}
75
76// This can only be done because workers are single threaded.
77unsafe impl Send for Socket {}
78unsafe impl Sync for Socket {}
79
80impl Socket {
81    fn new(inner: worker_sys::Socket) -> Self {
82        let writable = inner.writable().unwrap();
83        let readable = inner.readable().unwrap();
84        Socket {
85            inner,
86            writable,
87            readable,
88            read: None,
89            write: None,
90            close: None,
91        }
92    }
93
94    pub(crate) fn from_inner(inner: worker_sys::Socket) -> Self {
95        Self::new(inner)
96    }
97
98    /// Closes the TCP socket. Both the readable and writable streams are forcibly closed.
99    pub async fn close(&mut self) -> Result<()> {
100        JsFuture::from(self.inner.close()?).await?;
101        Ok(())
102    }
103
104    /// This Future is resolved when the socket is closed
105    /// and is rejected if the socket encounters an error.
106    pub async fn closed(&self) -> Result<()> {
107        JsFuture::from(self.inner.closed()?).await?;
108        Ok(())
109    }
110
111    pub async fn opened(&self) -> Result<SocketInfo> {
112        let value = JsFuture::from(self.inner.opened()?).await?;
113        value.try_into()
114    }
115
116    /// Upgrades an insecure socket to a secure one that uses TLS,
117    /// returning a new Socket. Note that in order to call this method,
118    /// you must set [`secure_transport`](SocketOptions::secure_transport)
119    /// to [`StartTls`](SecureTransport::StartTls) when initially
120    /// calling [`connect`](connect) to create the socket.
121    pub fn start_tls(self) -> Socket {
122        let inner = self.inner.start_tls().unwrap();
123        Socket::new(inner)
124    }
125
126    pub fn builder() -> ConnectionBuilder {
127        ConnectionBuilder::default()
128    }
129
130    fn handle_write_future(
131        cx: &mut Context<'_>,
132        mut fut: JsFuture,
133        writer: WritableStreamDefaultWriter,
134        len: usize,
135    ) -> (Writing, Poll<IoResult<usize>>) {
136        match fut.poll_unpin(cx) {
137            Poll::Pending => (Writing::Pending(fut, writer, len), Poll::Pending),
138            Poll::Ready(res) => {
139                writer.release_lock();
140                match res {
141                    Ok(_) => (Writing::None, Poll::Ready(Ok(len))),
142                    Err(e) => (Writing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
143                }
144            }
145        }
146    }
147}
148
149fn js_value_to_std_io_error(value: JsValue) -> IoError {
150    let s = if value.is_string() {
151        value.as_string().unwrap()
152    } else if let Some(value) = value.dyn_ref::<JsError>() {
153        value.to_string().into()
154    } else {
155        format!("Error interpreting JsError: {value:?}")
156    };
157    IoError::other(s)
158}
159impl AsyncRead for Socket {
160    fn poll_read(
161        mut self: Pin<&mut Self>,
162        cx: &mut Context<'_>,
163        buf: &mut ReadBuf<'_>,
164    ) -> Poll<IoResult<()>> {
165        fn handle_future(
166            cx: &mut Context<'_>,
167            buf: &mut ReadBuf<'_>,
168            mut fut: JsFuture,
169            reader: ReadableStreamDefaultReader,
170        ) -> (Reading, Poll<IoResult<()>>) {
171            match fut.poll_unpin(cx) {
172                Poll::Pending => (Reading::Pending(fut, reader), Poll::Pending),
173                Poll::Ready(res) => match res {
174                    Ok(value) => {
175                        reader.release_lock();
176                        let done: JsBoolean = match Reflect::get(&value, &JsValue::from("done")) {
177                            Ok(value) => value.into(),
178                            Err(error) => {
179                                let msg = format!("Unable to interpret field 'done' in ReadableStreamDefaultReader.read(): {error:?}");
180                                return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
181                            }
182                        };
183                        if done.is_truthy() {
184                            (Reading::None, Poll::Ready(Ok(())))
185                        } else {
186                            let arr: Uint8Array = match Reflect::get(
187                                &value,
188                                &JsValue::from("value"),
189                            ) {
190                                Ok(value) => value.into(),
191                                Err(error) => {
192                                    let msg = format!("Unable to interpret field 'value' in ReadableStreamDefaultReader.read(): {error:?}");
193                                    return (Reading::None, Poll::Ready(Err(IoError::other(msg))));
194                                }
195                            };
196                            let data = arr.to_vec();
197                            handle_data(buf, data)
198                        }
199                    }
200                    Err(e) => (Reading::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
201                },
202            }
203        }
204
205        let (new_reading, poll) = match self.read.take().unwrap_or_default() {
206            Reading::None => {
207                let reader: ReadableStreamDefaultReader =
208                    match self.readable.get_reader().dyn_into() {
209                        Ok(reader) => reader,
210                        Err(error) => {
211                            let msg = format!(
212                                "Unable to cast JsObject to ReadableStreamDefaultReader: {error:?}"
213                            );
214                            return Poll::Ready(Err(IoError::other(msg)));
215                        }
216                    };
217
218                handle_future(cx, buf, JsFuture::from(reader.read()), reader)
219            }
220            Reading::Pending(fut, reader) => handle_future(cx, buf, fut, reader),
221            Reading::Ready(data) => handle_data(buf, data),
222        };
223        self.read = Some(new_reading);
224        poll
225    }
226}
227
228impl AsyncWrite for Socket {
229    fn poll_write(
230        mut self: Pin<&mut Self>,
231        cx: &mut Context<'_>,
232        buf: &[u8],
233    ) -> Poll<IoResult<usize>> {
234        let (new_writing, poll) = match self.write.take().unwrap_or_default() {
235            Writing::None => {
236                let obj = JsValue::from(Uint8Array::from(buf));
237                let writer: WritableStreamDefaultWriter = match self.writable.get_writer() {
238                    Ok(writer) => writer,
239                    Err(error) => {
240                        let msg = format!("Could not retrieve Writer: {error:?}");
241                        return Poll::Ready(Err(IoError::other(msg)));
242                    }
243                };
244                Self::handle_write_future(
245                    cx,
246                    JsFuture::from(writer.write_with_chunk(&obj)),
247                    writer,
248                    buf.len(),
249                )
250            }
251            Writing::Pending(fut, writer, len) => Self::handle_write_future(cx, fut, writer, len),
252        };
253        self.write = Some(new_writing);
254        poll
255    }
256
257    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
258        // Poll existing write future if it exists.
259        let (new_writing, poll) = match self.write.take().unwrap_or_default() {
260            Writing::Pending(fut, writer, len) => {
261                let (writing, poll) = Self::handle_write_future(cx, fut, writer, len);
262                // Map poll output to ()
263                (writing, poll.map(|res| res.map(|_| ())))
264            }
265            writing => (writing, Poll::Ready(Ok(()))),
266        };
267        self.write = Some(new_writing);
268        poll
269    }
270
271    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
272        fn handle_future(cx: &mut Context<'_>, mut fut: JsFuture) -> (Closing, Poll<IoResult<()>>) {
273            match fut.poll_unpin(cx) {
274                Poll::Pending => (Closing::Pending(fut), Poll::Pending),
275                Poll::Ready(res) => match res {
276                    Ok(_) => (Closing::None, Poll::Ready(Ok(()))),
277                    Err(e) => (Closing::None, Poll::Ready(Err(js_value_to_std_io_error(e)))),
278                },
279            }
280        }
281        let (new_closing, poll) = match self.close.take().unwrap_or_default() {
282            Closing::None => handle_future(cx, JsFuture::from(self.writable.close())),
283            Closing::Pending(fut) => handle_future(cx, fut),
284        };
285        self.close = Some(new_closing);
286        poll
287    }
288}
289
290/// Secure transport options for outbound TCP connections.
291#[derive(Debug, Clone)]
292pub enum SecureTransport {
293    /// Do not use TLS.
294    Off,
295    /// Use TLS.
296    On,
297    /// Do not use TLS initially, but allow the socket to be upgraded to
298    /// use TLS by calling [`Socket.start_tls`](Socket::start_tls).
299    StartTls,
300}
301
302/// Used to configure outbound TCP connections.
303#[derive(Debug, Clone)]
304pub struct SocketOptions {
305    /// Specifies whether or not to use TLS when creating the TCP socket.
306    pub secure_transport: SecureTransport,
307    /// Defines whether the writable side of the TCP socket will automatically
308    /// close on end-of-file (EOF). When set to false, the writable side of the
309    /// TCP socket will automatically close on EOF. When set to true, the
310    /// writable side of the TCP socket will remain open on EOF.
311    pub allow_half_open: bool,
312}
313
314impl Default for SocketOptions {
315    fn default() -> Self {
316        SocketOptions {
317            secure_transport: SecureTransport::Off,
318            allow_half_open: false,
319        }
320    }
321}
322
323/// The host and port that you wish to connect to.
324#[derive(Debug, Clone)]
325pub struct SocketAddress {
326    /// The hostname to connect to. Example: `cloudflare.com`.
327    pub hostname: String,
328    /// The port number to connect to. Example: `5432`.
329    pub port: u16,
330}
331
332#[derive(Default, Debug, Clone)]
333pub struct ConnectionBuilder {
334    options: SocketOptions,
335}
336
337impl ConnectionBuilder {
338    /// Create a new `ConnectionBuilder` with default settings.
339    pub fn new() -> Self {
340        ConnectionBuilder {
341            options: SocketOptions::default(),
342        }
343    }
344
345    /// Set whether the writable side of the TCP socket will automatically
346    /// close on end-of-file (EOF).
347    pub fn allow_half_open(mut self, allow_half_open: bool) -> Self {
348        self.options.allow_half_open = allow_half_open;
349        self
350    }
351
352    // Specify whether or not to use TLS when creating the TCP socket.
353    pub fn secure_transport(mut self, secure_transport: SecureTransport) -> Self {
354        self.options.secure_transport = secure_transport;
355        self
356    }
357
358    /// Open the connection to `hostname` on port `port`, returning a [`Socket`](Socket).
359    pub fn connect(self, hostname: impl Into<String>, port: u16) -> Result<Socket> {
360        let address: JsValue = js_object!(
361            "hostname" => JsObject::from(JsString::from(hostname.into())),
362            "port" => JsNumber::from(port)
363        )
364        .into();
365
366        let options = socket_options_to_js_value(&self.options);
367
368        let inner = worker_sys::connect(address, options)?;
369        Ok(Socket::new(inner))
370    }
371}
372
373pub(crate) fn secure_transport_label(secure_transport: &SecureTransport) -> &'static str {
374    match secure_transport {
375        SecureTransport::On => "on",
376        SecureTransport::Off => "off",
377        SecureTransport::StartTls => "starttls",
378    }
379}
380
381pub(crate) fn socket_options_to_js_value(options: &SocketOptions) -> JsValue {
382    js_object!(
383        "allowHalfOpen" => JsBoolean::from(options.allow_half_open),
384        "secureTransport" => JsString::from(secure_transport_label(&options.secure_transport))
385    )
386    .into()
387}
388
389// Writes as much as possible to buf, and stores the rest in internal buffer
390fn handle_data(buf: &mut ReadBuf<'_>, mut data: Vec<u8>) -> (Reading, Poll<IoResult<()>>) {
391    let idx = buf.remaining().min(data.len());
392    let store = data.split_off(idx);
393    buf.put_slice(&data);
394    if store.is_empty() {
395        (Reading::None, Poll::Ready(Ok(())))
396    } else {
397        (Reading::Ready(store), Poll::Ready(Ok(())))
398    }
399}
400
401#[cfg(feature = "tokio-postgres")]
402/// Implements [`TlsConnect`](tokio_postgres::TlsConnect) for
403/// [`Socket`](crate::Socket) to enable `tokio_postgres` connections
404/// to databases using TLS.
405pub mod postgres_tls {
406    use super::Socket;
407    use futures_util::future::{ready, Ready};
408    use std::error::Error;
409    use std::fmt::{self, Display, Formatter};
410    use tokio_postgres::tls::{ChannelBinding, TlsConnect, TlsStream};
411
412    /// Supply this to `connect_raw` in place of `NoTls` to specify TLS
413    /// when using Workers.
414    ///
415    /// ```rust
416    /// let config = tokio_postgres::config::Config::new();
417    /// let socket = Socket::builder()
418    ///     .secure_transport(SecureTransport::StartTls)
419    ///     .connect("database_url", 5432)?;
420    /// let _ = config.connect_raw(socket, PassthroughTls).await?;
421    /// ```
422    #[derive(Debug, Clone, Default)]
423    pub struct PassthroughTls;
424
425    #[derive(Debug)]
426    /// Error type for PassthroughTls.
427    /// Should never be returned.
428    pub struct PassthroughTlsError;
429
430    impl Error for PassthroughTlsError {}
431
432    impl Display for PassthroughTlsError {
433        fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
434            fmt.write_str("PassthroughTlsError")
435        }
436    }
437
438    impl TlsConnect<Socket> for PassthroughTls {
439        type Stream = Socket;
440        type Error = PassthroughTlsError;
441        type Future = Ready<Result<Socket, PassthroughTlsError>>;
442
443        fn connect(self, s: Self::Stream) -> Self::Future {
444            let tls = s.start_tls();
445            ready(Ok(tls))
446        }
447    }
448
449    impl TlsStream for Socket {
450        fn channel_binding(&self) -> ChannelBinding {
451            ChannelBinding::none()
452        }
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    #[test]
461    fn secure_transport_labels_match_runtime_strings() {
462        assert_eq!(secure_transport_label(&SecureTransport::On), "on");
463        assert_eq!(secure_transport_label(&SecureTransport::Off), "off");
464        assert_eq!(secure_transport_label(&SecureTransport::StartTls), "starttls");
465    }
466
467    #[test]
468    fn test_handle_data() {
469        let mut arr = vec![0u8; 32];
470        let mut buf = ReadBuf::new(&mut arr);
471        let data = vec![1u8; 32];
472        let (reading, _) = handle_data(&mut buf, data);
473
474        assert!(matches!(reading, Reading::None));
475        assert_eq!(buf.remaining(), 0);
476        assert_eq!(buf.filled().len(), 32);
477    }
478
479    #[test]
480    fn test_handle_large_data() {
481        let mut arr = vec![0u8; 32];
482        let mut buf = ReadBuf::new(&mut arr);
483        let data = vec![1u8; 64];
484        let (reading, _) = handle_data(&mut buf, data);
485
486        assert!(matches!(reading, Reading::Ready(store) if store.len() == 32));
487        assert_eq!(buf.remaining(), 0);
488        assert_eq!(buf.filled().len(), 32);
489    }
490
491    #[test]
492    fn test_handle_small_data() {
493        let mut arr = vec![0u8; 32];
494        let mut buf = ReadBuf::new(&mut arr);
495        let data = vec![1u8; 16];
496        let (reading, _) = handle_data(&mut buf, data);
497
498        assert!(matches!(reading, Reading::None));
499        assert_eq!(buf.remaining(), 16);
500        assert_eq!(buf.filled().len(), 16);
501    }
502}