veilid_async_tungstenite/
lib.rs

1//! Async WebSockets.
2//!
3//! This crate is based on [tungstenite](https://crates.io/crates/tungstenite)
4//! Rust WebSocket library and provides async bindings and wrappers for it, so you
5//! can use it with non-blocking/asynchronous `TcpStream`s from and couple it
6//! together with other crates from the async stack. In addition, optional
7//! integration with various other crates can be enabled via feature flags
8//!
9//!  * `async-tls`: Enables the `async_tls` module, which provides integration
10//!    with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can
11//!    be used independent of any async runtime.
12//!  * `async-std-runtime`: Enables the `async_std` module, which provides
13//!    integration with the [async-std](https://async.rs) runtime.
14//!  * `async-native-tls`: Enables the additional functions in the `async_std`
15//!    module to implement TLS via
16//!    [async-native-tls](https://crates.io/crates/async-native-tls).
17//!  * `tokio-runtime`: Enables the `tokio` module, which provides integration
18//!    with the [tokio](https://tokio.rs) runtime.
19//!  * `tokio-native-tls`: Enables the additional functions in the `tokio` module to
20//!    implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls).
21//!  * `tokio-rustls-native-certs`: Enables the additional functions in the `tokio`
22//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
23//!    and uses native system certificates found with
24//!    [rustls-native-certs](https://github.com/rustls/rustls-native-certs).
25//!  * `tokio-rustls-webpki-roots`: Enables the additional functions in the `tokio`
26//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
27//!    and uses the certificates [webpki-roots](https://github.com/rustls/webpki-roots)
28//!    provides.
29//!  * `tokio-openssl`: Enables the additional functions in the `tokio` module to
30//!    implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl).
31//!  * `gio-runtime`: Enables the `gio` module, which provides integration with
32//!    the [gio](https://www.gtk-rs.org) runtime.
33//!
34//! Each WebSocket stream implements the required `Stream` and `Sink` traits,
35//! making the socket a stream of WebSocket messages coming in and going out.
36
37#![deny(
38    missing_docs,
39    unused_must_use,
40    unused_mut,
41    unused_imports,
42    unused_import_braces
43)]
44
45pub use tungstenite;
46
47mod compat;
48mod handshake;
49
50#[cfg(any(
51    feature = "async-tls",
52    feature = "async-native-tls",
53    feature = "tokio-native-tls",
54    feature = "tokio-rustls-manual-roots",
55    feature = "tokio-rustls-native-certs",
56    feature = "tokio-rustls-webpki-roots",
57    feature = "tokio-openssl",
58))]
59pub mod stream;
60
61use std::io::{Read, Write};
62
63use compat::{cvt, AllowStd, ContextWaker};
64use futures_io::{AsyncRead, AsyncWrite};
65use futures_util::{
66    sink::{Sink, SinkExt},
67    stream::{FusedStream, Stream},
68};
69use log::*;
70use std::pin::Pin;
71use std::task::{Context, Poll};
72
73#[cfg(feature = "handshake")]
74use tungstenite::{
75    client::IntoClientRequest,
76    handshake::{
77        client::{ClientHandshake, Response},
78        server::{Callback, NoCallback},
79        HandshakeError,
80    },
81};
82use tungstenite::{
83    error::Error as WsError,
84    protocol::{Message, Role, WebSocket, WebSocketConfig},
85};
86
87#[cfg(feature = "async-std-runtime")]
88pub mod async_std;
89#[cfg(feature = "async-tls")]
90pub mod async_tls;
91#[cfg(feature = "gio-runtime")]
92pub mod gio;
93#[cfg(feature = "tokio-runtime")]
94pub mod tokio;
95
96use tungstenite::protocol::CloseFrame;
97
98/// Creates a WebSocket handshake from a request and a stream.
99/// For convenience, the user may call this with a url string, a URL,
100/// or a `Request`. Calling with `Request` allows the user to add
101/// a WebSocket protocol or other custom headers.
102///
103/// Internally, this custom creates a handshake representation and returns
104/// a future representing the resolution of the WebSocket handshake. The
105/// returned future will resolve to either `WebSocketStream<S>` or `Error`
106/// depending on whether the handshake is successful.
107///
108/// This is typically used for clients who have already established, for
109/// example, a TCP connection to the remote server.
110#[cfg(feature = "handshake")]
111pub async fn client_async<'a, R, S>(
112    request: R,
113    stream: S,
114) -> Result<(WebSocketStream<S>, Response), WsError>
115where
116    R: IntoClientRequest + Unpin,
117    S: AsyncRead + AsyncWrite + Unpin,
118{
119    client_async_with_config(request, stream, None).await
120}
121
122/// The same as `client_async()` but the one can specify a websocket configuration.
123/// Please refer to `client_async()` for more details.
124#[cfg(feature = "handshake")]
125pub async fn client_async_with_config<'a, R, S>(
126    request: R,
127    stream: S,
128    config: Option<WebSocketConfig>,
129) -> Result<(WebSocketStream<S>, Response), WsError>
130where
131    R: IntoClientRequest + Unpin,
132    S: AsyncRead + AsyncWrite + Unpin,
133{
134    let f = handshake::client_handshake(stream, move |allow_std| {
135        let request = request.into_client_request()?;
136        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
137        cli_handshake.handshake()
138    });
139    f.await.map_err(|e| match e {
140        HandshakeError::Failure(e) => e,
141        e => WsError::Io(std::io::Error::new(
142            std::io::ErrorKind::Other,
143            e.to_string(),
144        )),
145    })
146}
147
148/// Accepts a new WebSocket connection with the provided stream.
149///
150/// This function will internally call `server::accept` to create a
151/// handshake representation and returns a future representing the
152/// resolution of the WebSocket handshake. The returned future will resolve
153/// to either `WebSocketStream<S>` or `Error` depending if it's successful
154/// or not.
155///
156/// This is typically used after a socket has been accepted from a
157/// `TcpListener`. That socket is then passed to this function to perform
158/// the server half of the accepting a client's websocket connection.
159#[cfg(feature = "handshake")]
160pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
161where
162    S: AsyncRead + AsyncWrite + Unpin,
163{
164    accept_hdr_async(stream, NoCallback).await
165}
166
167/// The same as `accept_async()` but the one can specify a websocket configuration.
168/// Please refer to `accept_async()` for more details.
169#[cfg(feature = "handshake")]
170pub async fn accept_async_with_config<S>(
171    stream: S,
172    config: Option<WebSocketConfig>,
173) -> Result<WebSocketStream<S>, WsError>
174where
175    S: AsyncRead + AsyncWrite + Unpin,
176{
177    accept_hdr_async_with_config(stream, NoCallback, config).await
178}
179
180/// Accepts a new WebSocket connection with the provided stream.
181///
182/// This function does the same as `accept_async()` but accepts an extra callback
183/// for header processing. The callback receives headers of the incoming
184/// requests and is able to add extra headers to the reply.
185#[cfg(feature = "handshake")]
186pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
187where
188    S: AsyncRead + AsyncWrite + Unpin,
189    C: Callback + Unpin,
190{
191    accept_hdr_async_with_config(stream, callback, None).await
192}
193
194/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
195/// Please refer to `accept_hdr_async()` for more details.
196#[cfg(feature = "handshake")]
197pub async fn accept_hdr_async_with_config<S, C>(
198    stream: S,
199    callback: C,
200    config: Option<WebSocketConfig>,
201) -> Result<WebSocketStream<S>, WsError>
202where
203    S: AsyncRead + AsyncWrite + Unpin,
204    C: Callback + Unpin,
205{
206    let f = handshake::server_handshake(stream, move |allow_std| {
207        tungstenite::accept_hdr_with_config(allow_std, callback, config)
208    });
209    f.await.map_err(|e| match e {
210        HandshakeError::Failure(e) => e,
211        e => WsError::Io(std::io::Error::new(
212            std::io::ErrorKind::Other,
213            e.to_string(),
214        )),
215    })
216}
217
218/// A wrapper around an underlying raw stream which implements the WebSocket
219/// protocol.
220///
221/// A `WebSocketStream<S>` represents a handshake that has been completed
222/// successfully and both the server and the client are ready for receiving
223/// and sending data. Message from a `WebSocketStream<S>` are accessible
224/// through the respective `Stream` and `Sink`. Check more information about
225/// them in `futures-rs` crate documentation or have a look on the examples
226/// and unit tests for this crate.
227#[derive(Debug)]
228pub struct WebSocketStream<S> {
229    inner: WebSocket<AllowStd<S>>,
230    closing: bool,
231    ended: bool,
232    /// Tungstenite is probably ready to receive more data.
233    ///
234    /// `false` once start_send hits `WouldBlock` errors.
235    /// `true` initially and after `flush`ing.
236    ready: bool,
237}
238
239impl<S> WebSocketStream<S> {
240    /// Convert a raw socket into a WebSocketStream without performing a
241    /// handshake.
242    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
243    where
244        S: AsyncRead + AsyncWrite + Unpin,
245    {
246        handshake::without_handshake(stream, move |allow_std| {
247            WebSocket::from_raw_socket(allow_std, role, config)
248        })
249        .await
250    }
251
252    /// Convert a raw socket into a WebSocketStream without performing a
253    /// handshake.
254    pub async fn from_partially_read(
255        stream: S,
256        part: Vec<u8>,
257        role: Role,
258        config: Option<WebSocketConfig>,
259    ) -> Self
260    where
261        S: AsyncRead + AsyncWrite + Unpin,
262    {
263        handshake::without_handshake(stream, move |allow_std| {
264            WebSocket::from_partially_read(allow_std, part, role, config)
265        })
266        .await
267    }
268
269    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
270        Self {
271            inner: ws,
272            closing: false,
273            ended: false,
274            ready: true,
275        }
276    }
277
278    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
279    where
280        S: Unpin,
281        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
282        AllowStd<S>: Read + Write,
283    {
284        #[cfg(feature = "verbose-logging")]
285        trace!("{}:{} WebSocketStream.with_context", file!(), line!());
286        if let Some((kind, ctx)) = ctx {
287            self.inner.get_mut().set_waker(kind, ctx.waker());
288        }
289        f(&mut self.inner)
290    }
291
292    /// Returns a shared reference to the inner stream.
293    pub fn get_ref(&self) -> &S
294    where
295        S: AsyncRead + AsyncWrite + Unpin,
296    {
297        self.inner.get_ref().get_ref()
298    }
299
300    /// Returns a mutable reference to the inner stream.
301    pub fn get_mut(&mut self) -> &mut S
302    where
303        S: AsyncRead + AsyncWrite + Unpin,
304    {
305        self.inner.get_mut().get_mut()
306    }
307
308    /// Returns a reference to the configuration of the tungstenite stream.
309    pub fn get_config(&self) -> &WebSocketConfig {
310        self.inner.get_config()
311    }
312
313    /// Close the underlying web socket
314    pub async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<(), WsError>
315    where
316        S: AsyncRead + AsyncWrite + Unpin,
317    {
318        let msg = msg.map(|msg| msg.into_owned());
319        self.send(Message::Close(msg)).await
320    }
321}
322
323impl<T> Stream for WebSocketStream<T>
324where
325    T: AsyncRead + AsyncWrite + Unpin,
326{
327    type Item = Result<Message, WsError>;
328
329    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
330        #[cfg(feature = "verbose-logging")]
331        trace!("{}:{} Stream.poll_next", file!(), line!());
332
333        // The connection has been closed or a critical error has occurred.
334        // We have already returned the error to the user, the `Stream` is unusable,
335        // so we assume that the stream has been "fused".
336        if self.ended {
337            return Poll::Ready(None);
338        }
339
340        match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
341            #[cfg(feature = "verbose-logging")]
342            trace!(
343                "{}:{} Stream.with_context poll_next -> read()",
344                file!(),
345                line!()
346            );
347            cvt(s.read())
348        })) {
349            Ok(v) => Poll::Ready(Some(Ok(v))),
350            Err(e) => {
351                self.ended = true;
352                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
353                    Poll::Ready(None)
354                } else {
355                    Poll::Ready(Some(Err(e)))
356                }
357            }
358        }
359    }
360}
361
362impl<T> FusedStream for WebSocketStream<T>
363where
364    T: AsyncRead + AsyncWrite + Unpin,
365{
366    fn is_terminated(&self) -> bool {
367        self.ended
368    }
369}
370
371impl<T> Sink<Message> for WebSocketStream<T>
372where
373    T: AsyncRead + AsyncWrite + Unpin,
374{
375    type Error = WsError;
376
377    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
378        if self.ready {
379            Poll::Ready(Ok(()))
380        } else {
381            // Currently blocked so try to flush the blockage away
382            (*self)
383                .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
384                .map(|r| {
385                    self.ready = true;
386                    r
387                })
388        }
389    }
390
391    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
392        match (*self).with_context(None, |s| s.write(item)) {
393            Ok(()) => {
394                self.ready = true;
395                Ok(())
396            }
397            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
398                // the message was accepted and queued so not an error
399                // but `poll_ready` will now start trying to flush the block
400                self.ready = false;
401                Ok(())
402            }
403            Err(e) => {
404                self.ready = true;
405                debug!("websocket start_send error: {}", e);
406                Err(e)
407            }
408        }
409    }
410
411    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
412        (*self)
413            .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
414            .map(|r| {
415                self.ready = true;
416                match r {
417                    // WebSocket connection has just been closed. Flushing completed, not an error.
418                    Err(WsError::ConnectionClosed) => Ok(()),
419                    other => other,
420                }
421            })
422    }
423
424    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
425        self.ready = true;
426        let res = if self.closing {
427            // After queueing it, we call `flush` to drive the close handshake to completion.
428            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
429        } else {
430            (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
431        };
432
433        match res {
434            Ok(()) => Poll::Ready(Ok(())),
435            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
436            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
437                trace!("WouldBlock");
438                self.closing = true;
439                Poll::Pending
440            }
441            Err(err) => {
442                debug!("websocket close error: {}", err);
443                Poll::Ready(Err(err))
444            }
445        }
446    }
447}
448
449#[cfg(any(
450    feature = "async-tls",
451    feature = "async-std-runtime",
452    feature = "tokio-runtime",
453    feature = "gio-runtime"
454))]
455/// Get a domain from an URL.
456#[inline]
457pub(crate) fn domain(
458    request: &tungstenite::handshake::client::Request,
459) -> Result<String, tungstenite::Error> {
460    request
461        .uri()
462        .host()
463        .map(|host| {
464            // If host is an IPv6 address, it might be surrounded by brackets. These brackets are
465            // *not* part of a valid IP, so they must be stripped out.
466            //
467            // The URI from the request is guaranteed to be valid, so we don't need a separate
468            // check for the closing bracket.
469            let host = if host.starts_with('[') {
470                &host[1..host.len() - 1]
471            } else {
472                host
473            };
474
475            host.to_owned()
476        })
477        .ok_or(tungstenite::Error::Url(
478            tungstenite::error::UrlError::NoHostName,
479        ))
480}
481
482#[cfg(any(
483    feature = "async-std-runtime",
484    feature = "tokio-runtime",
485    feature = "gio-runtime"
486))]
487/// Get the port from an URL.
488#[inline]
489pub(crate) fn port(
490    request: &tungstenite::handshake::client::Request,
491) -> Result<u16, tungstenite::Error> {
492    request
493        .uri()
494        .port_u16()
495        .or_else(|| match request.uri().scheme_str() {
496            Some("wss") => Some(443),
497            Some("ws") => Some(80),
498            _ => None,
499        })
500        .ok_or(tungstenite::Error::Url(
501            tungstenite::error::UrlError::UnsupportedUrlScheme,
502        ))
503}
504
505#[cfg(test)]
506mod tests {
507    #[cfg(any(
508        feature = "async-tls",
509        feature = "async-std-runtime",
510        feature = "tokio-runtime",
511        feature = "gio-runtime"
512    ))]
513    #[test]
514    fn domain_strips_ipv6_brackets() {
515        use tungstenite::client::IntoClientRequest;
516
517        let request = "ws://[::1]:80".into_client_request().unwrap();
518        assert_eq!(crate::domain(&request).unwrap(), "::1");
519    }
520
521    #[cfg(feature = "handshake")]
522    #[test]
523    fn requests_cannot_contain_invalid_uris() {
524        use tungstenite::client::IntoClientRequest;
525
526        assert!("ws://[".into_client_request().is_err());
527        assert!("ws://[blabla/bla".into_client_request().is_err());
528        assert!("ws://[::1/bla".into_client_request().is_err());
529    }
530}