Skip to main content

tork_core/
ws.rs

1//! WebSocket connections.
2//!
3//! A `#[websocket]` handler receives a [`WebSocket`] handle and calls
4//! [`accept`](WebSocket::accept) to obtain a live [`WebSocketConn`]. Dependencies
5//! and the handshake are resolved before the upgrade, so a failure is rejected
6//! with a normal HTTP response; once accepted, the connection exchanges
7//! [`WsMessage`] values until it closes. The wire protocol is handled by
8//! `tokio-tungstenite`, which users never see directly.
9
10use std::borrow::Cow;
11use std::collections::HashMap;
12use std::net::IpAddr;
13use std::pin::Pin;
14use std::sync::{Arc, Mutex, Weak};
15use std::task::{Context, Poll};
16use std::time::{Duration, Instant};
17
18use bytes::Bytes;
19use futures_util::{SinkExt, StreamExt};
20use garde::Validate;
21use http::header::{
22    CONNECTION, HOST, ORIGIN, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION,
23    UPGRADE,
24};
25use http::Method;
26use http::{HeaderValue, StatusCode};
27use hyper::upgrade::{OnUpgrade, Upgraded};
28use hyper_util::rt::TokioIo;
29use serde::de::DeserializeOwned;
30use serde::Serialize;
31use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf};
32use tokio::sync::watch;
33use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
34use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode as TgCloseCode;
35use tokio_tungstenite::tungstenite::protocol::CloseFrame;
36use tokio_tungstenite::tungstenite::protocol::Role;
37use tokio_tungstenite::tungstenite::protocol::WebSocketConfig as TgWebSocketConfig;
38use tokio_tungstenite::tungstenite::Message;
39use tokio_tungstenite::WebSocketStream;
40
41use crate::body::RespBody;
42use crate::error::{Error, Result};
43use crate::extract::{scheme_from_extensions, RequestContext, RequestScheme};
44use crate::response::Response;
45use crate::router::BoxFuture;
46
47/// The supported WebSocket protocol version.
48const WEBSOCKET_VERSION: &str = "13";
49/// Error code used when a request is not a valid WebSocket upgrade.
50const NOT_A_WEBSOCKET: &str = "NOT_A_WEBSOCKET";
51/// Header consulted to correlate a connection with a request identifier.
52const REQUEST_ID_HEADER: &str = "x-request-id";
53/// Default time allowed for the upgrade handshake to complete before the
54/// pending connection is abandoned, so a stalled client cannot hold a slot.
55const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
56/// Default cap on a single incoming WebSocket message (reassembled frames).
57///
58/// Applied when the app/route does not set one, so a peer cannot make the server
59/// buffer up to tungstenite's 64 MiB default per message. Override with
60/// [`WebSocketConfig::max_message_size`].
61const DEFAULT_WS_MAX_MESSAGE_SIZE: usize = 1024 * 1024;
62/// Default cap on a single incoming WebSocket frame.
63const DEFAULT_WS_MAX_FRAME_SIZE: usize = 1024 * 1024;
64
65/// A WebSocket close status code.
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum WsCloseCode {
68    /// `1000` Normal closure.
69    NormalClosure,
70    /// `1001` The endpoint is going away.
71    GoingAway,
72    /// `1002` Protocol error.
73    ProtocolError,
74    /// `1003` Unsupported data type.
75    UnsupportedData,
76    /// `1008` A message violated the endpoint's policy.
77    PolicyViolation,
78    /// `1009` A message was too big to process.
79    MessageTooBig,
80    /// `1011` The server encountered an internal error.
81    InternalError,
82    /// Any other status code.
83    Other(u16),
84}
85
86impl WsCloseCode {
87    /// Returns the numeric status code.
88    pub fn as_u16(self) -> u16 {
89        match self {
90            WsCloseCode::NormalClosure => 1000,
91            WsCloseCode::GoingAway => 1001,
92            WsCloseCode::ProtocolError => 1002,
93            WsCloseCode::UnsupportedData => 1003,
94            WsCloseCode::PolicyViolation => 1008,
95            WsCloseCode::MessageTooBig => 1009,
96            WsCloseCode::InternalError => 1011,
97            WsCloseCode::Other(code) => code,
98        }
99    }
100
101    /// Builds a close code from its numeric value.
102    pub fn from_u16(code: u16) -> Self {
103        match code {
104            1000 => WsCloseCode::NormalClosure,
105            1001 => WsCloseCode::GoingAway,
106            1002 => WsCloseCode::ProtocolError,
107            1003 => WsCloseCode::UnsupportedData,
108            1008 => WsCloseCode::PolicyViolation,
109            1009 => WsCloseCode::MessageTooBig,
110            1011 => WsCloseCode::InternalError,
111            other => WsCloseCode::Other(other),
112        }
113    }
114}
115
116/// A close control frame: a status code and a human-readable reason.
117#[derive(Debug, Clone, PartialEq, Eq)]
118pub struct WsClose {
119    /// The close status code.
120    pub code: WsCloseCode,
121    /// The reason for closing.
122    pub reason: String,
123}
124
125/// A WebSocket message.
126#[derive(Debug, Clone, PartialEq, Eq)]
127pub enum WsMessage {
128    /// A UTF-8 text message.
129    Text(String),
130    /// A binary message.
131    Binary(Vec<u8>),
132    /// A ping control frame.
133    Ping(Vec<u8>),
134    /// A pong control frame.
135    Pong(Vec<u8>),
136    /// A close control frame, with an optional reason.
137    Close(Option<WsClose>),
138}
139
140/// An error raised while handling a WebSocket connection.
141///
142/// Before the connection is accepted it converts into an HTTP error (so a guard
143/// can reject the upgrade); after accept, prefer [`WebSocketConn::close`].
144#[derive(Debug, Clone)]
145pub struct WsError {
146    code: WsCloseCode,
147    message: String,
148}
149
150impl WsError {
151    /// Creates an error with an explicit close code.
152    pub fn new(code: WsCloseCode, message: impl Into<String>) -> Self {
153        Self {
154            code,
155            message: message.into(),
156        }
157    }
158
159    /// Creates a `PolicyViolation` (`1008`) error.
160    pub fn policy_violation(message: impl Into<String>) -> Self {
161        Self::new(WsCloseCode::PolicyViolation, message)
162    }
163
164    /// Creates an `InternalError` (`1011`) error.
165    pub fn internal(message: impl Into<String>) -> Self {
166        Self::new(WsCloseCode::InternalError, message)
167    }
168
169    /// Returns the close code this error maps to.
170    pub fn code(&self) -> WsCloseCode {
171        self.code
172    }
173
174    /// Returns the error message.
175    pub fn message(&self) -> &str {
176        &self.message
177    }
178}
179
180impl std::fmt::Display for WsError {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.write_str(&self.message)
183    }
184}
185
186impl std::error::Error for WsError {}
187
188impl From<WsError> for Error {
189    fn from(error: WsError) -> Self {
190        // Used when a guard rejects the upgrade before it is accepted.
191        match error.code {
192            WsCloseCode::PolicyViolation => Error::forbidden(error.message),
193            WsCloseCode::MessageTooBig => Error::payload_too_large(error.message),
194            _ => Error::bad_request(error.message),
195        }
196        .with_code("WS_REJECTED")
197    }
198}
199
200/// Limits and timeouts for a WebSocket connection.
201///
202/// Set defaults for the whole app with
203/// [`App::websocket_config`](crate::App::websocket_config), or per route with the
204/// `#[websocket(...)]` attributes; a route value overrides the app default.
205#[derive(Clone, Default)]
206pub struct WebSocketConfig {
207    max_message_size: Option<usize>,
208    max_frame_size: Option<usize>,
209    idle_timeout: Option<Duration>,
210    handshake_timeout: Option<Duration>,
211    max_connections_per_ip: Option<usize>,
212    origin_policy: Option<WsOriginPolicy>,
213}
214
215#[derive(Clone)]
216enum WsOriginPolicy {
217    Any,
218    Allowlist(Vec<String>),
219}
220
221impl WebSocketConfig {
222    /// Creates an empty configuration (all limits unset).
223    pub fn new() -> Self {
224        Self::default()
225    }
226
227    /// Sets the maximum size of an incoming message, in bytes.
228    pub fn max_message_size(mut self, bytes: usize) -> Self {
229        self.max_message_size = Some(bytes);
230        self
231    }
232
233    /// Sets the maximum size of an incoming message, in kibibytes.
234    pub fn max_message_size_kb(self, kb: usize) -> Self {
235        self.max_message_size(kb * 1024)
236    }
237
238    /// Sets the maximum size of a single incoming frame, in bytes.
239    pub fn max_frame_size(mut self, bytes: usize) -> Self {
240        self.max_frame_size = Some(bytes);
241        self
242    }
243
244    /// Sets the maximum size of a single incoming frame, in kibibytes.
245    pub fn max_frame_size_kb(self, kb: usize) -> Self {
246        self.max_frame_size(kb * 1024)
247    }
248
249    /// Closes the connection if no message arrives within `timeout`.
250    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
251        self.idle_timeout = Some(timeout);
252        self
253    }
254
255    /// Closes the connection if no message arrives within `secs` seconds.
256    pub fn idle_timeout_secs(self, secs: u64) -> Self {
257        self.idle_timeout(Duration::from_secs(secs))
258    }
259
260    /// Sets how long the upgrade handshake may take before the pending
261    /// connection is abandoned (default 10 seconds). Guards against a slow client
262    /// that opens the upgrade and then stalls, holding a connection slot.
263    pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
264        self.handshake_timeout = Some(timeout);
265        self
266    }
267
268    /// Limits the number of concurrent WebSocket connections from a single client
269    /// IP; further connections from that IP are rejected with `429`. An
270    /// application-level setting (`App::websocket_config`).
271    pub fn max_connections_per_ip(mut self, max: usize) -> Self {
272        self.max_connections_per_ip = Some(max);
273        self
274    }
275
276    /// Allows a browser `Origin` for this WebSocket endpoint.
277    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
278        match &mut self.origin_policy {
279            Some(WsOriginPolicy::Allowlist(allowed)) => allowed.push(origin.into()),
280            _ => self.origin_policy = Some(WsOriginPolicy::Allowlist(vec![origin.into()])),
281        }
282        self
283    }
284
285    /// Allows any browser `Origin`.
286    pub fn allow_any_origin(mut self) -> Self {
287        self.origin_policy = Some(WsOriginPolicy::Any);
288        self
289    }
290
291    /// Returns a copy with each unset field taken from `base` (route over app).
292    pub(crate) fn merge(self, base: &WebSocketConfig) -> Self {
293        Self {
294            max_message_size: self.max_message_size.or(base.max_message_size),
295            max_frame_size: self.max_frame_size.or(base.max_frame_size),
296            idle_timeout: self.idle_timeout.or(base.idle_timeout),
297            handshake_timeout: self.handshake_timeout.or(base.handshake_timeout),
298            max_connections_per_ip: self.max_connections_per_ip.or(base.max_connections_per_ip),
299            origin_policy: self.origin_policy.or_else(|| base.origin_policy.clone()),
300        }
301    }
302
303    /// The configured per-IP connection cap, if any.
304    pub(crate) fn ip_connection_limit(&self) -> Option<usize> {
305        self.max_connections_per_ip
306    }
307
308    /// Maps the size limits onto a tungstenite config.
309    ///
310    /// Always returns a config: an unset message/frame size falls back to a secure
311    /// framework default (1 MiB) rather than tungstenite's 64 MiB, so a WebSocket
312    /// route is memory-bounded by default.
313    fn to_tungstenite(&self) -> Option<TgWebSocketConfig> {
314        Some(TgWebSocketConfig {
315            max_message_size: Some(self.max_message_size.unwrap_or(DEFAULT_WS_MAX_MESSAGE_SIZE)),
316            max_frame_size: Some(self.max_frame_size.unwrap_or(DEFAULT_WS_MAX_FRAME_SIZE)),
317            ..TgWebSocketConfig::default()
318        })
319    }
320}
321
322/// The application-wide default WebSocket configuration, stored in the state map.
323#[derive(Clone)]
324pub(crate) struct AppWsConfig(pub(crate) WebSocketConfig);
325
326/// A receiver, shared via the state map, that flips to `true` when the server
327/// begins a graceful shutdown so live WebSocket connections can close cleanly.
328#[derive(Clone)]
329pub(crate) struct WsShutdown(pub(crate) watch::Receiver<bool>);
330
331/// Tracks live WebSocket connections per client IP to cap how many a single
332/// client may hold open. Shared app-wide via the state map.
333#[derive(Clone)]
334pub(crate) struct WsIpLimiter {
335    counts: Arc<Mutex<HashMap<IpAddr, usize>>>,
336    max: usize,
337}
338
339impl WsIpLimiter {
340    pub(crate) fn new(max: usize) -> Self {
341        Self {
342            counts: Arc::new(Mutex::new(HashMap::new())),
343            max,
344        }
345    }
346
347    /// Reserves a connection slot for `ip`, returning a permit that releases it on
348    /// drop, or `None` if the client already holds the maximum.
349    fn try_acquire(&self, ip: IpAddr) -> Option<WsIpPermit> {
350        let mut counts = self.counts.lock().unwrap_or_else(|p| p.into_inner());
351        let count = counts.entry(ip).or_insert(0);
352        if *count >= self.max {
353            return None;
354        }
355        *count += 1;
356        Some(WsIpPermit {
357            counts: Arc::clone(&self.counts),
358            ip,
359        })
360    }
361}
362
363/// Releases an IP's reserved connection slot when the connection ends.
364struct WsIpPermit {
365    counts: Arc<Mutex<HashMap<IpAddr, usize>>>,
366    ip: IpAddr,
367}
368
369impl Drop for WsIpPermit {
370    fn drop(&mut self) {
371        let mut counts = self.counts.lock().unwrap_or_else(|p| p.into_inner());
372        if let Some(count) = counts.get_mut(&self.ip) {
373            *count -= 1;
374            if *count == 0 {
375                counts.remove(&self.ip);
376            }
377        }
378    }
379}
380
381/// Connection metadata shared by the lifecycle events.
382#[derive(Clone)]
383pub(crate) struct WsConnInfo {
384    method: Method,
385    path: String,
386    request_id: Option<String>,
387}
388
389/// Context for [`on_ws_connect`](crate::App::on_ws_connect): a socket opened.
390pub struct WsConnectInfo {
391    info: WsConnInfo,
392}
393
394impl WsConnectInfo {
395    pub(crate) fn new(info: WsConnInfo) -> Self {
396        Self { info }
397    }
398
399    /// The HTTP method of the upgrade request.
400    pub fn method(&self) -> &Method {
401        &self.info.method
402    }
403
404    /// The request path.
405    pub fn path(&self) -> &str {
406        &self.info.path
407    }
408
409    /// The request identifier (the `x-request-id` value), if present.
410    pub fn request_id(&self) -> Option<&str> {
411        self.info.request_id.as_deref()
412    }
413}
414
415/// Context for [`on_ws_disconnect`](crate::App::on_ws_disconnect): a socket closed.
416pub struct WsDisconnectInfo {
417    info: WsConnInfo,
418    duration: Duration,
419    close_code: Option<WsCloseCode>,
420}
421
422impl WsDisconnectInfo {
423    pub(crate) fn new(
424        info: WsConnInfo,
425        duration: Duration,
426        close_code: Option<WsCloseCode>,
427    ) -> Self {
428        Self {
429            info,
430            duration,
431            close_code,
432        }
433    }
434
435    /// The HTTP method of the upgrade request.
436    pub fn method(&self) -> &Method {
437        &self.info.method
438    }
439
440    /// The request path.
441    pub fn path(&self) -> &str {
442        &self.info.path
443    }
444
445    /// The request identifier (the `x-request-id` value), if present.
446    pub fn request_id(&self) -> Option<&str> {
447        self.info.request_id.as_deref()
448    }
449
450    /// How long the connection was open.
451    pub fn duration(&self) -> Duration {
452        self.duration
453    }
454
455    /// The close code, if the connection closed with one.
456    pub fn close_code(&self) -> Option<WsCloseCode> {
457        self.close_code
458    }
459}
460
461/// An observe-only `on_ws_connect` hook.
462pub(crate) type WsConnectHook = Box<dyn Fn(WsConnectInfo) -> BoxFuture<'static, ()> + Send + Sync>;
463/// An observe-only `on_ws_disconnect` hook.
464pub(crate) type WsDisconnectHook =
465    Box<dyn Fn(WsDisconnectInfo) -> BoxFuture<'static, ()> + Send + Sync>;
466
467/// The application's WebSocket lifecycle hooks, stored in the state map.
468#[derive(Default)]
469pub(crate) struct WsHooks {
470    pub(crate) connect: Vec<WsConnectHook>,
471    pub(crate) disconnect: Vec<WsDisconnectHook>,
472}
473
474/// A pending WebSocket upgrade.
475///
476/// Either a real upgrade negotiated by hyper on a live connection, or an
477/// in-memory duplex used by the in-process test client (no network).
478pub(crate) enum Upgrade {
479    /// A real upgrade from hyper.
480    Hyper(OnUpgrade),
481    /// An in-process duplex stream (test client). Constructed by the test client,
482    /// which lands in a later commit of this phase.
483    #[allow(dead_code)]
484    Duplex(DuplexStream),
485}
486
487/// The byte transport beneath a [`WebSocketConn`].
488///
489/// Both variants implement tokio's async IO traits, so the connection type stays
490/// concrete while supporting a real upgraded socket and an in-process duplex.
491enum WsTransport {
492    Upgraded(TokioIo<Upgraded>),
493    Duplex(DuplexStream),
494}
495
496impl AsyncRead for WsTransport {
497    fn poll_read(
498        self: Pin<&mut Self>,
499        cx: &mut Context<'_>,
500        buf: &mut ReadBuf<'_>,
501    ) -> Poll<std::io::Result<()>> {
502        match self.get_mut() {
503            WsTransport::Upgraded(io) => Pin::new(io).poll_read(cx, buf),
504            WsTransport::Duplex(io) => Pin::new(io).poll_read(cx, buf),
505        }
506    }
507}
508
509impl AsyncWrite for WsTransport {
510    fn poll_write(
511        self: Pin<&mut Self>,
512        cx: &mut Context<'_>,
513        buf: &[u8],
514    ) -> Poll<std::io::Result<usize>> {
515        match self.get_mut() {
516            WsTransport::Upgraded(io) => Pin::new(io).poll_write(cx, buf),
517            WsTransport::Duplex(io) => Pin::new(io).poll_write(cx, buf),
518        }
519    }
520
521    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
522        match self.get_mut() {
523            WsTransport::Upgraded(io) => Pin::new(io).poll_flush(cx),
524            WsTransport::Duplex(io) => Pin::new(io).poll_flush(cx),
525        }
526    }
527
528    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
529        match self.get_mut() {
530            WsTransport::Upgraded(io) => Pin::new(io).poll_shutdown(cx),
531            WsTransport::Duplex(io) => Pin::new(io).poll_shutdown(cx),
532        }
533    }
534}
535
536/// A WebSocket upgrade handle: call [`accept`](WebSocket::accept) to open it.
537pub struct WebSocket {
538    upgrade: Upgrade,
539    config: WebSocketConfig,
540    hooks: Arc<WsHooks>,
541    info: WsConnInfo,
542    permit: Option<WsIpPermit>,
543    shutdown: Option<watch::Receiver<bool>>,
544}
545
546impl WebSocket {
547    /// Claims the pending upgrade from the request context, merging the route's
548    /// config over the application default.
549    ///
550    /// This is generated-code support for `#[websocket]`, not part of the
551    /// everyday API. It errors (`NOT_AN_UPGRADE`) if the request is not a
552    /// WebSocket upgrade.
553    #[doc(hidden)]
554    pub fn from_request_context(ctx: &RequestContext, route: WebSocketConfig) -> Result<Self> {
555        let upgrade = ctx.take_upgrade()?;
556        let app_default = ctx
557            .state()
558            .get::<AppWsConfig>()
559            .map(|config| config.0.clone())
560            .unwrap_or_default();
561        let config = route.merge(&app_default);
562
563        // Enforce the per-IP connection cap before the upgrade, so an abusive
564        // client is rejected with a normal HTTP `429` rather than completing a
565        // socket. The permit is held for the connection's lifetime.
566        let permit = match (
567            config.max_connections_per_ip,
568            ctx.state().get::<WsIpLimiter>(),
569            ctx.peer_addr(),
570        ) {
571            (Some(_), Some(limiter), Some(peer)) => {
572                Some(limiter.try_acquire(peer.ip()).ok_or_else(|| {
573                    Error::too_many_requests("too many WebSocket connections from this client")
574                })?)
575            }
576            _ => None,
577        };
578
579        let hooks = ctx
580            .state()
581            .get::<WsHooks>()
582            .unwrap_or_else(|| Arc::new(WsHooks::default()));
583        let request_id = ctx
584            .headers()
585            .get(REQUEST_ID_HEADER)
586            .and_then(|value| value.to_str().ok())
587            .map(str::to_owned);
588        let info = WsConnInfo {
589            method: ctx.method().clone(),
590            path: ctx.uri().path().to_owned(),
591            request_id,
592        };
593        let shutdown = ctx.state().get::<WsShutdown>().map(|s| s.0.clone());
594        Ok(Self {
595            upgrade,
596            config,
597            hooks,
598            info,
599            permit,
600            shutdown,
601        })
602    }
603
604    /// Completes the upgrade and returns the live connection.
605    ///
606    /// Fires the `on_ws_connect` hooks once the socket is open.
607    pub async fn accept(self) -> Result<WebSocketConn> {
608        let idle_timeout = self.config.idle_timeout;
609        let handshake_timeout = self
610            .config
611            .handshake_timeout
612            .unwrap_or(DEFAULT_HANDSHAKE_TIMEOUT);
613        let transport = match self.upgrade {
614            Upgrade::Hyper(on_upgrade) => {
615                // Bound the handshake so a client that stalls after starting the
616                // upgrade cannot hold the pending connection open indefinitely.
617                let upgraded = tokio::time::timeout(handshake_timeout, on_upgrade)
618                    .await
619                    .map_err(|_| Error::internal("websocket upgrade timed out"))?
620                    .map_err(|error| {
621                        Error::internal(format!("websocket upgrade failed: {error}"))
622                    })?;
623                WsTransport::Upgraded(TokioIo::new(upgraded))
624            }
625            Upgrade::Duplex(duplex) => WsTransport::Duplex(duplex),
626        };
627        let stream =
628            WebSocketStream::from_raw_socket(transport, Role::Server, self.config.to_tungstenite())
629                .await;
630
631        for hook in self.hooks.connect.iter() {
632            hook(WsConnectInfo::new(self.info.clone())).await;
633        }
634
635        Ok(WebSocketConn {
636            stream,
637            idle_timeout,
638            hooks: Arc::downgrade(&self.hooks),
639            info: self.info,
640            started: Instant::now(),
641            close_code: None,
642            _permit: self.permit,
643            shutdown: self.shutdown,
644            hooks_fired: false,
645        })
646    }
647}
648
649/// A live WebSocket connection.
650pub struct WebSocketConn {
651    stream: WebSocketStream<WsTransport>,
652    idle_timeout: Option<Duration>,
653    hooks: Weak<WsHooks>,
654    info: WsConnInfo,
655    started: Instant,
656    close_code: Option<WsCloseCode>,
657    /// Held for the connection's lifetime; releases the per-IP slot on drop.
658    _permit: Option<WsIpPermit>,
659    /// Flips to `true` when the server starts shutting down, so [`recv`](WebSocketConn::recv)
660    /// can close the connection cleanly instead of being abruptly dropped.
661    shutdown: Option<watch::Receiver<bool>>,
662    /// Set once the disconnect hooks have run, so they fire exactly once whether
663    /// the connection closes through [`recv`](WebSocketConn::recv) or [`Drop`].
664    hooks_fired: bool,
665}
666
667impl Drop for WebSocketConn {
668    fn drop(&mut self) {
669        let Some(hooks) = self.hooks.upgrade() else {
670            return;
671        };
672        // The common close paths fire the hooks inline (awaited, runtime alive).
673        // Drop is only the fallback when the handler dropped the socket without
674        // closing it, e.g. an early return or a panic mid-stream.
675        if self.hooks_fired || hooks.disconnect.is_empty() {
676            return;
677        }
678        // Fire the disconnect hooks on a detached task (Drop cannot be async).
679        // Skipped when there is no current runtime, so non-server use is safe.
680        if let Ok(handle) = tokio::runtime::Handle::try_current() {
681            let info = self.info.clone();
682            let duration = self.started.elapsed();
683            let close_code = self.close_code;
684            handle.spawn(async move {
685                for hook in hooks.disconnect.iter() {
686                    hook(WsDisconnectInfo::new(info.clone(), duration, close_code)).await;
687                }
688            });
689        }
690    }
691}
692
693/// The result of one `recv` round: a frame, or a shutdown signal.
694enum RecvStep {
695    Shutdown,
696    Frame(FrameStep),
697}
698
699/// The outcome of awaiting the next frame from the socket.
700enum FrameStep {
701    Message(Message),
702    Error(tokio_tungstenite::tungstenite::Error),
703    /// The idle timeout elapsed.
704    Idle,
705    /// The stream ended.
706    Closed,
707}
708
709/// Awaits the next message from `stream`, honoring an optional idle timeout.
710async fn next_frame(
711    stream: &mut WebSocketStream<WsTransport>,
712    idle_timeout: Option<Duration>,
713) -> FrameStep {
714    let next = match idle_timeout {
715        Some(timeout) => match tokio::time::timeout(timeout, stream.next()).await {
716            Ok(item) => item,
717            Err(_elapsed) => return FrameStep::Idle,
718        },
719        None => stream.next().await,
720    };
721    match next {
722        Some(Ok(message)) => FrameStep::Message(message),
723        Some(Err(error)) => FrameStep::Error(error),
724        None => FrameStep::Closed,
725    }
726}
727
728impl WebSocketConn {
729    /// Receives the next message, or `None` once the connection is closed.
730    ///
731    /// Raw frames are not surfaced; ping and pong frames are returned so the
732    /// handler may observe them (the protocol layer answers pings on its own).
733    pub async fn recv(&mut self) -> Result<Option<WsMessage>> {
734        loop {
735            // If the server is already shutting down, close cleanly right away.
736            if self.shutdown.as_ref().is_some_and(|rx| *rx.borrow()) {
737                let _ = self.send_close_going_away().await;
738                self.fire_disconnect_hooks().await;
739                return Ok(None);
740            }
741
742            let step = {
743                let frame = next_frame(&mut self.stream, self.idle_timeout);
744                tokio::pin!(frame);
745                match &mut self.shutdown {
746                    // Race the next frame against the shutdown signal.
747                    Some(rx) => tokio::select! {
748                        biased;
749                        _ = rx.changed() => RecvStep::Shutdown,
750                        outcome = &mut frame => RecvStep::Frame(outcome),
751                    },
752                    None => RecvStep::Frame(frame.await),
753                }
754            };
755
756            match step {
757                RecvStep::Shutdown => {
758                    // Send a Going Away close so the client disconnects cleanly
759                    // rather than seeing the socket dropped mid-shutdown.
760                    let _ = self.send_close_going_away().await;
761                    self.fire_disconnect_hooks().await;
762                    return Ok(None);
763                }
764                RecvStep::Frame(FrameStep::Idle) | RecvStep::Frame(FrameStep::Closed) => {
765                    self.fire_disconnect_hooks().await;
766                    return Ok(None);
767                }
768                RecvStep::Frame(FrameStep::Error(error)) => return Err(connection_error(error)),
769                RecvStep::Frame(FrameStep::Message(message)) => {
770                    if let Some(message) = from_tungstenite(message) {
771                        if let WsMessage::Close(close) = &message {
772                            if let Some(close) = close {
773                                self.close_code = Some(close.code);
774                            }
775                            // The peer initiated the close; fire hooks now while
776                            // the runtime is alive, before the handler drops us.
777                            self.fire_disconnect_hooks().await;
778                        }
779                        return Ok(Some(message));
780                    }
781                    // A control frame the protocol layer handled; keep waiting.
782                }
783            }
784        }
785    }
786
787    /// Sends a `1001 Going Away` close frame (best effort).
788    async fn send_close_going_away(&mut self) -> Result<()> {
789        let close = Message::Close(Some(CloseFrame {
790            code: TgCloseCode::Away,
791            reason: "server shutting down".into(),
792        }));
793        self.stream.send(close).await.map_err(connection_error)
794    }
795
796    /// Runs the `on_ws_disconnect` hooks once, awaited in the connection's own
797    /// task so they cannot be lost to a detached [`Drop`] task during shutdown.
798    async fn fire_disconnect_hooks(&mut self) {
799        let Some(hooks) = self.hooks.upgrade() else {
800            self.hooks_fired = true;
801            return;
802        };
803        if self.hooks_fired || hooks.disconnect.is_empty() {
804            return;
805        }
806        self.hooks_fired = true;
807        let duration = self.started.elapsed();
808        for hook in hooks.disconnect.iter() {
809            hook(WsDisconnectInfo::new(
810                self.info.clone(),
811                duration,
812                self.close_code,
813            ))
814            .await;
815        }
816    }
817
818    /// Sends a message.
819    pub async fn send(&mut self, message: WsMessage) -> Result<()> {
820        self.stream
821            .send(into_tungstenite(message))
822            .await
823            .map_err(connection_error)
824    }
825
826    /// Sends a text message.
827    pub async fn send_text(&mut self, text: impl Into<String>) -> Result<()> {
828        self.send(WsMessage::Text(text.into())).await
829    }
830
831    /// Sends a binary message.
832    pub async fn send_binary(&mut self, bytes: impl Into<Vec<u8>>) -> Result<()> {
833        self.send(WsMessage::Binary(bytes.into())).await
834    }
835
836    /// Receives the next text message, skipping control frames.
837    ///
838    /// Returns `None` if the peer closes the connection.
839    pub async fn receive_text(&mut self) -> Result<Option<String>> {
840        while let Some(message) = self.recv().await? {
841            match message {
842                WsMessage::Text(text) => return Ok(Some(text)),
843                WsMessage::Close(_) => return Ok(None),
844                _ => continue,
845            }
846        }
847        Ok(None)
848    }
849
850    /// Receives the next message and deserializes it from JSON.
851    ///
852    /// Accepts a text or binary payload, skips control frames, and returns `None`
853    /// if the peer closes the connection. A malformed payload is a `400` error.
854    pub async fn receive_json<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
855        while let Some(message) = self.recv().await? {
856            let value = match message {
857                WsMessage::Text(text) => serde_json::from_str::<T>(&text),
858                WsMessage::Binary(bytes) => serde_json::from_slice::<T>(&bytes),
859                WsMessage::Close(_) => return Ok(None),
860                _ => continue,
861            };
862            return value
863                .map(Some)
864                .map_err(|error| Error::bad_request(format!("invalid JSON message: {error}")));
865        }
866        Ok(None)
867    }
868
869    /// Receives the next message, deserializes it from JSON, and validates it.
870    ///
871    /// Like [`receive_json`](WebSocketConn::receive_json) but also runs the
872    /// type's `garde` validation; an invalid message is a `422` error whose body
873    /// lists the offending fields. Returns `None` if the peer closes.
874    pub async fn receive_valid<T>(&mut self) -> Result<Option<T>>
875    where
876        T: DeserializeOwned + Validate<Context = ()>,
877    {
878        while let Some(message) = self.recv().await? {
879            return match message {
880                WsMessage::Text(text) => deserialize_and_validate::<T>(text.as_bytes()).map(Some),
881                WsMessage::Binary(bytes) => deserialize_and_validate::<T>(&bytes).map(Some),
882                WsMessage::Close(_) => Ok(None),
883                _ => continue,
884            };
885        }
886        Ok(None)
887    }
888
889    /// Serializes `value` to JSON and sends it as a text message.
890    pub async fn send_json<T: Serialize>(&mut self, value: &T) -> Result<()> {
891        let text = serde_json::to_string(value)
892            .map_err(|error| Error::internal(format!("failed to serialize message: {error}")))?;
893        self.send_text(text).await
894    }
895
896    /// Closes the connection with a status code and reason.
897    pub async fn close(&mut self, code: WsCloseCode, reason: impl Into<String>) -> Result<()> {
898        self.close_code = Some(code);
899        self.send(WsMessage::Close(Some(WsClose {
900            code,
901            reason: reason.into(),
902        })))
903        .await?;
904        SinkExt::close(&mut self.stream)
905            .await
906            .map_err(connection_error)
907    }
908}
909
910/// Validates a WebSocket handshake and builds the `101 Switching Protocols`
911/// response.
912///
913/// This is generated-code support for `#[websocket]`, not part of the everyday
914/// API. A request that is not a valid WebSocket upgrade is rejected with a
915/// `400 Bad Request` (code `NOT_A_WEBSOCKET`), before the connection is opened.
916#[doc(hidden)]
917pub fn __ws_handshake(ctx: &RequestContext, route: WebSocketConfig) -> Result<Response> {
918    validate_origin(ctx, &route)?;
919    let headers = ctx.headers();
920
921    let is_websocket = headers
922        .get(UPGRADE)
923        .and_then(|value| value.to_str().ok())
924        .is_some_and(|value| value.eq_ignore_ascii_case("websocket"));
925    if !is_websocket {
926        return Err(Error::bad_request("expected a WebSocket upgrade").with_code(NOT_A_WEBSOCKET));
927    }
928
929    let connection_upgrade = headers
930        .get(CONNECTION)
931        .and_then(|value| value.to_str().ok())
932        .is_some_and(|value| value.to_ascii_lowercase().contains("upgrade"));
933    if !connection_upgrade {
934        return Err(
935            Error::bad_request("WebSocket upgrade requires Connection: upgrade")
936                .with_code(NOT_A_WEBSOCKET),
937        );
938    }
939
940    let version_ok = headers
941        .get(SEC_WEBSOCKET_VERSION)
942        .and_then(|value| value.to_str().ok())
943        .is_some_and(|value| value == WEBSOCKET_VERSION);
944    if !version_ok {
945        return Err(Error::bad_request("unsupported WebSocket version").with_code(NOT_A_WEBSOCKET));
946    }
947
948    let key = headers.get(SEC_WEBSOCKET_KEY).ok_or_else(|| {
949        Error::bad_request("missing Sec-WebSocket-Key").with_code(NOT_A_WEBSOCKET)
950    })?;
951    let accept = derive_accept_key(key.as_bytes());
952    let accept = HeaderValue::from_str(&accept)
953        .map_err(|_| Error::internal("failed to build WebSocket accept header"))?;
954
955    let mut response = http::Response::new(RespBody::new(Bytes::new()));
956    *response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
957    let headers = response.headers_mut();
958    headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
959    headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
960    headers.insert(SEC_WEBSOCKET_ACCEPT, accept);
961    Ok(response)
962}
963
964fn validate_origin(ctx: &RequestContext, route: &WebSocketConfig) -> Result<()> {
965    let Some(origin) = ctx
966        .headers()
967        .get(ORIGIN)
968        .and_then(|value| value.to_str().ok())
969    else {
970        return Ok(());
971    };
972
973    let policy = effective_config(ctx, route).origin_policy;
974    match policy {
975        Some(WsOriginPolicy::Any) => Ok(()),
976        Some(WsOriginPolicy::Allowlist(allowed)) => {
977            let actual = parse_origin(origin).ok_or_else(|| {
978                Error::forbidden("websocket origin is not allowed").with_code("WS_ORIGIN_FORBIDDEN")
979            })?;
980            let matches = allowed
981                .iter()
982                .filter_map(|origin| parse_origin(origin))
983                .any(|allowed| allowed == actual);
984            if matches {
985                Ok(())
986            } else {
987                Err(Error::forbidden("websocket origin is not allowed")
988                    .with_code("WS_ORIGIN_FORBIDDEN"))
989            }
990        }
991        None => {
992            let actual = parse_origin(origin).ok_or_else(|| {
993                Error::forbidden("websocket origin is not allowed").with_code("WS_ORIGIN_FORBIDDEN")
994            })?;
995            let expected = expected_same_origin(ctx).ok_or_else(|| {
996                Error::forbidden("websocket origin is not allowed").with_code("WS_ORIGIN_FORBIDDEN")
997            })?;
998            if actual == expected {
999                Ok(())
1000            } else {
1001                Err(Error::forbidden("websocket origin is not allowed")
1002                    .with_code("WS_ORIGIN_FORBIDDEN"))
1003            }
1004        }
1005    }
1006}
1007
1008fn effective_config(ctx: &RequestContext, route: &WebSocketConfig) -> WebSocketConfig {
1009    let base = ctx
1010        .state()
1011        .get::<AppWsConfig>()
1012        .map(|config| config.0.clone())
1013        .unwrap_or_default();
1014    route.clone().merge(&base)
1015}
1016
1017#[derive(Clone, PartialEq, Eq)]
1018struct ParsedOrigin {
1019    scheme: &'static str,
1020    host: String,
1021    port: u16,
1022}
1023
1024fn parse_origin(origin: &str) -> Option<ParsedOrigin> {
1025    let uri: http::Uri = origin.parse().ok()?;
1026    let scheme = match uri.scheme_str()? {
1027        "http" => "http",
1028        "https" => "https",
1029        _ => return None,
1030    };
1031    let authority = uri.authority()?;
1032    Some(ParsedOrigin {
1033        scheme,
1034        host: authority.host().to_ascii_lowercase(),
1035        port: authority.port_u16().unwrap_or(default_port(scheme)),
1036    })
1037}
1038
1039fn expected_same_origin(ctx: &RequestContext) -> Option<ParsedOrigin> {
1040    let scheme = scheme_from_extensions(&ctx.head().extensions)
1041        .unwrap_or(RequestScheme::Http)
1042        .as_str();
1043    let host = ctx.headers().get(HOST)?.to_str().ok()?;
1044    let authority: http::uri::Authority = host.parse().ok()?;
1045    Some(ParsedOrigin {
1046        scheme,
1047        host: authority.host().to_ascii_lowercase(),
1048        port: authority.port_u16().unwrap_or(default_port(scheme)),
1049    })
1050}
1051
1052fn default_port(scheme: &str) -> u16 {
1053    if scheme == "https" {
1054        443
1055    } else {
1056        80
1057    }
1058}
1059
1060/// Deserializes a JSON message and runs its `garde` validation.
1061fn deserialize_and_validate<T>(bytes: &[u8]) -> Result<T>
1062where
1063    T: DeserializeOwned + Validate<Context = ()>,
1064{
1065    let value: T = serde_json::from_slice(bytes)
1066        .map_err(|error| Error::unprocessable(format!("invalid JSON message: {error}")))?;
1067    value.validate().map_err(Error::from_garde_report)?;
1068    Ok(value)
1069}
1070
1071/// Maps a framework message to a tungstenite message.
1072pub(crate) fn into_tungstenite(message: WsMessage) -> Message {
1073    match message {
1074        WsMessage::Text(text) => Message::Text(text),
1075        WsMessage::Binary(bytes) => Message::Binary(bytes),
1076        WsMessage::Ping(bytes) => Message::Ping(bytes),
1077        WsMessage::Pong(bytes) => Message::Pong(bytes),
1078        WsMessage::Close(close) => Message::Close(close.map(|close| CloseFrame {
1079            code: TgCloseCode::from(close.code.as_u16()),
1080            reason: Cow::Owned(close.reason),
1081        })),
1082    }
1083}
1084
1085/// Maps a tungstenite message to a framework message, dropping raw frames.
1086pub(crate) fn from_tungstenite(message: Message) -> Option<WsMessage> {
1087    match message {
1088        Message::Text(text) => Some(WsMessage::Text(text)),
1089        Message::Binary(bytes) => Some(WsMessage::Binary(bytes)),
1090        Message::Ping(bytes) => Some(WsMessage::Ping(bytes)),
1091        Message::Pong(bytes) => Some(WsMessage::Pong(bytes)),
1092        Message::Close(close) => Some(WsMessage::Close(close.map(|close| WsClose {
1093            code: WsCloseCode::from_u16(u16::from(close.code)),
1094            reason: close.reason.into_owned(),
1095        }))),
1096        Message::Frame(_) => None,
1097    }
1098}
1099
1100/// Renders a tungstenite protocol error as a framework error.
1101pub(crate) fn connection_error(error: tokio_tungstenite::tungstenite::Error) -> Error {
1102    Error::internal(format!("websocket connection error: {error}")).with_code("WS_CONNECTION_ERROR")
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107    use super::*;
1108    use crate::body::box_body;
1109    use crate::extract::PathParams;
1110    use crate::state::StateMap;
1111    use bytes::Bytes;
1112    use futures_util::{SinkExt, StreamExt};
1113    use http_body_util::Full;
1114    use std::sync::Mutex;
1115    use tokio_tungstenite::tungstenite::protocol::Role;
1116
1117    fn request_context(headers: &[(&str, &str)]) -> RequestContext {
1118        let mut builder = http::Request::builder().method(Method::GET).uri("/ws");
1119        for (name, value) in headers {
1120            builder = builder.header(*name, *value);
1121        }
1122        let head = builder.body(()).unwrap().into_parts().0;
1123        RequestContext::new(
1124            head,
1125            PathParams::new(),
1126            Arc::new(StateMap::new()),
1127            box_body(Full::new(Bytes::new())),
1128        )
1129    }
1130
1131    fn request_context_with_duplex(
1132        headers: &[(&str, &str)],
1133        config: Option<WebSocketConfig>,
1134        hooks: Option<WsHooks>,
1135    ) -> (RequestContext, DuplexStream) {
1136        let mut builder = http::Request::builder().method(Method::GET).uri("/ws");
1137        for (name, value) in headers {
1138            builder = builder.header(*name, *value);
1139        }
1140        let head = builder.body(()).unwrap().into_parts().0;
1141        let mut state = StateMap::new();
1142        if let Some(config) = config {
1143            state.insert(AppWsConfig(config));
1144        }
1145        if let Some(hooks) = hooks {
1146            state.insert(hooks);
1147        }
1148        let (client, server) = tokio::io::duplex(64 * 1024);
1149        let ctx = RequestContext::with_duplex_upgrade(
1150            head,
1151            PathParams::new(),
1152            Arc::new(state),
1153            box_body(Full::new(Bytes::new())),
1154            server,
1155        );
1156        (ctx, client)
1157    }
1158
1159    fn websocket_headers() -> [(&'static str, &'static str); 4] {
1160        [
1161            ("upgrade", "websocket"),
1162            ("connection", "keep-alive, Upgrade"),
1163            ("sec-websocket-version", "13"),
1164            ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1165        ]
1166    }
1167
1168    fn default_route_config() -> WebSocketConfig {
1169        WebSocketConfig::new()
1170    }
1171
1172    #[test]
1173    fn close_code_round_trips_through_u16() {
1174        for code in [
1175            WsCloseCode::NormalClosure,
1176            WsCloseCode::GoingAway,
1177            WsCloseCode::ProtocolError,
1178            WsCloseCode::UnsupportedData,
1179            WsCloseCode::PolicyViolation,
1180            WsCloseCode::MessageTooBig,
1181            WsCloseCode::InternalError,
1182            WsCloseCode::Other(4000),
1183        ] {
1184            assert_eq!(WsCloseCode::from_u16(code.as_u16()), code);
1185        }
1186    }
1187
1188    #[test]
1189    fn messages_map_to_and_from_tungstenite() {
1190        let cases = [
1191            WsMessage::Text("hello".to_owned()),
1192            WsMessage::Binary(vec![1, 2, 3]),
1193            WsMessage::Ping(vec![9]),
1194            WsMessage::Pong(vec![8]),
1195            WsMessage::Close(Some(WsClose {
1196                code: WsCloseCode::NormalClosure,
1197                reason: "bye".to_owned(),
1198            })),
1199        ];
1200        for message in cases {
1201            let round = from_tungstenite(into_tungstenite(message.clone()));
1202            assert_eq!(round, Some(message));
1203        }
1204    }
1205
1206    #[test]
1207    fn config_merge_prefers_route_over_app() {
1208        let app = WebSocketConfig::new()
1209            .max_message_size(1000)
1210            .idle_timeout_secs(30);
1211        let route = WebSocketConfig::new().max_message_size(2000);
1212
1213        let merged = route.merge(&app);
1214        assert_eq!(merged.max_message_size, Some(2000), "route value wins");
1215        assert_eq!(merged.max_frame_size, None);
1216        assert_eq!(
1217            merged.idle_timeout,
1218            Some(Duration::from_secs(30)),
1219            "app default is kept where the route is unset"
1220        );
1221    }
1222
1223    #[test]
1224    fn ws_error_maps_to_an_http_status() {
1225        let error: Error = WsError::policy_violation("no token").into();
1226        assert_eq!(error.kind(), crate::ErrorKind::Forbidden);
1227        assert_eq!(error.code(), "WS_REJECTED");
1228
1229        let too_large: Error = WsError::new(WsCloseCode::MessageTooBig, "big").into();
1230        assert_eq!(too_large.kind(), crate::ErrorKind::PayloadTooLarge);
1231
1232        let internal = WsError::internal("boom");
1233        assert_eq!(internal.code(), WsCloseCode::InternalError);
1234        assert_eq!(internal.message(), "boom");
1235        assert_eq!(internal.to_string(), "boom");
1236    }
1237
1238    #[test]
1239    fn disconnect_info_exposes_duration_and_close_code() {
1240        let info = WsConnInfo {
1241            method: Method::GET,
1242            path: "/ws".to_owned(),
1243            request_id: Some("req-1".to_owned()),
1244        };
1245        let event = WsDisconnectInfo::new(
1246            info,
1247            Duration::from_secs(3),
1248            Some(WsCloseCode::NormalClosure),
1249        );
1250        assert_eq!(event.path(), "/ws");
1251        assert_eq!(event.method(), &Method::GET);
1252        assert_eq!(event.request_id(), Some("req-1"));
1253        assert_eq!(event.duration(), Duration::from_secs(3));
1254        assert_eq!(event.close_code(), Some(WsCloseCode::NormalClosure));
1255    }
1256
1257    #[test]
1258    fn websocket_config_builders_and_connect_info_accessors_work() {
1259        let config = WebSocketConfig::new()
1260            .max_message_size_kb(2)
1261            .max_frame_size_kb(1)
1262            .idle_timeout_secs(3);
1263        let tungstenite = config.to_tungstenite().expect("limits should be present");
1264        assert_eq!(tungstenite.max_message_size, Some(2 * 1024));
1265        assert_eq!(tungstenite.max_frame_size, Some(1024));
1266        assert_eq!(config.idle_timeout, Some(Duration::from_secs(3)));
1267        // With nothing set, the secure framework defaults apply (not None / the
1268        // 64 MiB tungstenite default).
1269        let defaults = WebSocketConfig::new()
1270            .to_tungstenite()
1271            .expect("defaults should be present");
1272        assert_eq!(defaults.max_message_size, Some(DEFAULT_WS_MAX_MESSAGE_SIZE));
1273        assert_eq!(defaults.max_frame_size, Some(DEFAULT_WS_MAX_FRAME_SIZE));
1274
1275        let info = WsConnInfo {
1276            method: Method::POST,
1277            path: "/chat".to_owned(),
1278            request_id: Some("req-9".to_owned()),
1279        };
1280        let connect = WsConnectInfo::new(info);
1281        assert_eq!(connect.method(), &Method::POST);
1282        assert_eq!(connect.path(), "/chat");
1283        assert_eq!(connect.request_id(), Some("req-9"));
1284    }
1285
1286    #[test]
1287    fn handshake_validates_required_headers() {
1288        let ctx = request_context(&[]);
1289        let error = match __ws_handshake(&ctx, default_route_config()) {
1290            Ok(_) => panic!("expected handshake rejection"),
1291            Err(error) => error,
1292        };
1293        assert_eq!(error.code(), NOT_A_WEBSOCKET);
1294        assert_eq!(error.message(), "expected a WebSocket upgrade");
1295
1296        let ctx = request_context(&[
1297            ("upgrade", "websocket"),
1298            ("sec-websocket-version", "13"),
1299            ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1300        ]);
1301        let error = match __ws_handshake(&ctx, default_route_config()) {
1302            Ok(_) => panic!("expected handshake rejection"),
1303            Err(error) => error,
1304        };
1305        assert_eq!(
1306            error.message(),
1307            "WebSocket upgrade requires Connection: upgrade"
1308        );
1309
1310        let ctx = request_context(&[
1311            ("upgrade", "websocket"),
1312            ("connection", "upgrade"),
1313            ("sec-websocket-version", "12"),
1314            ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1315        ]);
1316        let error = match __ws_handshake(&ctx, default_route_config()) {
1317            Ok(_) => panic!("expected handshake rejection"),
1318            Err(error) => error,
1319        };
1320        assert_eq!(error.message(), "unsupported WebSocket version");
1321
1322        let ctx = request_context(&[
1323            ("upgrade", "websocket"),
1324            ("connection", "upgrade"),
1325            ("sec-websocket-version", "13"),
1326        ]);
1327        let error = match __ws_handshake(&ctx, default_route_config()) {
1328            Ok(_) => panic!("expected handshake rejection"),
1329            Err(error) => error,
1330        };
1331        assert_eq!(error.message(), "missing Sec-WebSocket-Key");
1332    }
1333
1334    #[test]
1335    fn handshake_builds_switching_protocols_response() {
1336        let ctx = request_context(&websocket_headers());
1337        let response = __ws_handshake(&ctx, default_route_config()).unwrap();
1338        assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
1339        assert_eq!(response.headers()[UPGRADE], "websocket");
1340        assert_eq!(response.headers()[CONNECTION], "upgrade");
1341        assert!(response.headers().contains_key(SEC_WEBSOCKET_ACCEPT));
1342    }
1343
1344    #[test]
1345    fn handshake_rejects_cross_origin_by_default_and_accepts_same_origin() {
1346        let ctx = request_context(&[
1347            ("host", "example.com"),
1348            ("origin", "https://evil.example.com"),
1349            ("upgrade", "websocket"),
1350            ("connection", "upgrade"),
1351            ("sec-websocket-version", "13"),
1352            ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1353        ]);
1354        let error = match __ws_handshake(&ctx, default_route_config()) {
1355            Ok(_) => panic!("expected handshake rejection"),
1356            Err(error) => error,
1357        };
1358        assert_eq!(error.kind(), crate::ErrorKind::Forbidden);
1359        assert_eq!(error.code(), "WS_ORIGIN_FORBIDDEN");
1360
1361        let mut head = http::Request::builder()
1362            .method(Method::GET)
1363            .uri("/ws")
1364            .header("host", "example.com")
1365            .header("origin", "https://example.com")
1366            .header("upgrade", "websocket")
1367            .header("connection", "upgrade")
1368            .header("sec-websocket-version", "13")
1369            .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
1370            .body(())
1371            .unwrap()
1372            .into_parts()
1373            .0;
1374        head.extensions.insert(RequestScheme::Https);
1375        let ctx = RequestContext::new(
1376            head,
1377            PathParams::new(),
1378            Arc::new(StateMap::new()),
1379            box_body(Full::new(Bytes::new())),
1380        );
1381        let response = __ws_handshake(&ctx, default_route_config()).unwrap();
1382        assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
1383    }
1384
1385    #[test]
1386    fn allowlists_and_allow_any_origin_override_same_origin_policy() {
1387        let ctx = request_context(&[
1388            ("host", "example.com"),
1389            ("origin", "https://evil.example.com"),
1390            ("upgrade", "websocket"),
1391            ("connection", "upgrade"),
1392            ("sec-websocket-version", "13"),
1393            ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1394        ]);
1395
1396        let response = __ws_handshake(
1397            &ctx,
1398            WebSocketConfig::new().allow_origin("https://evil.example.com"),
1399        )
1400        .unwrap();
1401        assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
1402
1403        let response = __ws_handshake(&ctx, WebSocketConfig::new().allow_any_origin()).unwrap();
1404        assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
1405    }
1406
1407    #[test]
1408    fn from_request_context_merges_config_and_captures_request_metadata() {
1409        let hooks = WsHooks::default();
1410        let (ctx, _client) = request_context_with_duplex(
1411            &[
1412                ("x-request-id", "req-2"),
1413                ("upgrade", "websocket"),
1414                ("connection", "upgrade"),
1415                ("sec-websocket-version", "13"),
1416                ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1417            ],
1418            Some(WebSocketConfig::new().max_frame_size(64)),
1419            Some(hooks),
1420        );
1421
1422        let socket = WebSocket::from_request_context(
1423            &ctx,
1424            WebSocketConfig::new()
1425                .max_message_size(128)
1426                .idle_timeout(Duration::from_secs(2)),
1427        )
1428        .unwrap();
1429
1430        assert_eq!(socket.config.max_message_size, Some(128));
1431        assert_eq!(socket.config.max_frame_size, Some(64));
1432        assert_eq!(socket.config.idle_timeout, Some(Duration::from_secs(2)));
1433        assert_eq!(socket.info.path, "/ws");
1434        assert_eq!(socket.info.request_id.as_deref(), Some("req-2"));
1435        assert!(socket.hooks.connect.is_empty());
1436        assert!(socket.hooks.disconnect.is_empty());
1437    }
1438
1439    #[derive(Debug, PartialEq, Eq, serde::Deserialize, garde::Validate)]
1440    struct ChatIn {
1441        #[garde(length(min = 1))]
1442        message: String,
1443    }
1444
1445    #[test]
1446    fn deserialize_and_validate_accepts_valid_and_rejects_invalid() {
1447        let ok = deserialize_and_validate::<ChatIn>(br#"{"message":"hi"}"#);
1448        assert!(ok.is_ok());
1449
1450        let empty = deserialize_and_validate::<ChatIn>(br#"{"message":""}"#);
1451        assert_eq!(empty.err().unwrap().kind(), crate::ErrorKind::Unprocessable);
1452
1453        let malformed = deserialize_and_validate::<ChatIn>(b"not json");
1454        assert_eq!(
1455            malformed.err().unwrap().kind(),
1456            crate::ErrorKind::Unprocessable
1457        );
1458    }
1459
1460    #[tokio::test]
1461    async fn duplex_accept_runs_hooks_and_exchanges_messages() {
1462        let connects = Arc::new(Mutex::new(Vec::new()));
1463        let disconnects = Arc::new(Mutex::new(Vec::new()));
1464        let hooks = WsHooks {
1465            connect: vec![Box::new({
1466                let connects = connects.clone();
1467                move |info| {
1468                    let connects = connects.clone();
1469                    Box::pin(async move {
1470                        connects.lock().unwrap().push((
1471                            info.method().clone(),
1472                            info.path().to_owned(),
1473                            info.request_id().map(str::to_owned),
1474                        ));
1475                    })
1476                }
1477            })],
1478            disconnect: vec![Box::new({
1479                let disconnects = disconnects.clone();
1480                move |info| {
1481                    let disconnects = disconnects.clone();
1482                    Box::pin(async move {
1483                        disconnects
1484                            .lock()
1485                            .unwrap()
1486                            .push((info.path().to_owned(), info.close_code()));
1487                    })
1488                }
1489            })],
1490        };
1491        let headers = [
1492            ("x-request-id", "req-hook"),
1493            ("upgrade", "websocket"),
1494            ("connection", "upgrade"),
1495            ("sec-websocket-version", "13"),
1496            ("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ=="),
1497        ];
1498        let (ctx, client_io) = request_context_with_duplex(&headers, None, Some(hooks));
1499        let socket = WebSocket::from_request_context(&ctx, WebSocketConfig::new()).unwrap();
1500        let mut conn = socket.accept().await.unwrap();
1501        let mut client = WebSocketStream::from_raw_socket(client_io, Role::Client, None).await;
1502
1503        client.send(Message::Text("hello".into())).await.unwrap();
1504        assert_eq!(conn.receive_text().await.unwrap(), Some("hello".to_owned()));
1505
1506        conn.send_json(&serde_json::json!({ "ok": true }))
1507            .await
1508            .unwrap();
1509        let message = client.next().await.unwrap().unwrap();
1510        assert_eq!(message.into_text().unwrap(), r#"{"ok":true}"#);
1511
1512        conn.close(WsCloseCode::NormalClosure, "bye").await.unwrap();
1513        match client.next().await.unwrap().unwrap() {
1514            Message::Close(Some(close)) => {
1515                assert_eq!(u16::from(close.code), 1000);
1516                assert_eq!(close.reason, "bye");
1517            }
1518            other => panic!("expected close frame, got {other:?}"),
1519        }
1520        drop(conn);
1521        tokio::task::yield_now().await;
1522
1523        assert_eq!(
1524            connects.lock().unwrap().as_slice(),
1525            &[(Method::GET, "/ws".to_owned(), Some("req-hook".to_owned()))]
1526        );
1527        assert_eq!(
1528            disconnects.lock().unwrap().as_slice(),
1529            &[("/ws".to_owned(), Some(WsCloseCode::NormalClosure))]
1530        );
1531    }
1532
1533    #[tokio::test]
1534    async fn duplex_connection_helpers_cover_close_idle_and_validation_paths() {
1535        let (ctx, client_io) = request_context_with_duplex(&websocket_headers(), None, None);
1536        let socket = WebSocket::from_request_context(
1537            &ctx,
1538            WebSocketConfig::new().idle_timeout(Duration::from_millis(10)),
1539        )
1540        .unwrap();
1541        let mut conn = socket.accept().await.unwrap();
1542        let mut client = WebSocketStream::from_raw_socket(client_io, Role::Client, None).await;
1543
1544        client.send(Message::Ping(vec![1, 2])).await.unwrap();
1545        client
1546            .send(Message::Text("{\"message\":\"ok\"}".into()))
1547            .await
1548            .unwrap();
1549        let validated = conn.receive_valid::<ChatIn>().await.unwrap().unwrap();
1550        assert_eq!(validated.message, "ok");
1551
1552        client
1553            .send(Message::Binary(br#"{"message":""}"#.to_vec()))
1554            .await
1555            .unwrap();
1556        let error = match conn.receive_valid::<ChatIn>().await {
1557            Ok(_) => panic!("expected validation error"),
1558            Err(error) => error,
1559        };
1560        assert_eq!(error.kind(), crate::ErrorKind::Unprocessable);
1561
1562        client.send(Message::Text("not-json".into())).await.unwrap();
1563        let error = match conn.receive_json::<ChatIn>().await {
1564            Ok(_) => panic!("expected decode error"),
1565            Err(error) => error,
1566        };
1567        assert_eq!(error.kind(), crate::ErrorKind::BadRequest);
1568
1569        client.close(None).await.unwrap();
1570        assert_eq!(conn.receive_text().await.unwrap(), None);
1571        assert_eq!(conn.receive_json::<ChatIn>().await.unwrap(), None);
1572        assert_eq!(conn.receive_valid::<ChatIn>().await.unwrap(), None);
1573
1574        let (ctx, _client_io) = request_context_with_duplex(&websocket_headers(), None, None);
1575        let socket = WebSocket::from_request_context(
1576            &ctx,
1577            WebSocketConfig::new().idle_timeout(Duration::from_millis(5)),
1578        )
1579        .unwrap();
1580        let mut idle_conn = socket.accept().await.unwrap();
1581        assert_eq!(idle_conn.recv().await.unwrap(), None);
1582    }
1583
1584    #[test]
1585    fn frame_and_connection_errors_map_to_expected_results() {
1586        let error = connection_error(tokio_tungstenite::tungstenite::Error::ConnectionClosed);
1587        assert_eq!(error.code(), "WS_CONNECTION_ERROR");
1588        assert!(error.message().contains("websocket connection error:"));
1589    }
1590
1591    #[test]
1592    fn ws_ip_limiter_caps_per_ip_and_releases_on_drop() {
1593        use std::net::Ipv4Addr;
1594
1595        let limiter = WsIpLimiter::new(2);
1596        let ip = IpAddr::V4(Ipv4Addr::LOCALHOST);
1597
1598        let first = limiter.try_acquire(ip).expect("first is under the limit");
1599        let _second = limiter.try_acquire(ip).expect("second reaches the limit");
1600        assert!(
1601            limiter.try_acquire(ip).is_none(),
1602            "a third connection from the same IP is rejected"
1603        );
1604
1605        // A different IP has its own budget.
1606        let other = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
1607        assert!(limiter.try_acquire(other).is_some());
1608
1609        // Releasing a permit frees a slot for that IP again.
1610        drop(first);
1611        assert!(
1612            limiter.try_acquire(ip).is_some(),
1613            "dropping a connection frees a slot"
1614        );
1615    }
1616
1617    #[test]
1618    fn route_config_overrides_app_defaults_for_new_limits() {
1619        let app = WebSocketConfig::new()
1620            .handshake_timeout(Duration::from_secs(5))
1621            .max_connections_per_ip(10);
1622        let route = WebSocketConfig::new().max_connections_per_ip(3);
1623
1624        let merged = route.merge(&app);
1625        assert_eq!(merged.ip_connection_limit(), Some(3), "route wins");
1626        assert_eq!(
1627            merged.handshake_timeout,
1628            Some(Duration::from_secs(5)),
1629            "unset on the route, taken from the app default"
1630        );
1631    }
1632}