wstd_tungstenite/
lib.rs

1//! Async WebSockets for Wasm / WASI 0.2 applications.
2//!
3//! This create is a thin WebSocket wrapper (powered by [`tungstenite`]) on top of [`wstd`] socket supports.
4//!
5//! [`tungstenite`]: https://docs.rs/tungstenite/latest/tungstenite/
6//! [`wstd`]: https://docs.rs/wstd/latest/wstd/
7
8#![deny(
9    missing_docs,
10    unused_must_use,
11    unused_mut,
12    unused_imports,
13    unused_import_braces
14)]
15
16pub use tungstenite;
17
18mod compat;
19#[cfg(feature = "connect")]
20mod connect;
21mod handshake;
22
23use std::io::{Read, Write};
24
25use compat::{AllowStd, ContextWaker, cvt};
26use futures_util::{
27    sink::{Sink, SinkExt},
28    stream::{FusedStream, Stream},
29};
30use std::{
31    pin::Pin,
32    task::{Context, Poll},
33};
34use wstd::io::{AsyncRead, AsyncWrite};
35
36#[cfg(feature = "handshake")]
37use tungstenite::{
38    client::IntoClientRequest,
39    handshake::{
40        HandshakeError,
41        client::{ClientHandshake, Response},
42        server::{Callback, NoCallback},
43    },
44};
45use tungstenite::{
46    error::Error as WsError,
47    protocol::{Message, Role, WebSocket, WebSocketConfig},
48};
49
50#[cfg(feature = "connect")]
51pub use connect::{connect_async, connect_async_with_config};
52
53use tungstenite::protocol::CloseFrame;
54
55/// Creates a WebSocket handshake from a request and a stream.
56/// For convenience, the user may call this with a url string, a URL,
57/// or a `Request`. Calling with `Request` allows the user to add
58/// a WebSocket protocol or other custom headers.
59///
60/// Internally, this custom creates a handshake representation and returns
61/// a future representing the resolution of the WebSocket handshake. The
62/// returned future will resolve to either `WebSocketStream<S>` or `Error`
63/// depending on whether the handshake is successful.
64///
65/// This is typically used for clients who have already established, for
66/// example, a TCP connection to the remote server.
67#[cfg(feature = "handshake")]
68pub async fn client_async<R, S>(
69    request: R,
70    stream: S,
71) -> Result<(WebSocketStream<S>, Response), WsError>
72where
73    R: IntoClientRequest + Unpin,
74    S: AsyncRead + AsyncWrite + Unpin,
75{
76    client_async_with_config(request, stream, None).await
77}
78
79/// The same as `client_async()` but the one can specify a websocket configuration.
80/// Please refer to `client_async()` for more details.
81#[cfg(feature = "handshake")]
82pub async fn client_async_with_config<R, S>(
83    request: R,
84    stream: S,
85    config: Option<WebSocketConfig>,
86) -> Result<(WebSocketStream<S>, Response), WsError>
87where
88    R: IntoClientRequest + Unpin,
89    S: AsyncRead + AsyncWrite + Unpin,
90{
91    let f = handshake::client_handshake(stream, move |allow_std| {
92        let request = request.into_client_request()?;
93        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
94        cli_handshake.handshake()
95    });
96    f.await.map_err(|e| match e {
97        HandshakeError::Failure(e) => e,
98        e => WsError::Io(std::io::Error::other(e.to_string())),
99    })
100}
101
102/// Accepts a new WebSocket connection with the provided stream.
103///
104/// This function will internally call `server::accept` to create a
105/// handshake representation and returns a future representing the
106/// resolution of the WebSocket handshake. The returned future will resolve
107/// to either `WebSocketStream<S>` or `Error` depending if it's successful
108/// or not.
109///
110/// This is typically used after a socket has been accepted from a
111/// `TcpListener`. That socket is then passed to this function to perform
112/// the server half of the accepting a client's websocket connection.
113#[cfg(feature = "handshake")]
114pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
115where
116    S: AsyncRead + AsyncWrite + Unpin,
117{
118    accept_hdr_async(stream, NoCallback).await
119}
120
121/// The same as `accept_async()` but the one can specify a websocket configuration.
122/// Please refer to `accept_async()` for more details.
123#[cfg(feature = "handshake")]
124pub async fn accept_async_with_config<S>(
125    stream: S,
126    config: Option<WebSocketConfig>,
127) -> Result<WebSocketStream<S>, WsError>
128where
129    S: AsyncRead + AsyncWrite + Unpin,
130{
131    accept_hdr_async_with_config(stream, NoCallback, config).await
132}
133
134/// Accepts a new WebSocket connection with the provided stream.
135///
136/// This function does the same as `accept_async()` but accepts an extra callback
137/// for header processing. The callback receives headers of the incoming
138/// requests and is able to add extra headers to the reply.
139#[cfg(feature = "handshake")]
140pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
141where
142    S: AsyncRead + AsyncWrite + Unpin,
143    C: Callback + Unpin,
144{
145    accept_hdr_async_with_config(stream, callback, None).await
146}
147
148/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
149/// Please refer to `accept_hdr_async()` for more details.
150#[cfg(feature = "handshake")]
151pub async fn accept_hdr_async_with_config<S, C>(
152    stream: S,
153    callback: C,
154    config: Option<WebSocketConfig>,
155) -> Result<WebSocketStream<S>, WsError>
156where
157    S: AsyncRead + AsyncWrite + Unpin,
158    C: Callback + Unpin,
159{
160    let f = handshake::server_handshake(stream, move |allow_std| {
161        tungstenite::accept_hdr_with_config(allow_std, callback, config)
162    });
163    f.await.map_err(|e| match e {
164        HandshakeError::Failure(e) => e,
165        e => WsError::Io(std::io::Error::other(e.to_string())),
166    })
167}
168
169/// A wrapper around an underlying raw stream which implements the WebSocket
170/// protocol.
171///
172/// A `WebSocketStream<S>` represents a handshake that has been completed
173/// successfully and both the server and the client are ready for receiving
174/// and sending data. Message from a `WebSocketStream<S>` are accessible
175/// through the respective `Stream` and `Sink`. Check more information about
176/// them in `futures-rs` crate documentation or have a look on the examples
177/// and unit tests for this crate.
178#[derive(Debug)]
179pub struct WebSocketStream<S> {
180    inner: WebSocket<AllowStd<S>>,
181    closing: bool,
182    ended: bool,
183    /// Tungstenite is probably ready to receive more data.
184    ///
185    /// `false` once start_send hits `WouldBlock` errors.
186    /// `true` initially and after `flush`ing.
187    ready: bool,
188}
189
190impl<S> WebSocketStream<S> {
191    /// Convert a raw socket into a WebSocketStream without performing a
192    /// handshake.
193    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
194    where
195        S: AsyncRead + AsyncWrite + Unpin,
196    {
197        handshake::without_handshake(stream, move |allow_std| {
198            WebSocket::from_raw_socket(allow_std, role, config)
199        })
200        .await
201    }
202
203    /// Convert a raw socket into a WebSocketStream without performing a
204    /// handshake.
205    pub async fn from_partially_read(
206        stream: S,
207        part: Vec<u8>,
208        role: Role,
209        config: Option<WebSocketConfig>,
210    ) -> Self
211    where
212        S: AsyncRead + AsyncWrite + Unpin,
213    {
214        handshake::without_handshake(stream, move |allow_std| {
215            WebSocket::from_partially_read(allow_std, part, role, config)
216        })
217        .await
218    }
219
220    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
221        Self {
222            inner: ws,
223            closing: false,
224            ended: false,
225            ready: true,
226        }
227    }
228
229    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
230    where
231        S: Unpin,
232        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
233        AllowStd<S>: Read + Write,
234    {
235        if let Some((kind, ctx)) = ctx {
236            self.inner.get_mut().set_waker(kind, ctx.waker());
237        }
238        f(&mut self.inner)
239    }
240
241    /// Consumes the `WebSocketStream` and returns the underlying stream.
242    pub fn into_inner(self) -> S {
243        self.inner.into_inner().into_inner()
244    }
245
246    /// Returns a shared reference to the inner stream.
247    pub fn get_ref(&self) -> &S
248    where
249        S: AsyncRead + AsyncWrite + Unpin,
250    {
251        self.inner.get_ref().get_ref()
252    }
253
254    /// Returns a mutable reference to the inner stream.
255    pub fn get_mut(&mut self) -> &mut S
256    where
257        S: AsyncRead + AsyncWrite + Unpin,
258    {
259        self.inner.get_mut().get_mut()
260    }
261
262    /// Returns a reference to the configuration of the tungstenite stream.
263    pub fn get_config(&self) -> &WebSocketConfig {
264        self.inner.get_config()
265    }
266
267    /// Close the underlying web socket
268    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
269    where
270        S: AsyncRead + AsyncWrite + Unpin,
271    {
272        self.send(Message::Close(msg)).await
273    }
274}
275
276impl<T> Stream for WebSocketStream<T>
277where
278    T: AsyncRead + AsyncWrite + Unpin,
279{
280    type Item = Result<Message, WsError>;
281
282    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
283        // The connection has been closed or a critical error has occurred.
284        // We have already returned the error to the user, the `Stream` is unusable,
285        // so we assume that the stream has been "fused".
286        if self.ended {
287            return Poll::Ready(None);
288        }
289
290        match futures_util::ready!(
291            self.with_context(Some((ContextWaker::Read, cx)), |s| { cvt(s.read()) })
292        ) {
293            Ok(v) => Poll::Ready(Some(Ok(v))),
294            Err(e) => {
295                self.ended = true;
296                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
297                    Poll::Ready(None)
298                } else {
299                    Poll::Ready(Some(Err(e)))
300                }
301            }
302        }
303    }
304}
305
306impl<T> FusedStream for WebSocketStream<T>
307where
308    T: AsyncRead + AsyncWrite + Unpin,
309{
310    fn is_terminated(&self) -> bool {
311        self.ended
312    }
313}
314
315impl<T> Sink<Message> for WebSocketStream<T>
316where
317    T: AsyncRead + AsyncWrite + Unpin,
318{
319    type Error = WsError;
320
321    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
322        if self.ready {
323            Poll::Ready(Ok(()))
324        } else {
325            // Currently blocked so try to flush the blockage away
326            (*self)
327                .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
328                .map(|r| {
329                    self.ready = true;
330                    r
331                })
332        }
333    }
334
335    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
336        match (*self).with_context(None, |s| s.write(item)) {
337            Ok(()) => {
338                self.ready = true;
339                Ok(())
340            }
341            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
342                // the message was accepted and queued so not an error
343                // but `poll_ready` will now start trying to flush the block
344                self.ready = false;
345                Ok(())
346            }
347            Err(e) => {
348                self.ready = true;
349                Err(e)
350            }
351        }
352    }
353
354    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
355        (*self)
356            .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
357            .map(|r| {
358                self.ready = true;
359                match r {
360                    // WebSocket connection has just been closed. Flushing completed, not an error.
361                    Err(WsError::ConnectionClosed) => Ok(()),
362                    other => other,
363                }
364            })
365    }
366
367    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
368        self.ready = true;
369        let res = if self.closing {
370            // After queueing it, we call `flush` to drive the close handshake to completion.
371            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
372        } else {
373            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
374        };
375
376        match res {
377            Ok(()) => Poll::Ready(Ok(())),
378            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
379            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
380                self.closing = true;
381                Poll::Pending
382            }
383            Err(err) => Poll::Ready(Err(err)),
384        }
385    }
386}
387
388/// Get a domain from an URL.
389#[cfg(feature = "connect")]
390#[inline]
391fn domain(request: &tungstenite::handshake::client::Request) -> Result<String, WsError> {
392    match request.uri().host() {
393        Some(d) => Ok(d.to_string()),
394        None => Err(WsError::Url(tungstenite::error::UrlError::NoHostName)),
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use crate::{WebSocketStream, compat::AllowStd};
401    use std::io::{Read, Write};
402
403    fn is_read<T: Read>() {}
404    fn is_write<T: Write>() {}
405    fn is_unpin<T: Unpin>() {}
406
407    #[test]
408    fn web_socket_stream_has_traits() {
409        is_read::<AllowStd<wstd::net::TcpStream>>();
410        is_write::<AllowStd<wstd::net::TcpStream>>();
411        is_unpin::<WebSocketStream<wstd::net::TcpStream>>();
412    }
413}