1use std::{
41 borrow::Cow,
42 error::Error,
43 fmt,
44 future::Future,
45 ops::{Deref, DerefMut},
46};
47
48use ahash::AHashSet;
49use http::{
50 header,
51 header::{HeaderMap, HeaderName, HeaderValue},
52 method::Method,
53 request::Parts,
54 status::StatusCode,
55 version::Version,
56};
57use hyper_util::rt::TokioIo;
58use tokio_tungstenite::WebSocketStream;
59pub use tungstenite::Message;
60use tungstenite::{
61 handshake::derive_accept_key,
62 protocol::{self, WebSocketConfig},
63};
64
65use crate::{
66 body::Body,
67 context::ServerContext,
68 response::Response,
69 server::{IntoResponse, extract::FromContext},
70};
71
72const HEADERVALUE_UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
73const HEADERVALUE_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
74
75#[must_use]
94pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
95 config: WebSocketConfig,
96 protocol: Option<HeaderValue>,
97 sec_websocket_key: HeaderValue,
98 sec_websocket_protocol: Option<HeaderValue>,
99 on_upgrade: hyper::upgrade::OnUpgrade,
100 on_failed_upgrade: F,
101}
102
103impl<F> WebSocketUpgrade<F> {
104 pub fn write_buffer_size(mut self, size: usize) -> Self {
116 self.config.write_buffer_size = size;
117 self
118 }
119
120 pub fn max_write_buffer_size(mut self, max: usize) -> Self {
132 self.config.max_write_buffer_size = max;
133 self
134 }
135
136 pub fn max_message_size(mut self, max: Option<usize>) -> Self {
143 self.config.max_message_size = max;
144 self
145 }
146
147 pub fn max_frame_size(mut self, max: Option<usize>) -> Self {
156 self.config.max_frame_size = max;
157 self
158 }
159
160 pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
170 self.config.accept_unmasked_frames = accept;
171 self
172 }
173
174 fn get_protocol<I>(&mut self, protocols: I) -> Option<HeaderValue>
175 where
176 I: IntoIterator,
177 I::Item: Into<Cow<'static, str>>,
178 {
179 let req_protocols = self
180 .sec_websocket_protocol
181 .as_ref()?
182 .to_str()
183 .ok()?
184 .split(',')
185 .map(str::trim)
186 .collect::<AHashSet<_>>();
187 for protocol in protocols.into_iter().map(Into::into) {
188 if req_protocols.contains(protocol.as_ref()) {
189 let protocol = match protocol {
190 Cow::Owned(s) => HeaderValue::from_str(&s).ok()?,
191 Cow::Borrowed(s) => HeaderValue::from_static(s),
192 };
193 return Some(protocol);
194 }
195 }
196
197 None
198 }
199
200 pub fn protocols<I>(mut self, protocols: I) -> Self
211 where
212 I: IntoIterator,
213 I::Item: Into<Cow<'static, str>>,
214 {
215 self.protocol = self.get_protocol(protocols);
216 self
217 }
218
219 pub fn on_failed_upgrade<F2>(self, callback: F2) -> WebSocketUpgrade<F2>
245 where
246 F2: OnFailedUpgrade,
247 {
248 WebSocketUpgrade {
249 config: self.config,
250 protocol: self.protocol,
251 sec_websocket_key: self.sec_websocket_key,
252 sec_websocket_protocol: self.sec_websocket_protocol,
253 on_upgrade: self.on_upgrade,
254 on_failed_upgrade: callback,
255 }
256 }
257
258 pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
293 where
294 C: FnOnce(WebSocket) -> Fut + Send + 'static,
295 Fut: Future<Output = ()> + Send,
296 F: OnFailedUpgrade + Send + 'static,
297 {
298 let protocol = self.protocol.clone();
299 let fut = async move {
300 let upgraded = match self.on_upgrade.await {
301 Ok(upgraded) => upgraded,
302 Err(err) => {
303 self.on_failed_upgrade.call(WebSocketError::Upgrade(err));
304 return;
305 }
306 };
307 let upgraded = TokioIo::new(upgraded);
308
309 let socket = WebSocketStream::from_raw_socket(
310 upgraded,
311 protocol::Role::Server,
312 Some(self.config),
313 )
314 .await;
315 let socket = WebSocket {
316 inner: socket,
317 protocol,
318 };
319
320 callback(socket).await;
321 };
322
323 let mut resp = Response::new(Body::empty());
324 *resp.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
325 resp.headers_mut()
326 .insert(header::CONNECTION, HEADERVALUE_UPGRADE);
327 resp.headers_mut()
328 .insert(header::UPGRADE, HEADERVALUE_WEBSOCKET);
329 let Ok(accept_key) =
330 HeaderValue::from_str(&derive_accept_key(self.sec_websocket_key.as_bytes()))
331 else {
332 return StatusCode::BAD_REQUEST.into_response();
333 };
334 resp.headers_mut()
335 .insert(header::SEC_WEBSOCKET_ACCEPT, accept_key);
336 if let Some(protocol) = self.protocol {
337 if let Ok(protocol) = HeaderValue::from_bytes(protocol.as_bytes()) {
338 resp.headers_mut()
339 .insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
340 }
341 }
342
343 tokio::spawn(fut);
344
345 resp
346 }
347}
348
349fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
350 let Some(header) = headers.get(&key) else {
351 return false;
352 };
353 let Ok(header) = simdutf8::basic::from_utf8(header.as_bytes()) else {
354 return false;
355 };
356 header.to_ascii_lowercase().contains(value)
357}
358
359fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
360 let Some(header) = headers.get(&key) else {
361 return false;
362 };
363 header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
364}
365
366impl FromContext for WebSocketUpgrade<DefaultOnFailedUpgrade> {
367 type Rejection = WebSocketUpgradeRejectionError;
368
369 async fn from_context(
370 _: &mut ServerContext,
371 parts: &mut Parts,
372 ) -> Result<Self, Self::Rejection> {
373 if parts.method != Method::GET {
374 return Err(WebSocketUpgradeRejectionError::MethodNotGet);
375 }
376 if parts.version < Version::HTTP_11 {
377 return Err(WebSocketUpgradeRejectionError::InvalidHttpVersion);
378 }
379
380 if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
385 return Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader);
386 }
387
388 if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
389 return Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader);
390 }
391
392 if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
393 return Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader);
394 }
395
396 let sec_websocket_key = parts
397 .headers
398 .get(header::SEC_WEBSOCKET_KEY)
399 .ok_or(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing)?
400 .clone();
401
402 let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
403
404 let on_upgrade = parts
405 .extensions
406 .remove::<hyper::upgrade::OnUpgrade>()
407 .expect("`OnUpgrade` is unavailable, maybe something wrong with `hyper`");
408
409 Ok(Self {
410 config: Default::default(),
411 protocol: None,
412 sec_websocket_key,
413 sec_websocket_protocol,
414 on_upgrade,
415 on_failed_upgrade: DefaultOnFailedUpgrade,
416 })
417 }
418}
419
420pub struct WebSocket {
422 inner: WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
423 protocol: Option<HeaderValue>,
424}
425
426impl WebSocket {
427 pub fn protocol(&self) -> Option<&str> {
435 simdutf8::basic::from_utf8(self.protocol.as_ref()?.as_bytes()).ok()
436 }
437}
438
439impl Deref for WebSocket {
440 type Target = WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>;
441
442 fn deref(&self) -> &Self::Target {
443 &self.inner
444 }
445}
446
447impl DerefMut for WebSocket {
448 fn deref_mut(&mut self) -> &mut Self::Target {
449 &mut self.inner
450 }
451}
452
453#[derive(Debug)]
455pub enum WebSocketError {
456 Upgrade(hyper::Error),
461}
462
463impl fmt::Display for WebSocketError {
464 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
465 match self {
466 Self::Upgrade(err) => write!(f, "failed to upgrade: {err}"),
467 }
468 }
469}
470
471impl Error for WebSocketError {
472 fn source(&self) -> Option<&(dyn Error + 'static)> {
473 match self {
474 Self::Upgrade(e) => Some(e),
475 }
476 }
477}
478
479pub trait OnFailedUpgrade {
483 fn call(self, error: WebSocketError);
485}
486
487impl<F> OnFailedUpgrade for F
488where
489 F: FnOnce(WebSocketError),
490{
491 fn call(self, error: WebSocketError) {
492 self(error)
493 }
494}
495
496#[derive(Debug)]
500pub struct DefaultOnFailedUpgrade;
501
502impl OnFailedUpgrade for DefaultOnFailedUpgrade {
503 fn call(self, _: WebSocketError) {}
504}
505
506#[derive(Debug)]
510pub enum WebSocketUpgradeRejectionError {
511 MethodNotGet,
513 InvalidHttpVersion,
515 InvalidConnectionHeader,
517 InvalidUpgradeHeader,
519 InvalidWebSocketVersionHeader,
521 WebSocketKeyHeaderMissing,
523}
524
525impl WebSocketUpgradeRejectionError {
526 fn to_status_code(&self) -> StatusCode {
528 match self {
529 Self::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED,
530 Self::InvalidHttpVersion => StatusCode::HTTP_VERSION_NOT_SUPPORTED,
531 Self::InvalidConnectionHeader => StatusCode::UPGRADE_REQUIRED,
532 Self::InvalidUpgradeHeader => StatusCode::BAD_REQUEST,
533 Self::InvalidWebSocketVersionHeader => StatusCode::BAD_REQUEST,
534 Self::WebSocketKeyHeaderMissing => StatusCode::BAD_REQUEST,
535 }
536 }
537}
538
539impl Error for WebSocketUpgradeRejectionError {}
540
541impl fmt::Display for WebSocketUpgradeRejectionError {
542 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
543 match self {
544 Self::MethodNotGet => f.write_str("Request method must be `GET`"),
545 Self::InvalidHttpVersion => f.write_str("HTTP version not support"),
546 Self::InvalidConnectionHeader => {
547 f.write_str("Header `Connection` does not include `upgrade`")
548 }
549 Self::InvalidUpgradeHeader => f.write_str("Header `Upgrade` is not `websocket`"),
550 Self::InvalidWebSocketVersionHeader => {
551 f.write_str("Header `Sec-WebSocket-Version` is not `13`")
552 }
553 Self::WebSocketKeyHeaderMissing => f.write_str("Header `Sec-WebSocket-Key` is missing"),
554 }
555 }
556}
557
558impl IntoResponse for WebSocketUpgradeRejectionError {
559 fn into_response(self) -> Response {
560 self.to_status_code().into_response()
561 }
562}
563
564#[cfg(test)]
565mod websocket_tests {
566 use std::{
567 convert::Infallible,
568 net::{IpAddr, Ipv4Addr, SocketAddr},
569 str::FromStr,
570 };
571
572 use futures_util::{sink::SinkExt, stream::StreamExt};
573 use http::uri::Uri;
574 use motore::service::Service;
575 use tokio::net::TcpStream;
576 use tokio_tungstenite::MaybeTlsStream;
577 use tungstenite::ClientRequestBuilder;
578 use volo::net::Address;
579
580 use super::*;
581 use crate::{Server, request::Request, server::test_helpers};
582
583 fn simple_parts() -> Parts {
584 let req = Request::builder()
585 .method(Method::GET)
586 .version(Version::HTTP_11)
587 .header(header::HOST, "localhost")
588 .header(header::CONNECTION, super::HEADERVALUE_UPGRADE)
589 .header(header::UPGRADE, super::HEADERVALUE_WEBSOCKET)
590 .header(header::SEC_WEBSOCKET_KEY, "6D69KGBOr4Re+Nj6zx9aQA==")
591 .header(header::SEC_WEBSOCKET_VERSION, "13")
592 .body(())
593 .unwrap();
594 req.into_parts().0
595 }
596
597 async fn run_ws_handler<S>(
598 service: S,
599 sub_protocol: Option<&'static str>,
600 port: u16,
601 ) -> (
602 WebSocketStream<MaybeTlsStream<TcpStream>>,
603 Response<Option<Vec<u8>>>,
604 )
605 where
606 S: Service<ServerContext, Request, Response = Response, Error = Infallible>
607 + Send
608 + Sync
609 + 'static,
610 {
611 let addr = Address::Ip(SocketAddr::new(
612 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
613 port,
614 ));
615 tokio::spawn(Server::new(service).run(addr.clone()));
616
617 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
618
619 let mut req = ClientRequestBuilder::new(Uri::from_str(&format!("ws://{addr}/")).unwrap());
620 if let Some(sub_protocol) = sub_protocol {
621 req = req.with_sub_protocol(sub_protocol);
622 }
623 tokio_tungstenite::connect_async(req).await.unwrap()
624 }
625
626 #[tokio::test]
627 async fn rejection() {
628 {
629 let mut parts = simple_parts();
630 parts.method = Method::POST;
631 let res =
632 WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
633 assert!(matches!(
634 res,
635 Err(WebSocketUpgradeRejectionError::MethodNotGet)
636 ));
637 }
638 {
639 let mut parts = simple_parts();
640 parts.version = Version::HTTP_10;
641 let res =
642 WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
643 assert!(matches!(
644 res,
645 Err(WebSocketUpgradeRejectionError::InvalidHttpVersion)
646 ));
647 }
648 {
649 let mut parts = simple_parts();
650 parts.headers.remove(header::CONNECTION);
651 let res =
652 WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
653 assert!(matches!(
654 res,
655 Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader)
656 ));
657 }
658 {
659 let mut parts = simple_parts();
660 parts.headers.remove(header::CONNECTION);
661 parts
662 .headers
663 .insert(header::CONNECTION, HeaderValue::from_static("downgrade"));
664 let res =
665 WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
666 assert!(matches!(
667 res,
668 Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader)
669 ));
670 }
671 {
672 let mut parts = simple_parts();
673 parts.headers.remove(header::UPGRADE);
674 let res =
675 WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
676 assert!(matches!(
677 res,
678 Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader)
679 ));
680 }
681 {
682 let mut parts = simple_parts();
683 parts.headers.remove(header::UPGRADE);
684 parts
685 .headers
686 .insert(header::UPGRADE, HeaderValue::from_static("supersocket"));
687 let res =
688 WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
689 assert!(matches!(
690 res,
691 Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader)
692 ));
693 }
694 {
695 let mut parts = simple_parts();
696 parts.headers.remove(header::SEC_WEBSOCKET_VERSION);
697 let res =
698 WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
699 assert!(matches!(
700 res,
701 Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader)
702 ));
703 }
704 {
705 let mut parts = simple_parts();
706 parts.headers.remove(header::SEC_WEBSOCKET_VERSION);
707 parts.headers.insert(
708 header::SEC_WEBSOCKET_VERSION,
709 HeaderValue::from_static("114514"),
710 );
711 let res =
712 WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
713 assert!(matches!(
714 res,
715 Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader)
716 ));
717 }
718 {
719 let mut parts = simple_parts();
720 parts.headers.remove(header::SEC_WEBSOCKET_KEY);
721 let res =
722 WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
723 assert!(matches!(
724 res,
725 Err(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing)
726 ));
727 }
728 }
729
730 #[tokio::test]
731 async fn protocol_test() {
732 async fn handler(ws: WebSocketUpgrade) -> Response {
733 ws.protocols(["soap", "wmap", "graphql-ws", "chat"])
734 .on_upgrade(|_| async {})
735 }
736
737 let (_, resp) =
738 run_ws_handler(test_helpers::to_service(handler), Some("graphql-ws"), 25230).await;
739
740 assert_eq!(
741 resp.headers()
742 .get(http::header::SEC_WEBSOCKET_PROTOCOL)
743 .unwrap(),
744 "graphql-ws"
745 );
746 }
747
748 #[tokio::test]
749 async fn success_on_upgrade() {
750 async fn echo(mut socket: WebSocket) {
751 while let Some(Ok(msg)) = socket.next().await {
752 if msg.is_ping() || msg.is_pong() {
753 continue;
754 }
755 if socket.send(msg).await.is_err() {
756 break;
757 }
758 }
759 }
760
761 async fn handler(ws: WebSocketUpgrade) -> Response {
762 ws.on_upgrade(echo)
763 }
764
765 let (mut ws_stream, _) =
766 run_ws_handler(test_helpers::to_service(handler), None, 25231).await;
767
768 let input = Message::Text("foobar".into());
769 ws_stream.send(input.clone()).await.unwrap();
770 let output = ws_stream.next().await.unwrap().unwrap();
771 assert_eq!(input, output);
772
773 let input = Message::Ping("foobar".into());
774 ws_stream.send(input).await.unwrap();
775 let output = ws_stream.next().await.unwrap().unwrap();
776 assert_eq!(output, Message::Pong("foobar".into()));
777 }
778}