wstd_tungstenite/
lib.rs

1//! Async WebSockets for Wasm / WASI 0.2 applications.
2//!
3//! This create is a thin WebSocket wrapper (powered by [`tungstenite`]) on top of [`wstd`] socket supports.
4//!
5//! [`tungstenite`]: https://docs.rs/tungstenite/latest/tungstenite/
6//! [`wstd`]: https://docs.rs/wstd/latest/wstd/
7
8#![deny(
9    missing_docs,
10    unused_must_use,
11    unused_mut,
12    unused_imports,
13    unused_import_braces
14)]
15
16pub use tungstenite;
17
18mod compat;
19mod connect;
20mod handshake;
21
22use std::io::{Read, Write};
23
24use compat::{AllowStd, ContextWaker, cvt};
25use futures_util::{
26    sink::{Sink, SinkExt},
27    stream::{FusedStream, Stream},
28};
29use std::{
30    pin::Pin,
31    task::{Context, Poll},
32};
33use wstd::io::{AsyncRead, AsyncWrite};
34
35#[cfg(feature = "handshake")]
36use tungstenite::{
37    client::IntoClientRequest,
38    handshake::{
39        HandshakeError,
40        client::{ClientHandshake, Response},
41        server::{Callback, NoCallback},
42    },
43};
44use tungstenite::{
45    error::Error as WsError,
46    protocol::{Message, Role, WebSocket, WebSocketConfig},
47};
48
49#[cfg(feature = "connect")]
50pub use connect::{connect_async, connect_async_with_config};
51
52use tungstenite::protocol::CloseFrame;
53
54/// Creates a WebSocket handshake from a request and a stream.
55/// For convenience, the user may call this with a url string, a URL,
56/// or a `Request`. Calling with `Request` allows the user to add
57/// a WebSocket protocol or other custom headers.
58///
59/// Internally, this custom creates a handshake representation and returns
60/// a future representing the resolution of the WebSocket handshake. The
61/// returned future will resolve to either `WebSocketStream<S>` or `Error`
62/// depending on whether the handshake is successful.
63///
64/// This is typically used for clients who have already established, for
65/// example, a TCP connection to the remote server.
66#[cfg(feature = "handshake")]
67pub async fn client_async<R, S>(
68    request: R,
69    stream: S,
70) -> Result<(WebSocketStream<S>, Response), WsError>
71where
72    R: IntoClientRequest + Unpin,
73    S: AsyncRead + AsyncWrite + Unpin,
74{
75    client_async_with_config(request, stream, None).await
76}
77
78/// The same as `client_async()` but the one can specify a websocket configuration.
79/// Please refer to `client_async()` for more details.
80#[cfg(feature = "handshake")]
81pub async fn client_async_with_config<R, S>(
82    request: R,
83    stream: S,
84    config: Option<WebSocketConfig>,
85) -> Result<(WebSocketStream<S>, Response), WsError>
86where
87    R: IntoClientRequest + Unpin,
88    S: AsyncRead + AsyncWrite + Unpin,
89{
90    let f = handshake::client_handshake(stream, move |allow_std| {
91        let request = request.into_client_request()?;
92        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
93        cli_handshake.handshake()
94    });
95    f.await.map_err(|e| match e {
96        HandshakeError::Failure(e) => e,
97        e => WsError::Io(std::io::Error::other(e.to_string())),
98    })
99}
100
101/// Accepts a new WebSocket connection with the provided stream.
102///
103/// This function will internally call `server::accept` to create a
104/// handshake representation and returns a future representing the
105/// resolution of the WebSocket handshake. The returned future will resolve
106/// to either `WebSocketStream<S>` or `Error` depending if it's successful
107/// or not.
108///
109/// This is typically used after a socket has been accepted from a
110/// `TcpListener`. That socket is then passed to this function to perform
111/// the server half of the accepting a client's websocket connection.
112#[cfg(feature = "handshake")]
113pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
114where
115    S: AsyncRead + AsyncWrite + Unpin,
116{
117    accept_hdr_async(stream, NoCallback).await
118}
119
120/// The same as `accept_async()` but the one can specify a websocket configuration.
121/// Please refer to `accept_async()` for more details.
122#[cfg(feature = "handshake")]
123pub async fn accept_async_with_config<S>(
124    stream: S,
125    config: Option<WebSocketConfig>,
126) -> Result<WebSocketStream<S>, WsError>
127where
128    S: AsyncRead + AsyncWrite + Unpin,
129{
130    accept_hdr_async_with_config(stream, NoCallback, config).await
131}
132
133/// Accepts a new WebSocket connection with the provided stream.
134///
135/// This function does the same as `accept_async()` but accepts an extra callback
136/// for header processing. The callback receives headers of the incoming
137/// requests and is able to add extra headers to the reply.
138#[cfg(feature = "handshake")]
139pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
140where
141    S: AsyncRead + AsyncWrite + Unpin,
142    C: Callback + Unpin,
143{
144    accept_hdr_async_with_config(stream, callback, None).await
145}
146
147/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
148/// Please refer to `accept_hdr_async()` for more details.
149#[cfg(feature = "handshake")]
150pub async fn accept_hdr_async_with_config<S, C>(
151    stream: S,
152    callback: C,
153    config: Option<WebSocketConfig>,
154) -> Result<WebSocketStream<S>, WsError>
155where
156    S: AsyncRead + AsyncWrite + Unpin,
157    C: Callback + Unpin,
158{
159    let f = handshake::server_handshake(stream, move |allow_std| {
160        tungstenite::accept_hdr_with_config(allow_std, callback, config)
161    });
162    f.await.map_err(|e| match e {
163        HandshakeError::Failure(e) => e,
164        e => WsError::Io(std::io::Error::other(e.to_string())),
165    })
166}
167
168/// A wrapper around an underlying raw stream which implements the WebSocket
169/// protocol.
170///
171/// A `WebSocketStream<S>` represents a handshake that has been completed
172/// successfully and both the server and the client are ready for receiving
173/// and sending data. Message from a `WebSocketStream<S>` are accessible
174/// through the respective `Stream` and `Sink`. Check more information about
175/// them in `futures-rs` crate documentation or have a look on the examples
176/// and unit tests for this crate.
177#[derive(Debug)]
178pub struct WebSocketStream<S> {
179    inner: WebSocket<AllowStd<S>>,
180    closing: bool,
181    ended: bool,
182    /// Tungstenite is probably ready to receive more data.
183    ///
184    /// `false` once start_send hits `WouldBlock` errors.
185    /// `true` initially and after `flush`ing.
186    ready: bool,
187}
188
189impl<S> WebSocketStream<S> {
190    /// Convert a raw socket into a WebSocketStream without performing a
191    /// handshake.
192    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
193    where
194        S: AsyncRead + AsyncWrite + Unpin,
195    {
196        handshake::without_handshake(stream, move |allow_std| {
197            WebSocket::from_raw_socket(allow_std, role, config)
198        })
199        .await
200    }
201
202    /// Convert a raw socket into a WebSocketStream without performing a
203    /// handshake.
204    pub async fn from_partially_read(
205        stream: S,
206        part: Vec<u8>,
207        role: Role,
208        config: Option<WebSocketConfig>,
209    ) -> Self
210    where
211        S: AsyncRead + AsyncWrite + Unpin,
212    {
213        handshake::without_handshake(stream, move |allow_std| {
214            WebSocket::from_partially_read(allow_std, part, role, config)
215        })
216        .await
217    }
218
219    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
220        Self {
221            inner: ws,
222            closing: false,
223            ended: false,
224            ready: true,
225        }
226    }
227
228    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
229    where
230        S: Unpin,
231        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
232        AllowStd<S>: Read + Write,
233    {
234        if let Some((kind, ctx)) = ctx {
235            self.inner.get_mut().set_waker(kind, ctx.waker());
236        }
237        f(&mut self.inner)
238    }
239
240    /// Consumes the `WebSocketStream` and returns the underlying stream.
241    pub fn into_inner(self) -> S {
242        self.inner.into_inner().into_inner()
243    }
244
245    /// Returns a shared reference to the inner stream.
246    pub fn get_ref(&self) -> &S
247    where
248        S: AsyncRead + AsyncWrite + Unpin,
249    {
250        self.inner.get_ref().get_ref()
251    }
252
253    /// Returns a mutable reference to the inner stream.
254    pub fn get_mut(&mut self) -> &mut S
255    where
256        S: AsyncRead + AsyncWrite + Unpin,
257    {
258        self.inner.get_mut().get_mut()
259    }
260
261    /// Returns a reference to the configuration of the tungstenite stream.
262    pub fn get_config(&self) -> &WebSocketConfig {
263        self.inner.get_config()
264    }
265
266    /// Close the underlying web socket
267    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
268    where
269        S: AsyncRead + AsyncWrite + Unpin,
270    {
271        self.send(Message::Close(msg)).await
272    }
273}
274
275impl<T> Stream for WebSocketStream<T>
276where
277    T: AsyncRead + AsyncWrite + Unpin,
278{
279    type Item = Result<Message, WsError>;
280
281    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
282        // The connection has been closed or a critical error has occurred.
283        // We have already returned the error to the user, the `Stream` is unusable,
284        // so we assume that the stream has been "fused".
285        if self.ended {
286            return Poll::Ready(None);
287        }
288
289        match futures_util::ready!(
290            self.with_context(Some((ContextWaker::Read, cx)), |s| { cvt(s.read()) })
291        ) {
292            Ok(v) => Poll::Ready(Some(Ok(v))),
293            Err(e) => {
294                self.ended = true;
295                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
296                    Poll::Ready(None)
297                } else {
298                    Poll::Ready(Some(Err(e)))
299                }
300            }
301        }
302    }
303}
304
305impl<T> FusedStream for WebSocketStream<T>
306where
307    T: AsyncRead + AsyncWrite + Unpin,
308{
309    fn is_terminated(&self) -> bool {
310        self.ended
311    }
312}
313
314impl<T> Sink<Message> for WebSocketStream<T>
315where
316    T: AsyncRead + AsyncWrite + Unpin,
317{
318    type Error = WsError;
319
320    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
321        if self.ready {
322            Poll::Ready(Ok(()))
323        } else {
324            // Currently blocked so try to flush the blockage away
325            (*self)
326                .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
327                .map(|r| {
328                    self.ready = true;
329                    r
330                })
331        }
332    }
333
334    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
335        match (*self).with_context(None, |s| s.write(item)) {
336            Ok(()) => {
337                self.ready = true;
338                Ok(())
339            }
340            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
341                // the message was accepted and queued so not an error
342                // but `poll_ready` will now start trying to flush the block
343                self.ready = false;
344                Ok(())
345            }
346            Err(e) => {
347                self.ready = true;
348                Err(e)
349            }
350        }
351    }
352
353    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
354        (*self)
355            .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
356            .map(|r| {
357                self.ready = true;
358                match r {
359                    // WebSocket connection has just been closed. Flushing completed, not an error.
360                    Err(WsError::ConnectionClosed) => Ok(()),
361                    other => other,
362                }
363            })
364    }
365
366    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
367        self.ready = true;
368        let res = if self.closing {
369            // After queueing it, we call `flush` to drive the close handshake to completion.
370            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
371        } else {
372            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
373        };
374
375        match res {
376            Ok(()) => Poll::Ready(Ok(())),
377            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
378            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
379                self.closing = true;
380                Poll::Pending
381            }
382            Err(err) => Poll::Ready(Err(err)),
383        }
384    }
385}
386
387/// Get a domain from an URL.
388#[cfg(feature = "connect")]
389#[inline]
390fn domain(request: &tungstenite::handshake::client::Request) -> Result<String, WsError> {
391    match request.uri().host() {
392        Some(d) => Ok(d.to_string()),
393        None => Err(WsError::Url(tungstenite::error::UrlError::NoHostName)),
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use crate::{WebSocketStream, compat::AllowStd};
400    use std::io::{Read, Write};
401
402    fn is_read<T: Read>() {}
403    fn is_write<T: Write>() {}
404    fn is_unpin<T: Unpin>() {}
405
406    #[test]
407    fn web_socket_stream_has_traits() {
408        is_read::<AllowStd<wstd::net::TcpStream>>();
409        is_write::<AllowStd<wstd::net::TcpStream>>();
410        is_unpin::<WebSocketStream<wstd::net::TcpStream>>();
411    }
412}