volo_http/server/utils/
ws.rs

1//! WebSocket implementation for server.
2//!
3//! This module provides utilities for setting up and handling WebSocket connections, including
4//! configuring WebSocket options, setting protocols and upgrading connections.
5//!
6//! # Example
7//!
8//! ```
9//! use std::convert::Infallible;
10//!
11//! use futures_util::{sink::SinkExt, stream::StreamExt};
12//! use volo_http::{
13//!     response::Response,
14//!     server::{
15//!         route::{Router, get},
16//!         utils::ws::{Message, WebSocket, WebSocketUpgrade},
17//!     },
18//! };
19//!
20//! async fn handle_socket(mut socket: WebSocket) {
21//!     while let Some(Ok(msg)) = socket.next().await {
22//!         match msg {
23//!             Message::Text(_) => {
24//!                 socket.send(msg).await.unwrap();
25//!             }
26//!             _ => {}
27//!         }
28//!     }
29//! }
30//!
31//! async fn ws_handler(ws: WebSocketUpgrade) -> Response {
32//!     ws.on_upgrade(handle_socket)
33//! }
34//!
35//! let app: Router = Router::new().route("/ws", get(ws_handler));
36//! ```
37//!
38//! See [`WebSocketUpgrade`] and [`WebSocket`] for more details.
39
40use std::{
41    borrow::Cow,
42    error::Error,
43    fmt,
44    future::Future,
45    ops::{Deref, DerefMut},
46};
47
48use ahash::AHashSet;
49use http::{
50    header,
51    header::{HeaderMap, HeaderName, HeaderValue},
52    method::Method,
53    request::Parts,
54    status::StatusCode,
55    version::Version,
56};
57use hyper_util::rt::TokioIo;
58use tokio_tungstenite::WebSocketStream;
59pub use tungstenite::Message;
60use tungstenite::{
61    handshake::derive_accept_key,
62    protocol::{self, WebSocketConfig},
63};
64
65use crate::{
66    body::Body,
67    context::ServerContext,
68    response::Response,
69    server::{IntoResponse, extract::FromContext},
70};
71
72const HEADERVALUE_UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
73const HEADERVALUE_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
74
75/// Handle request for establishing WebSocket connection.
76///
77/// [`WebSocketUpgrade`] can be passed as an argument to a handler, which will be called if the
78/// http connection making the request can be upgraded to a websocket connection.
79///
80/// [`WebSocketUpgrade`] must be used with [`WebSocketUpgrade::on_upgrade`] and a websocket
81/// handler, [`WebSocketUpgrade::on_upgrade`] will return a [`Response`] for the client and
82/// the connection will then be upgraded later.
83///
84/// # Example
85///
86/// ```
87/// use volo_http::{response::Response, server::utils::ws::WebSocketUpgrade};
88///
89/// fn ws_handler(ws: WebSocketUpgrade) -> Response {
90///     ws.on_upgrade(|socket| async { todo!() })
91/// }
92/// ```
93#[must_use]
94pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
95    config: WebSocketConfig,
96    protocol: Option<HeaderValue>,
97    sec_websocket_key: HeaderValue,
98    sec_websocket_protocol: Option<HeaderValue>,
99    on_upgrade: hyper::upgrade::OnUpgrade,
100    on_failed_upgrade: F,
101}
102
103impl<F> WebSocketUpgrade<F> {
104    /// The target minimum size of the write buffer to reach before writing the data to the
105    /// underlying stream.
106    ///
107    /// The default value is 128 KiB.
108    ///
109    /// If set to `0` each message will be eagerly written to the underlying stream. It is often
110    /// more optimal to allow them to buffer a little, hence the default value.
111    ///
112    /// Note: [`flush`] will always fully write the buffer regardless.
113    ///
114    /// [`flush`]: futures_util::sink::SinkExt::flush
115    pub fn write_buffer_size(mut self, size: usize) -> Self {
116        self.config.write_buffer_size = size;
117        self
118    }
119
120    /// The max size of the write buffer in bytes. Setting this can provide backpressure
121    /// in the case the write buffer is filling up due to write errors.
122    ///
123    /// The default value is unlimited.
124    ///
125    /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
126    /// when writes to the underlying stream are failing. So the **write buffer can not
127    /// fill up if you are not observing write errors even if not flushing**.
128    ///
129    /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
130    /// and probably a little more depending on error handling strategy.
131    pub fn max_write_buffer_size(mut self, max: usize) -> Self {
132        self.config.max_write_buffer_size = max;
133        self
134    }
135
136    /// The maximum size of an incoming message.
137    ///
138    /// `None` means no size limit.
139    ///
140    /// The default value is 64 MiB, which should be reasonably big for all normal use-cases but
141    /// small enough to prevent memory eating by a malicious user.
142    pub fn max_message_size(mut self, max: Option<usize>) -> Self {
143        self.config.max_message_size = max;
144        self
145    }
146
147    /// The maximum size of a single incoming message frame.
148    ///
149    /// `None` means no size limit.
150    ///
151    /// The limit is for frame payload NOT including the frame header.
152    ///
153    /// The default value is 16 MiB, which should be reasonably big for all normal use-cases but
154    /// small enough to prevent memory eating by a malicious user.
155    pub fn max_frame_size(mut self, max: Option<usize>) -> Self {
156        self.config.max_frame_size = max;
157        self
158    }
159
160    /// If server to accept unmasked frames.
161    ///
162    /// When set to `true`, the server will accept and handle unmasked frames from the client.
163    ///
164    /// According to the RFC 6455, the server must close the connection to the client in such
165    /// cases, however it seems like there are some popular libraries that are sending unmasked
166    /// frames, ignoring the RFC.
167    ///
168    /// By default this option is set to `false`, i.e. according to RFC 6455.
169    pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
170        self.config.accept_unmasked_frames = accept;
171        self
172    }
173
174    fn get_protocol<I>(&mut self, protocols: I) -> Option<HeaderValue>
175    where
176        I: IntoIterator,
177        I::Item: Into<Cow<'static, str>>,
178    {
179        let req_protocols = self
180            .sec_websocket_protocol
181            .as_ref()?
182            .to_str()
183            .ok()?
184            .split(',')
185            .map(str::trim)
186            .collect::<AHashSet<_>>();
187        for protocol in protocols.into_iter().map(Into::into) {
188            if req_protocols.contains(protocol.as_ref()) {
189                let protocol = match protocol {
190                    Cow::Owned(s) => HeaderValue::from_str(&s).ok()?,
191                    Cow::Borrowed(s) => HeaderValue::from_static(s),
192                };
193                return Some(protocol);
194            }
195        }
196
197        None
198    }
199
200    /// Set available protocols for [`Sec-WebSocket-Protocol`][mdn].
201    ///
202    /// If the protocol in [`Sec-WebSocket-Protocol`][mdn] matches any protocol, the upgrade
203    /// response will insert [`Sec-WebSocket-Protocol`][mdn] and [`WebSocket`] will contain the
204    /// protocol name.
205    ///
206    /// Note that if the client offers multiple protocols that the server supports, the server will
207    /// pick the first one in the list.
208    ///
209    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Protocol
210    pub fn protocols<I>(mut self, protocols: I) -> Self
211    where
212        I: IntoIterator,
213        I::Item: Into<Cow<'static, str>>,
214    {
215        self.protocol = self.get_protocol(protocols);
216        self
217    }
218
219    /// Provide a callback to call if upgrading the connection fails.
220    ///
221    /// The connection upgrade is performed in a background task. If that fails this callback will
222    /// be called.
223    ///
224    /// By default, any errors will be silently ignored.
225    ///
226    /// # Example
227    ///
228    /// ```
229    /// use volo_http::{
230    ///     response::Response,
231    ///     server::{
232    ///         route::{Router, get},
233    ///         utils::ws::{WebSocket, WebSocketUpgrade},
234    ///     },
235    /// };
236    ///
237    /// async fn ws_handler(ws: WebSocketUpgrade) -> Response {
238    ///     ws.on_failed_upgrade(|err| eprintln!("Failed to upgrade connection, err: {err}"))
239    ///         .on_upgrade(|socket| async { todo!() })
240    /// }
241    ///
242    /// let router: Router = Router::new().route("/ws", get(ws_handler));
243    /// ```
244    pub fn on_failed_upgrade<F2>(self, callback: F2) -> WebSocketUpgrade<F2>
245    where
246        F2: OnFailedUpgrade,
247    {
248        WebSocketUpgrade {
249            config: self.config,
250            protocol: self.protocol,
251            sec_websocket_key: self.sec_websocket_key,
252            sec_websocket_protocol: self.sec_websocket_protocol,
253            on_upgrade: self.on_upgrade,
254            on_failed_upgrade: callback,
255        }
256    }
257
258    /// Finalize upgrading the connection and call the provided callback
259    ///
260    /// If request protocol is matched, it will use `callback` to handle the connection stream
261    /// data.
262    ///
263    /// The callback function should be an async function with [`WebSocket`] as parameter.
264    ///
265    /// # Example
266    ///
267    /// ```
268    /// use futures_util::{sink::SinkExt, stream::StreamExt};
269    /// use volo_http::{
270    ///     response::Response,
271    ///     server::{
272    ///         route::{Router, get},
273    ///         utils::ws::{WebSocket, WebSocketUpgrade},
274    ///     },
275    /// };
276    ///
277    /// async fn ws_handler(ws: WebSocketUpgrade) -> Response {
278    ///     ws.on_upgrade(|mut socket| async move {
279    ///         while let Some(Ok(msg)) = socket.next().await {
280    ///             if msg.is_ping() || msg.is_pong() {
281    ///                 continue;
282    ///             }
283    ///             if socket.send(msg).await.is_err() {
284    ///                 break;
285    ///             }
286    ///         }
287    ///     })
288    /// }
289    ///
290    /// let router: Router = Router::new().route("/ws", get(ws_handler));
291    /// ```
292    pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
293    where
294        C: FnOnce(WebSocket) -> Fut + Send + 'static,
295        Fut: Future<Output = ()> + Send,
296        F: OnFailedUpgrade + Send + 'static,
297    {
298        let protocol = self.protocol.clone();
299        let fut = async move {
300            let upgraded = match self.on_upgrade.await {
301                Ok(upgraded) => upgraded,
302                Err(err) => {
303                    self.on_failed_upgrade.call(WebSocketError::Upgrade(err));
304                    return;
305                }
306            };
307            let upgraded = TokioIo::new(upgraded);
308
309            let socket = WebSocketStream::from_raw_socket(
310                upgraded,
311                protocol::Role::Server,
312                Some(self.config),
313            )
314            .await;
315            let socket = WebSocket {
316                inner: socket,
317                protocol,
318            };
319
320            callback(socket).await;
321        };
322
323        let mut resp = Response::new(Body::empty());
324        *resp.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
325        resp.headers_mut()
326            .insert(header::CONNECTION, HEADERVALUE_UPGRADE);
327        resp.headers_mut()
328            .insert(header::UPGRADE, HEADERVALUE_WEBSOCKET);
329        let Ok(accept_key) =
330            HeaderValue::from_str(&derive_accept_key(self.sec_websocket_key.as_bytes()))
331        else {
332            return StatusCode::BAD_REQUEST.into_response();
333        };
334        resp.headers_mut()
335            .insert(header::SEC_WEBSOCKET_ACCEPT, accept_key);
336        if let Some(protocol) = self.protocol {
337            if let Ok(protocol) = HeaderValue::from_bytes(protocol.as_bytes()) {
338                resp.headers_mut()
339                    .insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
340            }
341        }
342
343        tokio::spawn(fut);
344
345        resp
346    }
347}
348
349fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
350    let Some(header) = headers.get(&key) else {
351        return false;
352    };
353    let Ok(header) = simdutf8::basic::from_utf8(header.as_bytes()) else {
354        return false;
355    };
356    header.to_ascii_lowercase().contains(value)
357}
358
359fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
360    let Some(header) = headers.get(&key) else {
361        return false;
362    };
363    header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
364}
365
366impl FromContext for WebSocketUpgrade<DefaultOnFailedUpgrade> {
367    type Rejection = WebSocketUpgradeRejectionError;
368
369    async fn from_context(
370        _: &mut ServerContext,
371        parts: &mut Parts,
372    ) -> Result<Self, Self::Rejection> {
373        if parts.method != Method::GET {
374            return Err(WebSocketUpgradeRejectionError::MethodNotGet);
375        }
376        if parts.version < Version::HTTP_11 {
377            return Err(WebSocketUpgradeRejectionError::InvalidHttpVersion);
378        }
379
380        // The `Connection` may be multiple values separated by comma, so we should use
381        // `header_contains` rather than `header_eq` here.
382        //
383        // ref: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Connection
384        if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
385            return Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader);
386        }
387
388        if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
389            return Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader);
390        }
391
392        if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
393            return Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader);
394        }
395
396        let sec_websocket_key = parts
397            .headers
398            .get(header::SEC_WEBSOCKET_KEY)
399            .ok_or(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing)?
400            .clone();
401
402        let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
403
404        let on_upgrade = parts
405            .extensions
406            .remove::<hyper::upgrade::OnUpgrade>()
407            .expect("`OnUpgrade` is unavailable, maybe something wrong with `hyper`");
408
409        Ok(Self {
410            config: Default::default(),
411            protocol: None,
412            sec_websocket_key,
413            sec_websocket_protocol,
414            on_upgrade,
415            on_failed_upgrade: DefaultOnFailedUpgrade,
416        })
417    }
418}
419
420/// WebSocketStream used In handler Request
421pub struct WebSocket {
422    inner: WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
423    protocol: Option<HeaderValue>,
424}
425
426impl WebSocket {
427    /// Get protocol of current websocket.
428    ///
429    /// The value of protocol is from [`Sec-WebSocket-Protocol`][mdn] and
430    /// [`WebSocketUpgrade::protocols`] will pick one if there is any protocol that the server
431    /// gived.
432    ///
433    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Protocol
434    pub fn protocol(&self) -> Option<&str> {
435        simdutf8::basic::from_utf8(self.protocol.as_ref()?.as_bytes()).ok()
436    }
437}
438
439impl Deref for WebSocket {
440    type Target = WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>;
441
442    fn deref(&self) -> &Self::Target {
443        &self.inner
444    }
445}
446
447impl DerefMut for WebSocket {
448    fn deref_mut(&mut self) -> &mut Self::Target {
449        &mut self.inner
450    }
451}
452
453/// Error type when using [`WebSocket`].
454#[derive(Debug)]
455pub enum WebSocketError {
456    /// Error from [`hyper`] when calling [`OnUpgrade.await`][OnUpgrade] for upgrade a HTTP
457    /// connection to a WebSocket connection.
458    ///
459    /// [OnUpgrade]: hyper::upgrade::OnUpgrade
460    Upgrade(hyper::Error),
461}
462
463impl fmt::Display for WebSocketError {
464    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
465        match self {
466            Self::Upgrade(err) => write!(f, "failed to upgrade: {err}"),
467        }
468    }
469}
470
471impl Error for WebSocketError {
472    fn source(&self) -> Option<&(dyn Error + 'static)> {
473        match self {
474            Self::Upgrade(e) => Some(e),
475        }
476    }
477}
478
479/// What to do when a connection upgrade fails.
480///
481/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
482pub trait OnFailedUpgrade {
483    /// Called when a connection upgrade fails.
484    fn call(self, error: WebSocketError);
485}
486
487impl<F> OnFailedUpgrade for F
488where
489    F: FnOnce(WebSocketError),
490{
491    fn call(self, error: WebSocketError) {
492        self(error)
493    }
494}
495
496/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`.
497///
498/// It simply ignores the error.
499#[derive(Debug)]
500pub struct DefaultOnFailedUpgrade;
501
502impl OnFailedUpgrade for DefaultOnFailedUpgrade {
503    fn call(self, _: WebSocketError) {}
504}
505
506/// [`Error`]s while extracting [`WebSocketUpgrade`].
507///
508/// [`WebSocketUpgrade`]: crate::server::utils::ws::WebSocketUpgrade
509#[derive(Debug)]
510pub enum WebSocketUpgradeRejectionError {
511    /// The request method must be `GET`
512    MethodNotGet,
513    /// The HTTP version is not supported
514    InvalidHttpVersion,
515    /// The `Connection` header is invalid
516    InvalidConnectionHeader,
517    /// The `Upgrade` header is invalid
518    InvalidUpgradeHeader,
519    /// The `Sec-WebSocket-Version` header is invalid
520    InvalidWebSocketVersionHeader,
521    /// The `Sec-WebSocket-Key` header is missing
522    WebSocketKeyHeaderMissing,
523}
524
525impl WebSocketUpgradeRejectionError {
526    /// Convert the [`WebSocketUpgradeRejectionError`] to the corresponding [`StatusCode`]
527    fn to_status_code(&self) -> StatusCode {
528        match self {
529            Self::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED,
530            Self::InvalidHttpVersion => StatusCode::HTTP_VERSION_NOT_SUPPORTED,
531            Self::InvalidConnectionHeader => StatusCode::UPGRADE_REQUIRED,
532            Self::InvalidUpgradeHeader => StatusCode::BAD_REQUEST,
533            Self::InvalidWebSocketVersionHeader => StatusCode::BAD_REQUEST,
534            Self::WebSocketKeyHeaderMissing => StatusCode::BAD_REQUEST,
535        }
536    }
537}
538
539impl Error for WebSocketUpgradeRejectionError {}
540
541impl fmt::Display for WebSocketUpgradeRejectionError {
542    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
543        match self {
544            Self::MethodNotGet => f.write_str("Request method must be `GET`"),
545            Self::InvalidHttpVersion => f.write_str("HTTP version not support"),
546            Self::InvalidConnectionHeader => {
547                f.write_str("Header `Connection` does not include `upgrade`")
548            }
549            Self::InvalidUpgradeHeader => f.write_str("Header `Upgrade` is not `websocket`"),
550            Self::InvalidWebSocketVersionHeader => {
551                f.write_str("Header `Sec-WebSocket-Version` is not `13`")
552            }
553            Self::WebSocketKeyHeaderMissing => f.write_str("Header `Sec-WebSocket-Key` is missing"),
554        }
555    }
556}
557
558impl IntoResponse for WebSocketUpgradeRejectionError {
559    fn into_response(self) -> Response {
560        self.to_status_code().into_response()
561    }
562}
563
564#[cfg(test)]
565mod websocket_tests {
566    use std::{
567        convert::Infallible,
568        net::{IpAddr, Ipv4Addr, SocketAddr},
569        str::FromStr,
570    };
571
572    use futures_util::{sink::SinkExt, stream::StreamExt};
573    use http::uri::Uri;
574    use motore::service::Service;
575    use tokio::net::TcpStream;
576    use tokio_tungstenite::MaybeTlsStream;
577    use tungstenite::ClientRequestBuilder;
578    use volo::net::Address;
579
580    use super::*;
581    use crate::{Server, request::Request, server::test_helpers};
582
583    fn simple_parts() -> Parts {
584        let req = Request::builder()
585            .method(Method::GET)
586            .version(Version::HTTP_11)
587            .header(header::HOST, "localhost")
588            .header(header::CONNECTION, super::HEADERVALUE_UPGRADE)
589            .header(header::UPGRADE, super::HEADERVALUE_WEBSOCKET)
590            .header(header::SEC_WEBSOCKET_KEY, "6D69KGBOr4Re+Nj6zx9aQA==")
591            .header(header::SEC_WEBSOCKET_VERSION, "13")
592            .body(())
593            .unwrap();
594        req.into_parts().0
595    }
596
597    async fn run_ws_handler<S>(
598        service: S,
599        sub_protocol: Option<&'static str>,
600        port: u16,
601    ) -> (
602        WebSocketStream<MaybeTlsStream<TcpStream>>,
603        Response<Option<Vec<u8>>>,
604    )
605    where
606        S: Service<ServerContext, Request, Response = Response, Error = Infallible>
607            + Send
608            + Sync
609            + 'static,
610    {
611        let addr = Address::Ip(SocketAddr::new(
612            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
613            port,
614        ));
615        tokio::spawn(Server::new(service).run(addr.clone()));
616
617        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
618
619        let mut req = ClientRequestBuilder::new(Uri::from_str(&format!("ws://{addr}/")).unwrap());
620        if let Some(sub_protocol) = sub_protocol {
621            req = req.with_sub_protocol(sub_protocol);
622        }
623        tokio_tungstenite::connect_async(req).await.unwrap()
624    }
625
626    #[tokio::test]
627    async fn rejection() {
628        {
629            let mut parts = simple_parts();
630            parts.method = Method::POST;
631            let res =
632                WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
633            assert!(matches!(
634                res,
635                Err(WebSocketUpgradeRejectionError::MethodNotGet)
636            ));
637        }
638        {
639            let mut parts = simple_parts();
640            parts.version = Version::HTTP_10;
641            let res =
642                WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
643            assert!(matches!(
644                res,
645                Err(WebSocketUpgradeRejectionError::InvalidHttpVersion)
646            ));
647        }
648        {
649            let mut parts = simple_parts();
650            parts.headers.remove(header::CONNECTION);
651            let res =
652                WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
653            assert!(matches!(
654                res,
655                Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader)
656            ));
657        }
658        {
659            let mut parts = simple_parts();
660            parts.headers.remove(header::CONNECTION);
661            parts
662                .headers
663                .insert(header::CONNECTION, HeaderValue::from_static("downgrade"));
664            let res =
665                WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
666            assert!(matches!(
667                res,
668                Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader)
669            ));
670        }
671        {
672            let mut parts = simple_parts();
673            parts.headers.remove(header::UPGRADE);
674            let res =
675                WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
676            assert!(matches!(
677                res,
678                Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader)
679            ));
680        }
681        {
682            let mut parts = simple_parts();
683            parts.headers.remove(header::UPGRADE);
684            parts
685                .headers
686                .insert(header::UPGRADE, HeaderValue::from_static("supersocket"));
687            let res =
688                WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
689            assert!(matches!(
690                res,
691                Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader)
692            ));
693        }
694        {
695            let mut parts = simple_parts();
696            parts.headers.remove(header::SEC_WEBSOCKET_VERSION);
697            let res =
698                WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
699            assert!(matches!(
700                res,
701                Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader)
702            ));
703        }
704        {
705            let mut parts = simple_parts();
706            parts.headers.remove(header::SEC_WEBSOCKET_VERSION);
707            parts.headers.insert(
708                header::SEC_WEBSOCKET_VERSION,
709                HeaderValue::from_static("114514"),
710            );
711            let res =
712                WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
713            assert!(matches!(
714                res,
715                Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader)
716            ));
717        }
718        {
719            let mut parts = simple_parts();
720            parts.headers.remove(header::SEC_WEBSOCKET_KEY);
721            let res =
722                WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
723            assert!(matches!(
724                res,
725                Err(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing)
726            ));
727        }
728    }
729
730    #[tokio::test]
731    async fn protocol_test() {
732        async fn handler(ws: WebSocketUpgrade) -> Response {
733            ws.protocols(["soap", "wmap", "graphql-ws", "chat"])
734                .on_upgrade(|_| async {})
735        }
736
737        let (_, resp) =
738            run_ws_handler(test_helpers::to_service(handler), Some("graphql-ws"), 25230).await;
739
740        assert_eq!(
741            resp.headers()
742                .get(http::header::SEC_WEBSOCKET_PROTOCOL)
743                .unwrap(),
744            "graphql-ws"
745        );
746    }
747
748    #[tokio::test]
749    async fn success_on_upgrade() {
750        async fn echo(mut socket: WebSocket) {
751            while let Some(Ok(msg)) = socket.next().await {
752                if msg.is_ping() || msg.is_pong() {
753                    continue;
754                }
755                if socket.send(msg).await.is_err() {
756                    break;
757                }
758            }
759        }
760
761        async fn handler(ws: WebSocketUpgrade) -> Response {
762            ws.on_upgrade(echo)
763        }
764
765        let (mut ws_stream, _) =
766            run_ws_handler(test_helpers::to_service(handler), None, 25231).await;
767
768        let input = Message::Text("foobar".into());
769        ws_stream.send(input.clone()).await.unwrap();
770        let output = ws_stream.next().await.unwrap().unwrap();
771        assert_eq!(input, output);
772
773        let input = Message::Ping("foobar".into());
774        ws_stream.send(input).await.unwrap();
775        let output = ws_stream.next().await.unwrap().unwrap();
776        assert_eq!(output, Message::Pong("foobar".into()));
777    }
778}