1use std::borrow::Cow;
11use std::collections::HashMap;
12use std::net::IpAddr;
13use std::pin::Pin;
14use std::sync::{Arc, Mutex, Weak};
15use std::task::{Context, Poll};
16use std::time::{Duration, Instant};
17
18use bytes::Bytes;
19use futures_util::{SinkExt, StreamExt};
20use garde::Validate;
21use http::header::{
22 CONNECTION, HOST, ORIGIN, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION,
23 UPGRADE,
24};
25use http::Method;
26use http::{HeaderValue, StatusCode};
27use hyper::upgrade::{OnUpgrade, Upgraded};
28use hyper_util::rt::TokioIo;
29use serde::de::DeserializeOwned;
30use serde::Serialize;
31use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf};
32use tokio::sync::watch;
33use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
34use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode as TgCloseCode;
35use tokio_tungstenite::tungstenite::protocol::CloseFrame;
36use tokio_tungstenite::tungstenite::protocol::Role;
37use tokio_tungstenite::tungstenite::protocol::WebSocketConfig as TgWebSocketConfig;
38use tokio_tungstenite::tungstenite::Message;
39use tokio_tungstenite::WebSocketStream;
40
41use crate::body::RespBody;
42use crate::error::{Error, Result};
43use crate::extract::{scheme_from_extensions, RequestContext, RequestScheme};
44use crate::response::Response;
45use crate::router::BoxFuture;
46
47const WEBSOCKET_VERSION: &str = "13";
49const NOT_A_WEBSOCKET: &str = "NOT_A_WEBSOCKET";
51const REQUEST_ID_HEADER: &str = "x-request-id";
53const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
56const DEFAULT_WS_MAX_MESSAGE_SIZE: usize = 1024 * 1024;
62const DEFAULT_WS_MAX_FRAME_SIZE: usize = 1024 * 1024;
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum WsCloseCode {
68 NormalClosure,
70 GoingAway,
72 ProtocolError,
74 UnsupportedData,
76 PolicyViolation,
78 MessageTooBig,
80 InternalError,
82 Other(u16),
84}
85
86impl WsCloseCode {
87 pub fn as_u16(self) -> u16 {
89 match self {
90 WsCloseCode::NormalClosure => 1000,
91 WsCloseCode::GoingAway => 1001,
92 WsCloseCode::ProtocolError => 1002,
93 WsCloseCode::UnsupportedData => 1003,
94 WsCloseCode::PolicyViolation => 1008,
95 WsCloseCode::MessageTooBig => 1009,
96 WsCloseCode::InternalError => 1011,
97 WsCloseCode::Other(code) => code,
98 }
99 }
100
101 pub fn from_u16(code: u16) -> Self {
103 match code {
104 1000 => WsCloseCode::NormalClosure,
105 1001 => WsCloseCode::GoingAway,
106 1002 => WsCloseCode::ProtocolError,
107 1003 => WsCloseCode::UnsupportedData,
108 1008 => WsCloseCode::PolicyViolation,
109 1009 => WsCloseCode::MessageTooBig,
110 1011 => WsCloseCode::InternalError,
111 other => WsCloseCode::Other(other),
112 }
113 }
114}
115
116#[derive(Debug, Clone, PartialEq, Eq)]
118pub struct WsClose {
119 pub code: WsCloseCode,
121 pub reason: String,
123}
124
125#[derive(Debug, Clone, PartialEq, Eq)]
127pub enum WsMessage {
128 Text(String),
130 Binary(Vec<u8>),
132 Ping(Vec<u8>),
134 Pong(Vec<u8>),
136 Close(Option<WsClose>),
138}
139
140#[derive(Debug, Clone)]
145pub struct WsError {
146 code: WsCloseCode,
147 message: String,
148}
149
150impl WsError {
151 pub fn new(code: WsCloseCode, message: impl Into<String>) -> Self {
153 Self {
154 code,
155 message: message.into(),
156 }
157 }
158
159 pub fn policy_violation(message: impl Into<String>) -> Self {
161 Self::new(WsCloseCode::PolicyViolation, message)
162 }
163
164 pub fn internal(message: impl Into<String>) -> Self {
166 Self::new(WsCloseCode::InternalError, message)
167 }
168
169 pub fn code(&self) -> WsCloseCode {
171 self.code
172 }
173
174 pub fn message(&self) -> &str {
176 &self.message
177 }
178}
179
180impl std::fmt::Display for WsError {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.write_str(&self.message)
183 }
184}
185
186impl std::error::Error for WsError {}
187
188impl From<WsError> for Error {
189 fn from(error: WsError) -> Self {
190 match error.code {
192 WsCloseCode::PolicyViolation => Error::forbidden(error.message),
193 WsCloseCode::MessageTooBig => Error::payload_too_large(error.message),
194 _ => Error::bad_request(error.message),
195 }
196 .with_code("WS_REJECTED")
197 }
198}
199
200#[derive(Clone, Default)]
206pub struct WebSocketConfig {
207 max_message_size: Option<usize>,
208 max_frame_size: Option<usize>,
209 idle_timeout: Option<Duration>,
210 handshake_timeout: Option<Duration>,
211 max_connections_per_ip: Option<usize>,
212 origin_policy: Option<WsOriginPolicy>,
213}
214
215#[derive(Clone)]
216enum WsOriginPolicy {
217 Any,
218 Allowlist(Vec<String>),
219}
220
221impl WebSocketConfig {
222 pub fn new() -> Self {
224 Self::default()
225 }
226
227 pub fn max_message_size(mut self, bytes: usize) -> Self {
229 self.max_message_size = Some(bytes);
230 self
231 }
232
233 pub fn max_message_size_kb(self, kb: usize) -> Self {
235 self.max_message_size(kb * 1024)
236 }
237
238 pub fn max_frame_size(mut self, bytes: usize) -> Self {
240 self.max_frame_size = Some(bytes);
241 self
242 }
243
244 pub fn max_frame_size_kb(self, kb: usize) -> Self {
246 self.max_frame_size(kb * 1024)
247 }
248
249 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
251 self.idle_timeout = Some(timeout);
252 self
253 }
254
255 pub fn idle_timeout_secs(self, secs: u64) -> Self {
257 self.idle_timeout(Duration::from_secs(secs))
258 }
259
260 pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
264 self.handshake_timeout = Some(timeout);
265 self
266 }
267
268 pub fn max_connections_per_ip(mut self, max: usize) -> Self {
272 self.max_connections_per_ip = Some(max);
273 self
274 }
275
276 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
278 match &mut self.origin_policy {
279 Some(WsOriginPolicy::Allowlist(allowed)) => allowed.push(origin.into()),
280 _ => self.origin_policy = Some(WsOriginPolicy::Allowlist(vec![origin.into()])),
281 }
282 self
283 }
284
285 pub fn allow_any_origin(mut self) -> Self {
287 self.origin_policy = Some(WsOriginPolicy::Any);
288 self
289 }
290
291 pub(crate) fn merge(self, base: &WebSocketConfig) -> Self {
293 Self {
294 max_message_size: self.max_message_size.or(base.max_message_size),
295 max_frame_size: self.max_frame_size.or(base.max_frame_size),
296 idle_timeout: self.idle_timeout.or(base.idle_timeout),
297 handshake_timeout: self.handshake_timeout.or(base.handshake_timeout),
298 max_connections_per_ip: self.max_connections_per_ip.or(base.max_connections_per_ip),
299 origin_policy: self.origin_policy.or_else(|| base.origin_policy.clone()),
300 }
301 }
302
303 pub(crate) fn ip_connection_limit(&self) -> Option<usize> {
305 self.max_connections_per_ip
306 }
307
308 fn to_tungstenite(&self) -> Option<TgWebSocketConfig> {
314 Some(TgWebSocketConfig {
315 max_message_size: Some(self.max_message_size.unwrap_or(DEFAULT_WS_MAX_MESSAGE_SIZE)),
316 max_frame_size: Some(self.max_frame_size.unwrap_or(DEFAULT_WS_MAX_FRAME_SIZE)),
317 ..TgWebSocketConfig::default()
318 })
319 }
320}
321
322#[derive(Clone)]
324pub(crate) struct AppWsConfig(pub(crate) WebSocketConfig);
325
326#[derive(Clone)]
329pub(crate) struct WsShutdown(pub(crate) watch::Receiver<bool>);
330
331#[derive(Clone)]
334pub(crate) struct WsIpLimiter {
335 counts: Arc<Mutex<HashMap<IpAddr, usize>>>,
336 max: usize,
337}
338
339impl WsIpLimiter {
340 pub(crate) fn new(max: usize) -> Self {
341 Self {
342 counts: Arc::new(Mutex::new(HashMap::new())),
343 max,
344 }
345 }
346
347 fn try_acquire(&self, ip: IpAddr) -> Option<WsIpPermit> {
350 let mut counts = self.counts.lock().unwrap_or_else(|p| p.into_inner());
351 let count = counts.entry(ip).or_insert(0);
352 if *count >= self.max {
353 return None;
354 }
355 *count += 1;
356 Some(WsIpPermit {
357 counts: Arc::clone(&self.counts),
358 ip,
359 })
360 }
361}
362
363struct WsIpPermit {
365 counts: Arc<Mutex<HashMap<IpAddr, usize>>>,
366 ip: IpAddr,
367}
368
369impl Drop for WsIpPermit {
370 fn drop(&mut self) {
371 let mut counts = self.counts.lock().unwrap_or_else(|p| p.into_inner());
372 if let Some(count) = counts.get_mut(&self.ip) {
373 *count -= 1;
374 if *count == 0 {
375 counts.remove(&self.ip);
376 }
377 }
378 }
379}
380
381#[derive(Clone)]
383pub(crate) struct WsConnInfo {
384 method: Method,
385 path: String,
386 request_id: Option<String>,
387}
388
389pub struct WsConnectInfo {
391 info: WsConnInfo,
392}
393
394impl WsConnectInfo {
395 pub(crate) fn new(info: WsConnInfo) -> Self {
396 Self { info }
397 }
398
399 pub fn method(&self) -> &Method {
401 &self.info.method
402 }
403
404 pub fn path(&self) -> &str {
406 &self.info.path
407 }
408
409 pub fn request_id(&self) -> Option<&str> {
411 self.info.request_id.as_deref()
412 }
413}
414
415pub struct WsDisconnectInfo {
417 info: WsConnInfo,
418 duration: Duration,
419 close_code: Option<WsCloseCode>,
420}
421
422impl WsDisconnectInfo {
423 pub(crate) fn new(
424 info: WsConnInfo,
425 duration: Duration,
426 close_code: Option<WsCloseCode>,
427 ) -> Self {
428 Self {
429 info,
430 duration,
431 close_code,
432 }
433 }
434
435 pub fn method(&self) -> &Method {
437 &self.info.method
438 }
439
440 pub fn path(&self) -> &str {
442 &self.info.path
443 }
444
445 pub fn request_id(&self) -> Option<&str> {
447 self.info.request_id.as_deref()
448 }
449
450 pub fn duration(&self) -> Duration {
452 self.duration
453 }
454
455 pub fn close_code(&self) -> Option<WsCloseCode> {
457 self.close_code
458 }
459}
460
461pub(crate) type WsConnectHook = Box<dyn Fn(WsConnectInfo) -> BoxFuture<'static, ()> + Send + Sync>;
463pub(crate) type WsDisconnectHook =
465 Box<dyn Fn(WsDisconnectInfo) -> BoxFuture<'static, ()> + Send + Sync>;
466
467#[derive(Default)]
469pub(crate) struct WsHooks {
470 pub(crate) connect: Vec<WsConnectHook>,
471 pub(crate) disconnect: Vec<WsDisconnectHook>,
472}
473
474pub(crate) enum Upgrade {
479 Hyper(OnUpgrade),
481 #[allow(dead_code)]
484 Duplex(DuplexStream),
485}
486
487enum WsTransport {
492 Upgraded(TokioIo<Upgraded>),
493 Duplex(DuplexStream),
494}
495
496impl AsyncRead for WsTransport {
497 fn poll_read(
498 self: Pin<&mut Self>,
499 cx: &mut Context<'_>,
500 buf: &mut ReadBuf<'_>,
501 ) -> Poll<std::io::Result<()>> {
502 match self.get_mut() {
503 WsTransport::Upgraded(io) => Pin::new(io).poll_read(cx, buf),
504 WsTransport::Duplex(io) => Pin::new(io).poll_read(cx, buf),
505 }
506 }
507}
508
509impl AsyncWrite for WsTransport {
510 fn poll_write(
511 self: Pin<&mut Self>,
512 cx: &mut Context<'_>,
513 buf: &[u8],
514 ) -> Poll<std::io::Result<usize>> {
515 match self.get_mut() {
516 WsTransport::Upgraded(io) => Pin::new(io).poll_write(cx, buf),
517 WsTransport::Duplex(io) => Pin::new(io).poll_write(cx, buf),
518 }
519 }
520
521 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
522 match self.get_mut() {
523 WsTransport::Upgraded(io) => Pin::new(io).poll_flush(cx),
524 WsTransport::Duplex(io) => Pin::new(io).poll_flush(cx),
525 }
526 }
527
528 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
529 match self.get_mut() {
530 WsTransport::Upgraded(io) => Pin::new(io).poll_shutdown(cx),
531 WsTransport::Duplex(io) => Pin::new(io).poll_shutdown(cx),
532 }
533 }
534}
535
536pub struct WebSocket {
538 upgrade: Upgrade,
539 config: WebSocketConfig,
540 hooks: Arc<WsHooks>,
541 info: WsConnInfo,
542 permit: Option<WsIpPermit>,
543 shutdown: Option<watch::Receiver<bool>>,
544}
545
546impl WebSocket {
547 #[doc(hidden)]
554 pub fn from_request_context(ctx: &RequestContext, route: WebSocketConfig) -> Result<Self> {
555 let upgrade = ctx.take_upgrade()?;
556 let app_default = ctx
557 .state()
558 .get::<AppWsConfig>()
559 .map(|config| config.0.clone())
560 .unwrap_or_default();
561 let config = route.merge(&app_default);
562
563 let permit = match (
567 config.max_connections_per_ip,
568 ctx.state().get::<WsIpLimiter>(),
569 ctx.peer_addr(),
570 ) {
571 (Some(_), Some(limiter), Some(peer)) => {
572 Some(limiter.try_acquire(peer.ip()).ok_or_else(|| {
573 Error::too_many_requests("too many WebSocket connections from this client")
574 })?)
575 }
576 _ => None,
577 };
578
579 let hooks = ctx
580 .state()
581 .get::<WsHooks>()
582 .unwrap_or_else(|| Arc::new(WsHooks::default()));
583 let request_id = ctx
584 .headers()
585 .get(REQUEST_ID_HEADER)
586 .and_then(|value| value.to_str().ok())
587 .map(str::to_owned);
588 let info = WsConnInfo {
589 method: ctx.method().clone(),
590 path: ctx.uri().path().to_owned(),
591 request_id,
592 };
593 let shutdown = ctx.state().get::<WsShutdown>().map(|s| s.0.clone());
594 Ok(Self {
595 upgrade,
596 config,
597 hooks,
598 info,
599 permit,
600 shutdown,
601 })
602 }
603
604 pub async fn accept(self) -> Result<WebSocketConn> {
608 let idle_timeout = self.config.idle_timeout;
609 let handshake_timeout = self
610 .config
611 .handshake_timeout
612 .unwrap_or(DEFAULT_HANDSHAKE_TIMEOUT);
613 let transport = match self.upgrade {
614 Upgrade::Hyper(on_upgrade) => {
615 let upgraded = tokio::time::timeout(handshake_timeout, on_upgrade)
618 .await
619 .map_err(|_| Error::internal("websocket upgrade timed out"))?
620 .map_err(|error| {
621 Error::internal(format!("websocket upgrade failed: {error}"))
622 })?;
623 WsTransport::Upgraded(TokioIo::new(upgraded))
624 }
625 Upgrade::Duplex(duplex) => WsTransport::Duplex(duplex),
626 };
627 let stream =
628 WebSocketStream::from_raw_socket(transport, Role::Server, self.config.to_tungstenite())
629 .await;
630
631 for hook in self.hooks.connect.iter() {
632 hook(WsConnectInfo::new(self.info.clone())).await;
633 }
634
635 Ok(WebSocketConn {
636 stream,
637 idle_timeout,
638 hooks: Arc::downgrade(&self.hooks),
639 info: self.info,
640 started: Instant::now(),
641 close_code: None,
642 _permit: self.permit,
643 shutdown: self.shutdown,
644 hooks_fired: false,
645 })
646 }
647}
648
649pub struct WebSocketConn {
651 stream: WebSocketStream<WsTransport>,
652 idle_timeout: Option<Duration>,
653 hooks: Weak<WsHooks>,
654 info: WsConnInfo,
655 started: Instant,
656 close_code: Option<WsCloseCode>,
657 _permit: Option<WsIpPermit>,
659 shutdown: Option<watch::Receiver<bool>>,
662 hooks_fired: bool,
665}
666
667impl Drop for WebSocketConn {
668 fn drop(&mut self) {
669 let Some(hooks) = self.hooks.upgrade() else {
670 return;
671 };
672 if self.hooks_fired || hooks.disconnect.is_empty() {
676 return;
677 }
678 if let Ok(handle) = tokio::runtime::Handle::try_current() {
681 let info = self.info.clone();
682 let duration = self.started.elapsed();
683 let close_code = self.close_code;
684 handle.spawn(async move {
685 for hook in hooks.disconnect.iter() {
686 hook(WsDisconnectInfo::new(info.clone(), duration, close_code)).await;
687 }
688 });
689 }
690 }
691}
692
693enum RecvStep {
695 Shutdown,
696 Frame(FrameStep),
697}
698
699enum FrameStep {
701 Message(Message),
702 Error(tokio_tungstenite::tungstenite::Error),
703 Idle,
705 Closed,
707}
708
709async fn next_frame(
711 stream: &mut WebSocketStream<WsTransport>,
712 idle_timeout: Option<Duration>,
713) -> FrameStep {
714 let next = match idle_timeout {
715 Some(timeout) => match tokio::time::timeout(timeout, stream.next()).await {
716 Ok(item) => item,
717 Err(_elapsed) => return FrameStep::Idle,
718 },
719 None => stream.next().await,
720 };
721 match next {
722 Some(Ok(message)) => FrameStep::Message(message),
723 Some(Err(error)) => FrameStep::Error(error),
724 None => FrameStep::Closed,
725 }
726}
727
728impl WebSocketConn {
729 pub async fn recv(&mut self) -> Result<Option<WsMessage>> {
734 loop {
735 if self.shutdown.as_ref().is_some_and(|rx| *rx.borrow()) {
737 let _ = self.send_close_going_away().await;
738 self.fire_disconnect_hooks().await;
739 return Ok(None);
740 }
741
742 let step = {
743 let frame = next_frame(&mut self.stream, self.idle_timeout);
744 tokio::pin!(frame);
745 match &mut self.shutdown {
746 Some(rx) => tokio::select! {
748 biased;
749 _ = rx.changed() => RecvStep::Shutdown,
750 outcome = &mut frame => RecvStep::Frame(outcome),
751 },
752 None => RecvStep::Frame(frame.await),
753 }
754 };
755
756 match step {
757 RecvStep::Shutdown => {
758 let _ = self.send_close_going_away().await;
761 self.fire_disconnect_hooks().await;
762 return Ok(None);
763 }
764 RecvStep::Frame(FrameStep::Idle) | RecvStep::Frame(FrameStep::Closed) => {
765 self.fire_disconnect_hooks().await;
766 return Ok(None);
767 }
768 RecvStep::Frame(FrameStep::Error(error)) => return Err(connection_error(error)),
769 RecvStep::Frame(FrameStep::Message(message)) => {
770 if let Some(message) = from_tungstenite(message) {
771 if let WsMessage::Close(close) = &message {
772 if let Some(close) = close {
773 self.close_code = Some(close.code);
774 }
775 self.fire_disconnect_hooks().await;
778 }
779 return Ok(Some(message));
780 }
781 }
783 }
784 }
785 }
786
787 async fn send_close_going_away(&mut self) -> Result<()> {
789 let close = Message::Close(Some(CloseFrame {
790 code: TgCloseCode::Away,
791 reason: "server shutting down".into(),
792 }));
793 self.stream.send(close).await.map_err(connection_error)
794 }
795
796 async fn fire_disconnect_hooks(&mut self) {
799 let Some(hooks) = self.hooks.upgrade() else {
800 self.hooks_fired = true;
801 return;
802 };
803 if self.hooks_fired || hooks.disconnect.is_empty() {
804 return;
805 }
806 self.hooks_fired = true;
807 let duration = self.started.elapsed();
808 for hook in hooks.disconnect.iter() {
809 hook(WsDisconnectInfo::new(
810 self.info.clone(),
811 duration,
812 self.close_code,
813 ))
814 .await;
815 }
816 }
817
818 pub async fn send(&mut self, message: WsMessage) -> Result<()> {
820 self.stream
821 .send(into_tungstenite(message))
822 .await
823 .map_err(connection_error)
824 }
825
826 pub async fn send_text(&mut self, text: impl Into<String>) -> Result<()> {
828 self.send(WsMessage::Text(text.into())).await
829 }
830
831 pub async fn send_binary(&mut self, bytes: impl Into<Vec<u8>>) -> Result<()> {
833 self.send(WsMessage::Binary(bytes.into())).await
834 }
835
836 pub async fn receive_text(&mut self) -> Result<Option<String>> {
840 while let Some(message) = self.recv().await? {
841 match message {
842 WsMessage::Text(text) => return Ok(Some(text)),
843 WsMessage::Close(_) => return Ok(None),
844 _ => continue,
845 }
846 }
847 Ok(None)
848 }
849
850 pub async fn receive_json<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
855 while let Some(message) = self.recv().await? {
856 let value = match message {
857 WsMessage::Text(text) => serde_json::from_str::<T>(&text),
858 WsMessage::Binary(bytes) => serde_json::from_slice::<T>(&bytes),
859 WsMessage::Close(_) => return Ok(None),
860 _ => continue,
861 };
862 return value
863 .map(Some)
864 .map_err(|error| Error::bad_request(format!("invalid JSON message: {error}")));
865 }
866 Ok(None)
867 }
868
869 pub async fn receive_valid<T>(&mut self) -> Result<Option<T>>
875 where
876 T: DeserializeOwned + Validate<Context = ()>,
877 {
878 while let Some(message) = self.recv().await? {
879 return match message {
880 WsMessage::Text(text) => deserialize_and_validate::<T>(text.as_bytes()).map(Some),
881 WsMessage::Binary(bytes) => deserialize_and_validate::<T>(&bytes).map(Some),
882 WsMessage::Close(_) => Ok(None),
883 _ => continue,
884 };
885 }
886 Ok(None)
887 }
888
889 pub async fn send_json<T: Serialize>(&mut self, value: &T) -> Result<()> {
891 let text = serde_json::to_string(value)
892 .map_err(|error| Error::internal(format!("failed to serialize message: {error}")))?;
893 self.send_text(text).await
894 }
895
896 pub async fn close(&mut self, code: WsCloseCode, reason: impl Into<String>) -> Result<()> {
898 self.close_code = Some(code);
899 self.send(WsMessage::Close(Some(WsClose {
900 code,
901 reason: reason.into(),
902 })))
903 .await?;
904 SinkExt::close(&mut self.stream)
905 .await
906 .map_err(connection_error)
907 }
908}
909
910#[doc(hidden)]
917pub fn __ws_handshake(ctx: &RequestContext, route: WebSocketConfig) -> Result<Response> {
918 validate_origin(ctx, &route)?;
919 let headers = ctx.headers();
920
921 let is_websocket = headers
922 .get(UPGRADE)
923 .and_then(|value| value.to_str().ok())
924 .is_some_and(|value| value.eq_ignore_ascii_case("websocket"));
925 if !is_websocket {
926 return Err(Error::bad_request("expected a WebSocket upgrade").with_code(NOT_A_WEBSOCKET));
927 }
928
929 let connection_upgrade = headers
930 .get(CONNECTION)
931 .and_then(|value| value.to_str().ok())
932 .is_some_and(|value| value.to_ascii_lowercase().contains("upgrade"));
933 if !connection_upgrade {
934 return Err(
935 Error::bad_request("WebSocket upgrade requires Connection: upgrade")
936 .with_code(NOT_A_WEBSOCKET),
937 );
938 }
939
940 let version_ok = headers
941 .get(SEC_WEBSOCKET_VERSION)
942 .and_then(|value| value.to_str().ok())
943 .is_some_and(|value| value == WEBSOCKET_VERSION);
944 if !version_ok {
945 return Err(Error::bad_request("unsupported WebSocket version").with_code(NOT_A_WEBSOCKET));
946 }
947
948 let key = headers.get(SEC_WEBSOCKET_KEY).ok_or_else(|| {
949 Error::bad_request("missing Sec-WebSocket-Key").with_code(NOT_A_WEBSOCKET)
950 })?;
951 let accept = derive_accept_key(key.as_bytes());
952 let accept = HeaderValue::from_str(&accept)
953 .map_err(|_| Error::internal("failed to build WebSocket accept header"))?;
954
955 let mut response = http::Response::new(RespBody::new(Bytes::new()));
956 *response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
957 let headers = response.headers_mut();
958 headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
959 headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
960 headers.insert(SEC_WEBSOCKET_ACCEPT, accept);
961 Ok(response)
962}
963
964fn validate_origin(ctx: &RequestContext, route: &WebSocketConfig) -> Result<()> {
965 let Some(origin) = ctx
966 .headers()
967 .get(ORIGIN)
968 .and_then(|value| value.to_str().ok())
969 else {
970 return Ok(());
971 };
972
973 let policy = effective_config(ctx, route).origin_policy;
974 match policy {
975 Some(WsOriginPolicy::Any) => Ok(()),
976 Some(WsOriginPolicy::Allowlist(allowed)) => {
977 let actual = parse_origin(origin).ok_or_else(|| {
978 Error::forbidden("websocket origin is not allowed").with_code("WS_ORIGIN_FORBIDDEN")
979 })?;
980 let matches = allowed
981 .iter()
982 .filter_map(|origin| parse_origin(origin))
983 .any(|allowed| allowed == actual);
984 if matches {
985 Ok(())
986 } else {
987 Err(Error::forbidden("websocket origin is not allowed")
988 .with_code("WS_ORIGIN_FORBIDDEN"))
989 }
990 }
991 None => {
992 let actual = parse_origin(origin).ok_or_else(|| {
993 Error::forbidden("websocket origin is not allowed").with_code("WS_ORIGIN_FORBIDDEN")
994 })?;
995 let expected = expected_same_origin(ctx).ok_or_else(|| {
996 Error::forbidden("websocket origin is not allowed").with_code("WS_ORIGIN_FORBIDDEN")
997 })?;
998 if actual == expected {
999 Ok(())
1000 } else {
1001 Err(Error::forbidden("websocket origin is not allowed")
1002 .with_code("WS_ORIGIN_FORBIDDEN"))
1003 }
1004 }
1005 }
1006}
1007
1008fn effective_config(ctx: &RequestContext, route: &WebSocketConfig) -> WebSocketConfig {
1009 let base = ctx
1010 .state()
1011 .get::<AppWsConfig>()
1012 .map(|config| config.0.clone())
1013 .unwrap_or_default();
1014 route.clone().merge(&base)
1015}
1016
1017#[derive(Clone, PartialEq, Eq)]
1018struct ParsedOrigin {
1019 scheme: &'static str,
1020 host: String,
1021 port: u16,
1022}
1023
1024fn parse_origin(origin: &str) -> Option<ParsedOrigin> {
1025 let uri: http::Uri = origin.parse().ok()?;
1026 let scheme = match uri.scheme_str()? {
1027 "http" => "http",
1028 "https" => "https",
1029 _ => return None,
1030 };
1031 let authority = uri.authority()?;
1032 Some(ParsedOrigin {
1033 scheme,
1034 host: authority.host().to_ascii_lowercase(),
1035 port: authority.port_u16().unwrap_or(default_port(scheme)),
1036 })
1037}
1038
1039fn expected_same_origin(ctx: &RequestContext) -> Option<ParsedOrigin> {
1040 let scheme = scheme_from_extensions(&ctx.head().extensions)
1041 .unwrap_or(RequestScheme::Http)
1042 .as_str();
1043 let host = ctx.headers().get(HOST)?.to_str().ok()?;
1044 let authority: http::uri::Authority = host.parse().ok()?;
1045 Some(ParsedOrigin {
1046 scheme,
1047 host: authority.host().to_ascii_lowercase(),
1048 port: authority.port_u16().unwrap_or(default_port(scheme)),
1049 })
1050}
1051
1052fn default_port(scheme: &str) -> u16 {
1053 if scheme == "https" {
1054 443
1055 } else {
1056 80
1057 }
1058}
1059
1060fn deserialize_and_validate<T>(bytes: &[u8]) -> Result<T>
1062where
1063 T: DeserializeOwned + Validate<Context = ()>,
1064{
1065 let value: T = serde_json::from_slice(bytes)
1066 .map_err(|error| Error::unprocessable(format!("invalid JSON message: {error}")))?;
1067 value.validate().map_err(Error::from_garde_report)?;
1068 Ok(value)
1069}
1070
1071pub(crate) fn into_tungstenite(message: WsMessage) -> Message {
1073 match message {
1074 WsMessage::Text(text) => Message::Text(text),
1075 WsMessage::Binary(bytes) => Message::Binary(bytes),
1076 WsMessage::Ping(bytes) => Message::Ping(bytes),
1077 WsMessage::Pong(bytes) => Message::Pong(bytes),
1078 WsMessage::Close(close) => Message::Close(close.map(|close| CloseFrame {
1079 code: TgCloseCode::from(close.code.as_u16()),
1080 reason: Cow::Owned(close.reason),
1081 })),
1082 }
1083}
1084
1085pub(crate) fn from_tungstenite(message: Message) -> Option<WsMessage> {
1087 match message {
1088 Message::Text(text) => Some(WsMessage::Text(text)),
1089 Message::Binary(bytes) => Some(WsMessage::Binary(bytes)),
1090 Message::Ping(bytes) => Some(WsMessage::Ping(bytes)),
1091 Message::Pong(bytes) => Some(WsMessage::Pong(bytes)),
1092 Message::Close(close) => Some(WsMessage::Close(close.map(|close| WsClose {
1093 code: WsCloseCode::from_u16(u16::from(close.code)),
1094 reason: close.reason.into_owned(),
1095 }))),
1096 Message::Frame(_) => None,
1097 }
1098}
1099
1100pub(crate) fn connection_error(error: tokio_tungstenite::tungstenite::Error) -> Error {
1102 Error::internal(format!("websocket connection error: {error}")).with_code("WS_CONNECTION_ERROR")
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107 use super::*;
1108 use crate::body::box_body;
1109 use crate::extract::PathParams;
1110 use crate::state::StateMap;
1111 use bytes::Bytes;
1112 use futures_util::{SinkExt, StreamExt};
1113 use http_body_util::Full;
1114 use std::sync::Mutex;
1115 use tokio_tungstenite::tungstenite::protocol::Role;
1116
1117 fn request_context(headers: &[(&str, &str)]) -> RequestContext {
1118 let mut builder = http::Request::builder().method(Method::GET).uri("/ws");
1119 for (name, value) in headers {
1120 builder = builder.header(*name, *value);
1121 }
1122 let head = builder.body(()).unwrap().into_parts().0;
1123 RequestContext::new(
1124 head,
1125 PathParams::new(),
1126 Arc::new(StateMap::new()),
1127 box_body(Full::new(Bytes::new())),
1128 )
1129 }
1130
1131 fn request_context_with_duplex(
1132 headers: &[(&str, &str)],
1133 config: Option<WebSocketConfig>,
1134 hooks: Option<WsHooks>,
1135 ) -> (RequestContext, DuplexStream) {
1136 let mut builder = http::Request::builder().method(Method::GET).uri("/ws");
1137 for (name, value) in headers {
1138 builder = builder.header(*name, *value);
1139 }
1140 let head = builder.body(()).unwrap().into_parts().0;
1141 let mut state = StateMap::new();
1142 if let Some(config) = config {
1143 state.insert(AppWsConfig(config));
1144 }
1145 if let Some(hooks) = hooks {
1146 state.insert(hooks);
1147 }
1148 let (client, server) = tokio::io::duplex(64 * 1024);
1149 let ctx = RequestContext::with_duplex_upgrade(
1150 head,
1151 PathParams::new(),
1152 Arc::new(state),
1153 box_body(Full::new(Bytes::new())),
1154 server,
1155 );
1156 (ctx, client)
1157 }
1158
1159 fn websocket_headers() -> [(&'static str, &'static str); 4] {
1160 [
1161 ("upgrade", "websocket"),
1162 ("connection", "keep-alive, Upgrade"),
1163 ("sec-websocket-version", "13"),
1164 ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1165 ]
1166 }
1167
1168 fn default_route_config() -> WebSocketConfig {
1169 WebSocketConfig::new()
1170 }
1171
1172 #[test]
1173 fn close_code_round_trips_through_u16() {
1174 for code in [
1175 WsCloseCode::NormalClosure,
1176 WsCloseCode::GoingAway,
1177 WsCloseCode::ProtocolError,
1178 WsCloseCode::UnsupportedData,
1179 WsCloseCode::PolicyViolation,
1180 WsCloseCode::MessageTooBig,
1181 WsCloseCode::InternalError,
1182 WsCloseCode::Other(4000),
1183 ] {
1184 assert_eq!(WsCloseCode::from_u16(code.as_u16()), code);
1185 }
1186 }
1187
1188 #[test]
1189 fn messages_map_to_and_from_tungstenite() {
1190 let cases = [
1191 WsMessage::Text("hello".to_owned()),
1192 WsMessage::Binary(vec![1, 2, 3]),
1193 WsMessage::Ping(vec![9]),
1194 WsMessage::Pong(vec![8]),
1195 WsMessage::Close(Some(WsClose {
1196 code: WsCloseCode::NormalClosure,
1197 reason: "bye".to_owned(),
1198 })),
1199 ];
1200 for message in cases {
1201 let round = from_tungstenite(into_tungstenite(message.clone()));
1202 assert_eq!(round, Some(message));
1203 }
1204 }
1205
1206 #[test]
1207 fn config_merge_prefers_route_over_app() {
1208 let app = WebSocketConfig::new()
1209 .max_message_size(1000)
1210 .idle_timeout_secs(30);
1211 let route = WebSocketConfig::new().max_message_size(2000);
1212
1213 let merged = route.merge(&app);
1214 assert_eq!(merged.max_message_size, Some(2000), "route value wins");
1215 assert_eq!(merged.max_frame_size, None);
1216 assert_eq!(
1217 merged.idle_timeout,
1218 Some(Duration::from_secs(30)),
1219 "app default is kept where the route is unset"
1220 );
1221 }
1222
1223 #[test]
1224 fn ws_error_maps_to_an_http_status() {
1225 let error: Error = WsError::policy_violation("no token").into();
1226 assert_eq!(error.kind(), crate::ErrorKind::Forbidden);
1227 assert_eq!(error.code(), "WS_REJECTED");
1228
1229 let too_large: Error = WsError::new(WsCloseCode::MessageTooBig, "big").into();
1230 assert_eq!(too_large.kind(), crate::ErrorKind::PayloadTooLarge);
1231
1232 let internal = WsError::internal("boom");
1233 assert_eq!(internal.code(), WsCloseCode::InternalError);
1234 assert_eq!(internal.message(), "boom");
1235 assert_eq!(internal.to_string(), "boom");
1236 }
1237
1238 #[test]
1239 fn disconnect_info_exposes_duration_and_close_code() {
1240 let info = WsConnInfo {
1241 method: Method::GET,
1242 path: "/ws".to_owned(),
1243 request_id: Some("req-1".to_owned()),
1244 };
1245 let event = WsDisconnectInfo::new(
1246 info,
1247 Duration::from_secs(3),
1248 Some(WsCloseCode::NormalClosure),
1249 );
1250 assert_eq!(event.path(), "/ws");
1251 assert_eq!(event.method(), &Method::GET);
1252 assert_eq!(event.request_id(), Some("req-1"));
1253 assert_eq!(event.duration(), Duration::from_secs(3));
1254 assert_eq!(event.close_code(), Some(WsCloseCode::NormalClosure));
1255 }
1256
1257 #[test]
1258 fn websocket_config_builders_and_connect_info_accessors_work() {
1259 let config = WebSocketConfig::new()
1260 .max_message_size_kb(2)
1261 .max_frame_size_kb(1)
1262 .idle_timeout_secs(3);
1263 let tungstenite = config.to_tungstenite().expect("limits should be present");
1264 assert_eq!(tungstenite.max_message_size, Some(2 * 1024));
1265 assert_eq!(tungstenite.max_frame_size, Some(1024));
1266 assert_eq!(config.idle_timeout, Some(Duration::from_secs(3)));
1267 let defaults = WebSocketConfig::new()
1270 .to_tungstenite()
1271 .expect("defaults should be present");
1272 assert_eq!(defaults.max_message_size, Some(DEFAULT_WS_MAX_MESSAGE_SIZE));
1273 assert_eq!(defaults.max_frame_size, Some(DEFAULT_WS_MAX_FRAME_SIZE));
1274
1275 let info = WsConnInfo {
1276 method: Method::POST,
1277 path: "/chat".to_owned(),
1278 request_id: Some("req-9".to_owned()),
1279 };
1280 let connect = WsConnectInfo::new(info);
1281 assert_eq!(connect.method(), &Method::POST);
1282 assert_eq!(connect.path(), "/chat");
1283 assert_eq!(connect.request_id(), Some("req-9"));
1284 }
1285
1286 #[test]
1287 fn handshake_validates_required_headers() {
1288 let ctx = request_context(&[]);
1289 let error = match __ws_handshake(&ctx, default_route_config()) {
1290 Ok(_) => panic!("expected handshake rejection"),
1291 Err(error) => error,
1292 };
1293 assert_eq!(error.code(), NOT_A_WEBSOCKET);
1294 assert_eq!(error.message(), "expected a WebSocket upgrade");
1295
1296 let ctx = request_context(&[
1297 ("upgrade", "websocket"),
1298 ("sec-websocket-version", "13"),
1299 ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1300 ]);
1301 let error = match __ws_handshake(&ctx, default_route_config()) {
1302 Ok(_) => panic!("expected handshake rejection"),
1303 Err(error) => error,
1304 };
1305 assert_eq!(
1306 error.message(),
1307 "WebSocket upgrade requires Connection: upgrade"
1308 );
1309
1310 let ctx = request_context(&[
1311 ("upgrade", "websocket"),
1312 ("connection", "upgrade"),
1313 ("sec-websocket-version", "12"),
1314 ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1315 ]);
1316 let error = match __ws_handshake(&ctx, default_route_config()) {
1317 Ok(_) => panic!("expected handshake rejection"),
1318 Err(error) => error,
1319 };
1320 assert_eq!(error.message(), "unsupported WebSocket version");
1321
1322 let ctx = request_context(&[
1323 ("upgrade", "websocket"),
1324 ("connection", "upgrade"),
1325 ("sec-websocket-version", "13"),
1326 ]);
1327 let error = match __ws_handshake(&ctx, default_route_config()) {
1328 Ok(_) => panic!("expected handshake rejection"),
1329 Err(error) => error,
1330 };
1331 assert_eq!(error.message(), "missing Sec-WebSocket-Key");
1332 }
1333
1334 #[test]
1335 fn handshake_builds_switching_protocols_response() {
1336 let ctx = request_context(&websocket_headers());
1337 let response = __ws_handshake(&ctx, default_route_config()).unwrap();
1338 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
1339 assert_eq!(response.headers()[UPGRADE], "websocket");
1340 assert_eq!(response.headers()[CONNECTION], "upgrade");
1341 assert!(response.headers().contains_key(SEC_WEBSOCKET_ACCEPT));
1342 }
1343
1344 #[test]
1345 fn handshake_rejects_cross_origin_by_default_and_accepts_same_origin() {
1346 let ctx = request_context(&[
1347 ("host", "example.com"),
1348 ("origin", "https://evil.example.com"),
1349 ("upgrade", "websocket"),
1350 ("connection", "upgrade"),
1351 ("sec-websocket-version", "13"),
1352 ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1353 ]);
1354 let error = match __ws_handshake(&ctx, default_route_config()) {
1355 Ok(_) => panic!("expected handshake rejection"),
1356 Err(error) => error,
1357 };
1358 assert_eq!(error.kind(), crate::ErrorKind::Forbidden);
1359 assert_eq!(error.code(), "WS_ORIGIN_FORBIDDEN");
1360
1361 let mut head = http::Request::builder()
1362 .method(Method::GET)
1363 .uri("/ws")
1364 .header("host", "example.com")
1365 .header("origin", "https://example.com")
1366 .header("upgrade", "websocket")
1367 .header("connection", "upgrade")
1368 .header("sec-websocket-version", "13")
1369 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1370 .body(())
1371 .unwrap()
1372 .into_parts()
1373 .0;
1374 head.extensions.insert(RequestScheme::Https);
1375 let ctx = RequestContext::new(
1376 head,
1377 PathParams::new(),
1378 Arc::new(StateMap::new()),
1379 box_body(Full::new(Bytes::new())),
1380 );
1381 let response = __ws_handshake(&ctx, default_route_config()).unwrap();
1382 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
1383 }
1384
1385 #[test]
1386 fn allowlists_and_allow_any_origin_override_same_origin_policy() {
1387 let ctx = request_context(&[
1388 ("host", "example.com"),
1389 ("origin", "https://evil.example.com"),
1390 ("upgrade", "websocket"),
1391 ("connection", "upgrade"),
1392 ("sec-websocket-version", "13"),
1393 ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1394 ]);
1395
1396 let response = __ws_handshake(
1397 &ctx,
1398 WebSocketConfig::new().allow_origin("https://evil.example.com"),
1399 )
1400 .unwrap();
1401 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
1402
1403 let response = __ws_handshake(&ctx, WebSocketConfig::new().allow_any_origin()).unwrap();
1404 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
1405 }
1406
1407 #[test]
1408 fn from_request_context_merges_config_and_captures_request_metadata() {
1409 let hooks = WsHooks::default();
1410 let (ctx, _client) = request_context_with_duplex(
1411 &[
1412 ("x-request-id", "req-2"),
1413 ("upgrade", "websocket"),
1414 ("connection", "upgrade"),
1415 ("sec-websocket-version", "13"),
1416 ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1417 ],
1418 Some(WebSocketConfig::new().max_frame_size(64)),
1419 Some(hooks),
1420 );
1421
1422 let socket = WebSocket::from_request_context(
1423 &ctx,
1424 WebSocketConfig::new()
1425 .max_message_size(128)
1426 .idle_timeout(Duration::from_secs(2)),
1427 )
1428 .unwrap();
1429
1430 assert_eq!(socket.config.max_message_size, Some(128));
1431 assert_eq!(socket.config.max_frame_size, Some(64));
1432 assert_eq!(socket.config.idle_timeout, Some(Duration::from_secs(2)));
1433 assert_eq!(socket.info.path, "/ws");
1434 assert_eq!(socket.info.request_id.as_deref(), Some("req-2"));
1435 assert!(socket.hooks.connect.is_empty());
1436 assert!(socket.hooks.disconnect.is_empty());
1437 }
1438
1439 #[derive(Debug, PartialEq, Eq, serde::Deserialize, garde::Validate)]
1440 struct ChatIn {
1441 #[garde(length(min = 1))]
1442 message: String,
1443 }
1444
1445 #[test]
1446 fn deserialize_and_validate_accepts_valid_and_rejects_invalid() {
1447 let ok = deserialize_and_validate::<ChatIn>(br#"{"message":"hi"}"#);
1448 assert!(ok.is_ok());
1449
1450 let empty = deserialize_and_validate::<ChatIn>(br#"{"message":""}"#);
1451 assert_eq!(empty.err().unwrap().kind(), crate::ErrorKind::Unprocessable);
1452
1453 let malformed = deserialize_and_validate::<ChatIn>(b"not json");
1454 assert_eq!(
1455 malformed.err().unwrap().kind(),
1456 crate::ErrorKind::Unprocessable
1457 );
1458 }
1459
1460 #[tokio::test]
1461 async fn duplex_accept_runs_hooks_and_exchanges_messages() {
1462 let connects = Arc::new(Mutex::new(Vec::new()));
1463 let disconnects = Arc::new(Mutex::new(Vec::new()));
1464 let hooks = WsHooks {
1465 connect: vec![Box::new({
1466 let connects = connects.clone();
1467 move |info| {
1468 let connects = connects.clone();
1469 Box::pin(async move {
1470 connects.lock().unwrap().push((
1471 info.method().clone(),
1472 info.path().to_owned(),
1473 info.request_id().map(str::to_owned),
1474 ));
1475 })
1476 }
1477 })],
1478 disconnect: vec![Box::new({
1479 let disconnects = disconnects.clone();
1480 move |info| {
1481 let disconnects = disconnects.clone();
1482 Box::pin(async move {
1483 disconnects
1484 .lock()
1485 .unwrap()
1486 .push((info.path().to_owned(), info.close_code()));
1487 })
1488 }
1489 })],
1490 };
1491 let headers = [
1492 ("x-request-id", "req-hook"),
1493 ("upgrade", "websocket"),
1494 ("connection", "upgrade"),
1495 ("sec-websocket-version", "13"),
1496 ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1497 ];
1498 let (ctx, client_io) = request_context_with_duplex(&headers, None, Some(hooks));
1499 let socket = WebSocket::from_request_context(&ctx, WebSocketConfig::new()).unwrap();
1500 let mut conn = socket.accept().await.unwrap();
1501 let mut client = WebSocketStream::from_raw_socket(client_io, Role::Client, None).await;
1502
1503 client.send(Message::Text("hello".into())).await.unwrap();
1504 assert_eq!(conn.receive_text().await.unwrap(), Some("hello".to_owned()));
1505
1506 conn.send_json(&serde_json::json!({ "ok": true }))
1507 .await
1508 .unwrap();
1509 let message = client.next().await.unwrap().unwrap();
1510 assert_eq!(message.into_text().unwrap(), r#"{"ok":true}"#);
1511
1512 conn.close(WsCloseCode::NormalClosure, "bye").await.unwrap();
1513 match client.next().await.unwrap().unwrap() {
1514 Message::Close(Some(close)) => {
1515 assert_eq!(u16::from(close.code), 1000);
1516 assert_eq!(close.reason, "bye");
1517 }
1518 other => panic!("expected close frame, got {other:?}"),
1519 }
1520 drop(conn);
1521 tokio::task::yield_now().await;
1522
1523 assert_eq!(
1524 connects.lock().unwrap().as_slice(),
1525 &[(Method::GET, "/ws".to_owned(), Some("req-hook".to_owned()))]
1526 );
1527 assert_eq!(
1528 disconnects.lock().unwrap().as_slice(),
1529 &[("/ws".to_owned(), Some(WsCloseCode::NormalClosure))]
1530 );
1531 }
1532
1533 #[tokio::test]
1534 async fn duplex_connection_helpers_cover_close_idle_and_validation_paths() {
1535 let (ctx, client_io) = request_context_with_duplex(&websocket_headers(), None, None);
1536 let socket = WebSocket::from_request_context(
1537 &ctx,
1538 WebSocketConfig::new().idle_timeout(Duration::from_millis(10)),
1539 )
1540 .unwrap();
1541 let mut conn = socket.accept().await.unwrap();
1542 let mut client = WebSocketStream::from_raw_socket(client_io, Role::Client, None).await;
1543
1544 client.send(Message::Ping(vec![1, 2])).await.unwrap();
1545 client
1546 .send(Message::Text("{\"message\":\"ok\"}".into()))
1547 .await
1548 .unwrap();
1549 let validated = conn.receive_valid::<ChatIn>().await.unwrap().unwrap();
1550 assert_eq!(validated.message, "ok");
1551
1552 client
1553 .send(Message::Binary(br#"{"message":""}"#.to_vec()))
1554 .await
1555 .unwrap();
1556 let error = match conn.receive_valid::<ChatIn>().await {
1557 Ok(_) => panic!("expected validation error"),
1558 Err(error) => error,
1559 };
1560 assert_eq!(error.kind(), crate::ErrorKind::Unprocessable);
1561
1562 client.send(Message::Text("not-json".into())).await.unwrap();
1563 let error = match conn.receive_json::<ChatIn>().await {
1564 Ok(_) => panic!("expected decode error"),
1565 Err(error) => error,
1566 };
1567 assert_eq!(error.kind(), crate::ErrorKind::BadRequest);
1568
1569 client.close(None).await.unwrap();
1570 assert_eq!(conn.receive_text().await.unwrap(), None);
1571 assert_eq!(conn.receive_json::<ChatIn>().await.unwrap(), None);
1572 assert_eq!(conn.receive_valid::<ChatIn>().await.unwrap(), None);
1573
1574 let (ctx, _client_io) = request_context_with_duplex(&websocket_headers(), None, None);
1575 let socket = WebSocket::from_request_context(
1576 &ctx,
1577 WebSocketConfig::new().idle_timeout(Duration::from_millis(5)),
1578 )
1579 .unwrap();
1580 let mut idle_conn = socket.accept().await.unwrap();
1581 assert_eq!(idle_conn.recv().await.unwrap(), None);
1582 }
1583
1584 #[test]
1585 fn frame_and_connection_errors_map_to_expected_results() {
1586 let error = connection_error(tokio_tungstenite::tungstenite::Error::ConnectionClosed);
1587 assert_eq!(error.code(), "WS_CONNECTION_ERROR");
1588 assert!(error.message().contains("websocket connection error:"));
1589 }
1590
1591 #[test]
1592 fn ws_ip_limiter_caps_per_ip_and_releases_on_drop() {
1593 use std::net::Ipv4Addr;
1594
1595 let limiter = WsIpLimiter::new(2);
1596 let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
1597
1598 let first = limiter.try_acquire(ip).expect("first is under the limit");
1599 let _second = limiter.try_acquire(ip).expect("second reaches the limit");
1600 assert!(
1601 limiter.try_acquire(ip).is_none(),
1602 "a third connection from the same IP is rejected"
1603 );
1604
1605 let other = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
1607 assert!(limiter.try_acquire(other).is_some());
1608
1609 drop(first);
1611 assert!(
1612 limiter.try_acquire(ip).is_some(),
1613 "dropping a connection frees a slot"
1614 );
1615 }
1616
1617 #[test]
1618 fn route_config_overrides_app_defaults_for_new_limits() {
1619 let app = WebSocketConfig::new()
1620 .handshake_timeout(Duration::from_secs(5))
1621 .max_connections_per_ip(10);
1622 let route = WebSocketConfig::new().max_connections_per_ip(3);
1623
1624 let merged = route.merge(&app);
1625 assert_eq!(merged.ip_connection_limit(), Some(3), "route wins");
1626 assert_eq!(
1627 merged.handshake_timeout,
1628 Some(Duration::from_secs(5)),
1629 "unset on the route, taken from the app default"
1630 );
1631 }
1632}