Skip to main content

tork_core/testing/
websocket.rs

1//! The in-process WebSocket test client.
2
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use futures_util::{SinkExt, StreamExt};
8use http::header::{
9    CONNECTION, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE,
10};
11use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
12use serde::de::DeserializeOwned;
13use serde::Serialize;
14use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf};
15use tokio::net::TcpStream;
16use tokio_tungstenite::tungstenite::client::IntoClientRequest;
17use tokio_tungstenite::tungstenite::protocol::Role;
18use tokio_tungstenite::WebSocketStream;
19
20use super::client::{Shared, TestHeader, Transport};
21use crate::body::box_body;
22use crate::error::{Error, Result};
23use crate::ws::{
24    connection_error, from_tungstenite, into_tungstenite, WsClose, WsCloseCode, WsMessage,
25};
26
27/// Buffer size for the in-process duplex connecting client and server.
28const WS_DUPLEX_BUFFER: usize = 64 * 1024;
29/// A fixed, valid `Sec-WebSocket-Key`. The framing handshake is done in process,
30/// so the value only needs to be present and well formed.
31const WS_TEST_KEY: &str = "dGhlIHNhbXBsZSBub25jZQ==";
32
33/// The client side of a test WebSocket transport.
34///
35/// In-process tests use a duplex stream; a real-port variant is added later.
36pub(crate) enum ClientIo {
37    Duplex(DuplexStream),
38    Tcp(TcpStream),
39}
40
41impl AsyncRead for ClientIo {
42    fn poll_read(
43        self: Pin<&mut Self>,
44        cx: &mut Context<'_>,
45        buf: &mut ReadBuf<'_>,
46    ) -> Poll<std::io::Result<()>> {
47        match self.get_mut() {
48            ClientIo::Duplex(io) => Pin::new(io).poll_read(cx, buf),
49            ClientIo::Tcp(io) => Pin::new(io).poll_read(cx, buf),
50        }
51    }
52}
53
54impl AsyncWrite for ClientIo {
55    fn poll_write(
56        self: Pin<&mut Self>,
57        cx: &mut Context<'_>,
58        buf: &[u8],
59    ) -> Poll<std::io::Result<usize>> {
60        match self.get_mut() {
61            ClientIo::Duplex(io) => Pin::new(io).poll_write(cx, buf),
62            ClientIo::Tcp(io) => Pin::new(io).poll_write(cx, buf),
63        }
64    }
65
66    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
67        match self.get_mut() {
68            ClientIo::Duplex(io) => Pin::new(io).poll_flush(cx),
69            ClientIo::Tcp(io) => Pin::new(io).poll_flush(cx),
70        }
71    }
72
73    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
74        match self.get_mut() {
75            ClientIo::Duplex(io) => Pin::new(io).poll_shutdown(cx),
76            ClientIo::Tcp(io) => Pin::new(io).poll_shutdown(cx),
77        }
78    }
79}
80
81/// Builds a WebSocket connection: set headers, query parameters, and
82/// subprotocols, then call [`connect`](TestWebSocketBuilder::connect).
83pub struct TestWebSocketBuilder {
84    shared: Arc<Shared>,
85    path: String,
86    query: Vec<(String, String)>,
87    headers: Vec<TestHeader>,
88    subprotocols: Vec<String>,
89}
90
91impl TestWebSocketBuilder {
92    pub(crate) fn new(shared: Arc<Shared>, path: impl Into<String>) -> Self {
93        Self {
94            shared,
95            path: path.into(),
96            query: Vec::new(),
97            headers: Vec::new(),
98            subprotocols: Vec::new(),
99        }
100    }
101
102    /// Adds a header to the upgrade request.
103    pub fn header(mut self, name: &str, value: &str) -> Self {
104        if let (Ok(name), Ok(value)) = (
105            HeaderName::from_bytes(name.as_bytes()),
106            HeaderValue::from_str(value),
107        ) {
108            self.headers.push(TestHeader::safe(name, value));
109        }
110        self
111    }
112
113    /// Adds a security-sensitive header to the upgrade request, bypassing the
114    /// in-process guard.
115    pub fn unsafe_header(mut self, name: &str, value: &str) -> Self {
116        if let (Ok(name), Ok(value)) = (
117            HeaderName::from_bytes(name.as_bytes()),
118            HeaderValue::from_str(value),
119        ) {
120            self.headers.push(TestHeader::unsafe_allowed(name, value));
121        }
122        self
123    }
124
125    /// Adds a query parameter to the upgrade request.
126    pub fn query(mut self, name: &str, value: &str) -> Self {
127        self.query.push((name.to_owned(), value.to_owned()));
128        self
129    }
130
131    /// Requests a subprotocol (sent in `Sec-WebSocket-Protocol`).
132    pub fn subprotocol(mut self, protocol: &str) -> Self {
133        self.subprotocols.push(protocol.to_owned());
134        self
135    }
136
137    /// Performs the upgrade and returns the open connection.
138    ///
139    /// Returns an error if the handshake or a dependency is rejected before the
140    /// upgrade (the response status is not `101`).
141    pub async fn connect(self) -> Result<TestWebSocket> {
142        let path = if self.query.is_empty() {
143            self.path.clone()
144        } else {
145            let encoded = serde_urlencoded::to_string(&self.query)
146                .map_err(|_| Error::internal("failed to encode query parameters"))?;
147            format!("{}?{}", self.path, encoded)
148        };
149
150        // The headers common to both transports: defaults, per-request, and cookies.
151        let mut base_headers = HeaderMap::new();
152        for (name, value) in self.shared.default_headers.iter() {
153            base_headers.insert(name, value.clone());
154        }
155        for (name, value) in self.shared.unsafe_default_headers.iter() {
156            base_headers.insert(name, value.clone());
157        }
158        self.shared
159            .reject_in_process_sensitive_headers(&self.headers)?;
160        for header in &self.headers {
161            base_headers.insert(header.name.clone(), header.value.clone());
162        }
163        self.shared
164            .cookies
165            .lock()
166            .expect("cookie jar mutex poisoned")
167            .apply(&mut base_headers);
168        let subprotocol = if self.subprotocols.is_empty() {
169            None
170        } else {
171            HeaderValue::from_str(&self.subprotocols.join(", ")).ok()
172        };
173
174        match &self.shared.transport {
175            Transport::InProcess(app) => {
176                let mut request =
177                    http::Request::new(box_body(http_body_util::Full::new(bytes::Bytes::new())));
178                *request.method_mut() = Method::GET;
179                *request.uri_mut() = path
180                    .parse()
181                    .map_err(|_| Error::bad_request(format!("invalid request URI: {path}")))?;
182                let map = request.headers_mut();
183                *map = base_headers;
184                map.insert(UPGRADE, HeaderValue::from_static("websocket"));
185                map.insert(CONNECTION, HeaderValue::from_static("upgrade"));
186                map.insert(SEC_WEBSOCKET_VERSION, HeaderValue::from_static("13"));
187                map.insert(SEC_WEBSOCKET_KEY, HeaderValue::from_static(WS_TEST_KEY));
188                if let Some(value) = subprotocol {
189                    map.insert(SEC_WEBSOCKET_PROTOCOL, value);
190                }
191
192                let (client_io, server_io) = tokio::io::duplex(WS_DUPLEX_BUFFER);
193                let response = app.dispatch_upgrade(request, server_io).await;
194                if response.status() != StatusCode::SWITCHING_PROTOCOLS {
195                    return Err(rejected(response.status()));
196                }
197                let stream = WebSocketStream::from_raw_socket(
198                    ClientIo::Duplex(client_io),
199                    Role::Client,
200                    None,
201                )
202                .await;
203                Ok(TestWebSocket { stream })
204            }
205            Transport::RealPort(addr) => {
206                // Build the handshake request from the URL so tungstenite generates
207                // the mandatory headers (key, version, upgrade), then add ours.
208                let url = format!("ws://{addr}{path}");
209                let mut request = url
210                    .as_str()
211                    .into_client_request()
212                    .map_err(connection_error)?;
213                for (name, value) in base_headers.iter() {
214                    request.headers_mut().insert(name, value.clone());
215                }
216                if let Some(value) = subprotocol {
217                    request.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, value);
218                }
219
220                let stream = TcpStream::connect(addr).await.map_err(|error| {
221                    Error::internal(format!("failed to connect to {addr}: {error}"))
222                })?;
223                let (stream, _response) =
224                    tokio_tungstenite::client_async(request, ClientIo::Tcp(stream))
225                        .await
226                        .map_err(connection_error)?;
227                Ok(TestWebSocket { stream })
228            }
229        }
230    }
231}
232
233/// The error returned when a WebSocket upgrade is rejected before acceptance.
234fn rejected(status: StatusCode) -> Error {
235    Error::bad_request(format!("websocket upgrade rejected with status {status}"))
236        .with_code("WS_UPGRADE_REJECTED")
237}
238
239/// An open WebSocket connection in a test.
240pub struct TestWebSocket {
241    stream: WebSocketStream<ClientIo>,
242}
243
244impl TestWebSocket {
245    /// Sends a text message.
246    pub async fn send_text(&mut self, text: impl Into<String>) -> Result<()> {
247        self.send(WsMessage::Text(text.into())).await
248    }
249
250    /// Serializes `value` as JSON and sends it as a text message.
251    pub async fn send_json<T: Serialize>(&mut self, value: &T) -> Result<()> {
252        let text = serde_json::to_string(value)
253            .map_err(|error| Error::internal(format!("failed to encode message: {error}")))?;
254        self.send_text(text).await
255    }
256
257    /// Sends a binary message.
258    pub async fn send_binary(&mut self, bytes: impl Into<Vec<u8>>) -> Result<()> {
259        self.send(WsMessage::Binary(bytes.into())).await
260    }
261
262    async fn send(&mut self, message: WsMessage) -> Result<()> {
263        self.stream
264            .send(into_tungstenite(message))
265            .await
266            .map_err(connection_error)
267    }
268
269    /// Receives the next message, or `None` once the connection closes.
270    pub async fn receive(&mut self) -> Result<Option<WsMessage>> {
271        loop {
272            match self.stream.next().await {
273                Some(Ok(message)) => {
274                    if let Some(message) = from_tungstenite(message) {
275                        return Ok(Some(message));
276                    }
277                }
278                Some(Err(error)) => return Err(connection_error(error)),
279                None => return Ok(None),
280            }
281        }
282    }
283
284    /// Receives the next text message, skipping control frames.
285    pub async fn receive_text(&mut self) -> Result<String> {
286        loop {
287            match self.receive().await? {
288                Some(WsMessage::Text(text)) => return Ok(text),
289                Some(WsMessage::Close(_)) | None => {
290                    return Err(closed_error());
291                }
292                Some(_) => continue,
293            }
294        }
295    }
296
297    /// Receives the next message and deserializes it from JSON.
298    pub async fn receive_json<T: DeserializeOwned>(&mut self) -> Result<T> {
299        loop {
300            match self.receive().await? {
301                Some(WsMessage::Text(text)) => {
302                    return serde_json::from_str(&text).map_err(decode_error);
303                }
304                Some(WsMessage::Binary(bytes)) => {
305                    return serde_json::from_slice(&bytes).map_err(decode_error);
306                }
307                Some(WsMessage::Close(_)) | None => return Err(closed_error()),
308                Some(_) => continue,
309            }
310        }
311    }
312
313    /// Waits for the close frame and returns it.
314    pub async fn receive_close(&mut self) -> Result<WsClose> {
315        loop {
316            match self.receive().await? {
317                Some(WsMessage::Close(Some(close))) => return Ok(close),
318                Some(WsMessage::Close(None)) => {
319                    return Ok(WsClose {
320                        code: WsCloseCode::NormalClosure,
321                        reason: String::new(),
322                    });
323                }
324                None => return Err(closed_error()),
325                Some(_) => continue,
326            }
327        }
328    }
329
330    /// Closes the connection.
331    pub async fn close(&mut self) -> Result<()> {
332        SinkExt::close(&mut self.stream)
333            .await
334            .map_err(connection_error)
335    }
336}
337
338/// The error returned when the connection closed before a message was received.
339fn closed_error() -> Error {
340    Error::internal("websocket connection closed").with_code("WS_CLOSED")
341}
342
343/// Maps a JSON decode failure to an error.
344fn decode_error(error: serde_json::Error) -> Error {
345    Error::internal(format!("message is not valid JSON: {error}"))
346}
347
348#[cfg(test)]
349mod tests {
350    use super::super::client::{Shared, Transport};
351    use super::super::cookie::CookieJar;
352    use super::*;
353    use crate::app::App;
354    use http::HeaderMap;
355    use std::sync::Arc;
356    use tokio::io::{AsyncReadExt, AsyncWriteExt};
357    use tokio::net::TcpListener;
358
359    #[test]
360    fn builder_ignores_invalid_headers_and_keeps_query_and_subprotocols() {
361        let shared = Arc::new(Shared {
362            transport: Transport::InProcess(Arc::new(App::new().build().unwrap())),
363            default_headers: HeaderMap::new(),
364            unsafe_default_headers: HeaderMap::new(),
365            cookies: std::sync::Mutex::new(CookieJar::default()),
366        });
367
368        let builder = TestWebSocketBuilder::new(shared, "/ws")
369            .header("x-good", "ok")
370            .header("bad name", "ignored")
371            .header("x-bad-value", "line\nbreak")
372            .query("room", "main hall")
373            .subprotocol("json")
374            .subprotocol("binary");
375
376        assert_eq!(builder.headers.len(), 1);
377        assert_eq!(
378            builder.query,
379            vec![("room".to_owned(), "main hall".to_owned())]
380        );
381        assert_eq!(
382            builder.subprotocols,
383            vec!["json".to_owned(), "binary".to_owned()]
384        );
385    }
386
387    #[test]
388    fn unsafe_header_marks_the_entry() {
389        let shared = Arc::new(Shared {
390            transport: Transport::InProcess(Arc::new(App::new().build().unwrap())),
391            default_headers: HeaderMap::new(),
392            unsafe_default_headers: HeaderMap::new(),
393            cookies: std::sync::Mutex::new(CookieJar::default()),
394        });
395        let builder = TestWebSocketBuilder::new(shared, "/ws").unsafe_header("host", "example.com");
396        assert_eq!(builder.headers.len(), 1);
397        assert!(builder.headers[0].unsafe_allowed);
398    }
399
400    #[test]
401    fn rejected_error_uses_stable_code() {
402        let error = rejected(StatusCode::FORBIDDEN);
403
404        assert_eq!(error.code(), "WS_UPGRADE_REJECTED");
405        assert_eq!(
406            error.message(),
407            "websocket upgrade rejected with status 403 Forbidden"
408        );
409    }
410
411    #[test]
412    fn closed_error_uses_stable_code() {
413        let error = closed_error();
414
415        assert_eq!(error.code(), "WS_CLOSED");
416        assert_eq!(error.message(), "websocket connection closed");
417    }
418
419    #[test]
420    fn decode_error_reports_json_failure() {
421        let source = serde_json::from_str::<serde_json::Value>("{").unwrap_err();
422        let error = decode_error(source);
423
424        assert!(error.message().starts_with("message is not valid JSON:"));
425    }
426
427    #[tokio::test]
428    async fn client_io_duplex_supports_async_read_and_write() {
429        let (left, mut right) = tokio::io::duplex(16);
430        let mut io = ClientIo::Duplex(left);
431
432        io.write_all(b"ping").await.unwrap();
433        io.flush().await.unwrap();
434
435        let mut buf = [0u8; 4];
436        right.read_exact(&mut buf).await.unwrap();
437        assert_eq!(&buf, b"ping");
438
439        right.write_all(b"pong").await.unwrap();
440        right.flush().await.unwrap();
441
442        let mut back = [0u8; 4];
443        io.read_exact(&mut back).await.unwrap();
444        assert_eq!(&back, b"pong");
445    }
446
447    #[tokio::test]
448    async fn client_io_tcp_supports_async_read_and_write() {
449        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
450        let addr = listener.local_addr().unwrap();
451
452        let server = tokio::spawn(async move {
453            let (mut socket, _) = listener.accept().await.unwrap();
454            let mut buf = [0u8; 4];
455            socket.read_exact(&mut buf).await.unwrap();
456            assert_eq!(&buf, b"ping");
457            socket.write_all(b"pong").await.unwrap();
458            socket.flush().await.unwrap();
459        });
460
461        let stream = TcpStream::connect(addr).await.unwrap();
462        let mut io = ClientIo::Tcp(stream);
463
464        io.write_all(b"ping").await.unwrap();
465        io.flush().await.unwrap();
466
467        let mut back = [0u8; 4];
468        io.read_exact(&mut back).await.unwrap();
469        assert_eq!(&back, b"pong");
470
471        let _ = server.await;
472    }
473
474    #[tokio::test]
475    async fn client_io_duplex_poll_shutdown_completes() {
476        use tokio::io::AsyncWriteExt;
477        let (left, _right) = tokio::io::duplex(16);
478        let mut io = ClientIo::Duplex(left);
479        // Shutdown should complete without error.
480        io.shutdown().await.unwrap();
481    }
482
483    #[tokio::test]
484    async fn client_io_tcp_poll_shutdown_completes() {
485        use tokio::io::AsyncWriteExt;
486        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
487        let addr = listener.local_addr().unwrap();
488        let server = tokio::spawn(async move {
489            let _ = listener.accept().await;
490        });
491        let stream = TcpStream::connect(addr).await.unwrap();
492        let mut io = ClientIo::Tcp(stream);
493        io.shutdown().await.unwrap();
494        let _ = server.await;
495    }
496
497    #[test]
498    fn builder_keeps_query_parameters() {
499        let shared = Arc::new(Shared {
500            transport: Transport::InProcess(Arc::new(App::new().build().unwrap())),
501            default_headers: HeaderMap::new(),
502            unsafe_default_headers: HeaderMap::new(),
503            cookies: std::sync::Mutex::new(CookieJar::default()),
504        });
505
506        let builder = TestWebSocketBuilder::new(shared, "/ws")
507            .query("a", "1")
508            .query("b", "two");
509
510        assert_eq!(
511            builder.query,
512            vec![
513                ("a".to_owned(), "1".to_owned()),
514                ("b".to_owned(), "two".to_owned()),
515            ]
516        );
517    }
518}