1use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use futures_util::{SinkExt, StreamExt};
8use http::header::{
9 CONNECTION, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE,
10};
11use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
12use serde::de::DeserializeOwned;
13use serde::Serialize;
14use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf};
15use tokio::net::TcpStream;
16use tokio_tungstenite::tungstenite::client::IntoClientRequest;
17use tokio_tungstenite::tungstenite::protocol::Role;
18use tokio_tungstenite::WebSocketStream;
19
20use super::client::{Shared, TestHeader, Transport};
21use crate::body::box_body;
22use crate::error::{Error, Result};
23use crate::ws::{
24 connection_error, from_tungstenite, into_tungstenite, WsClose, WsCloseCode, WsMessage,
25};
26
27const WS_DUPLEX_BUFFER: usize = 64 * 1024;
29const WS_TEST_KEY: &str = "dGhlIHNhbXBsZSBub25jZQ==";
32
33pub(crate) enum ClientIo {
37 Duplex(DuplexStream),
38 Tcp(TcpStream),
39}
40
41impl AsyncRead for ClientIo {
42 fn poll_read(
43 self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 buf: &mut ReadBuf<'_>,
46 ) -> Poll<std::io::Result<()>> {
47 match self.get_mut() {
48 ClientIo::Duplex(io) => Pin::new(io).poll_read(cx, buf),
49 ClientIo::Tcp(io) => Pin::new(io).poll_read(cx, buf),
50 }
51 }
52}
53
54impl AsyncWrite for ClientIo {
55 fn poll_write(
56 self: Pin<&mut Self>,
57 cx: &mut Context<'_>,
58 buf: &[u8],
59 ) -> Poll<std::io::Result<usize>> {
60 match self.get_mut() {
61 ClientIo::Duplex(io) => Pin::new(io).poll_write(cx, buf),
62 ClientIo::Tcp(io) => Pin::new(io).poll_write(cx, buf),
63 }
64 }
65
66 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
67 match self.get_mut() {
68 ClientIo::Duplex(io) => Pin::new(io).poll_flush(cx),
69 ClientIo::Tcp(io) => Pin::new(io).poll_flush(cx),
70 }
71 }
72
73 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
74 match self.get_mut() {
75 ClientIo::Duplex(io) => Pin::new(io).poll_shutdown(cx),
76 ClientIo::Tcp(io) => Pin::new(io).poll_shutdown(cx),
77 }
78 }
79}
80
81pub struct TestWebSocketBuilder {
84 shared: Arc<Shared>,
85 path: String,
86 query: Vec<(String, String)>,
87 headers: Vec<TestHeader>,
88 subprotocols: Vec<String>,
89}
90
91impl TestWebSocketBuilder {
92 pub(crate) fn new(shared: Arc<Shared>, path: impl Into<String>) -> Self {
93 Self {
94 shared,
95 path: path.into(),
96 query: Vec::new(),
97 headers: Vec::new(),
98 subprotocols: Vec::new(),
99 }
100 }
101
102 pub fn header(mut self, name: &str, value: &str) -> Self {
104 if let (Ok(name), Ok(value)) = (
105 HeaderName::from_bytes(name.as_bytes()),
106 HeaderValue::from_str(value),
107 ) {
108 self.headers.push(TestHeader::safe(name, value));
109 }
110 self
111 }
112
113 pub fn unsafe_header(mut self, name: &str, value: &str) -> Self {
116 if let (Ok(name), Ok(value)) = (
117 HeaderName::from_bytes(name.as_bytes()),
118 HeaderValue::from_str(value),
119 ) {
120 self.headers.push(TestHeader::unsafe_allowed(name, value));
121 }
122 self
123 }
124
125 pub fn query(mut self, name: &str, value: &str) -> Self {
127 self.query.push((name.to_owned(), value.to_owned()));
128 self
129 }
130
131 pub fn subprotocol(mut self, protocol: &str) -> Self {
133 self.subprotocols.push(protocol.to_owned());
134 self
135 }
136
137 pub async fn connect(self) -> Result<TestWebSocket> {
142 let path = if self.query.is_empty() {
143 self.path.clone()
144 } else {
145 let encoded = serde_urlencoded::to_string(&self.query)
146 .map_err(|_| Error::internal("failed to encode query parameters"))?;
147 format!("{}?{}", self.path, encoded)
148 };
149
150 let mut base_headers = HeaderMap::new();
152 for (name, value) in self.shared.default_headers.iter() {
153 base_headers.insert(name, value.clone());
154 }
155 for (name, value) in self.shared.unsafe_default_headers.iter() {
156 base_headers.insert(name, value.clone());
157 }
158 self.shared
159 .reject_in_process_sensitive_headers(&self.headers)?;
160 for header in &self.headers {
161 base_headers.insert(header.name.clone(), header.value.clone());
162 }
163 self.shared
164 .cookies
165 .lock()
166 .expect("cookie jar mutex poisoned")
167 .apply(&mut base_headers);
168 let subprotocol = if self.subprotocols.is_empty() {
169 None
170 } else {
171 HeaderValue::from_str(&self.subprotocols.join(", ")).ok()
172 };
173
174 match &self.shared.transport {
175 Transport::InProcess(app) => {
176 let mut request =
177 http::Request::new(box_body(http_body_util::Full::new(bytes::Bytes::new())));
178 *request.method_mut() = Method::GET;
179 *request.uri_mut() = path
180 .parse()
181 .map_err(|_| Error::bad_request(format!("invalid request URI: {path}")))?;
182 let map = request.headers_mut();
183 *map = base_headers;
184 map.insert(UPGRADE, HeaderValue::from_static("websocket"));
185 map.insert(CONNECTION, HeaderValue::from_static("upgrade"));
186 map.insert(SEC_WEBSOCKET_VERSION, HeaderValue::from_static("13"));
187 map.insert(SEC_WEBSOCKET_KEY, HeaderValue::from_static(WS_TEST_KEY));
188 if let Some(value) = subprotocol {
189 map.insert(SEC_WEBSOCKET_PROTOCOL, value);
190 }
191
192 let (client_io, server_io) = tokio::io::duplex(WS_DUPLEX_BUFFER);
193 let response = app.dispatch_upgrade(request, server_io).await;
194 if response.status() != StatusCode::SWITCHING_PROTOCOLS {
195 return Err(rejected(response.status()));
196 }
197 let stream = WebSocketStream::from_raw_socket(
198 ClientIo::Duplex(client_io),
199 Role::Client,
200 None,
201 )
202 .await;
203 Ok(TestWebSocket { stream })
204 }
205 Transport::RealPort(addr) => {
206 let url = format!("ws://{addr}{path}");
209 let mut request = url
210 .as_str()
211 .into_client_request()
212 .map_err(connection_error)?;
213 for (name, value) in base_headers.iter() {
214 request.headers_mut().insert(name, value.clone());
215 }
216 if let Some(value) = subprotocol {
217 request.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, value);
218 }
219
220 let stream = TcpStream::connect(addr).await.map_err(|error| {
221 Error::internal(format!("failed to connect to {addr}: {error}"))
222 })?;
223 let (stream, _response) =
224 tokio_tungstenite::client_async(request, ClientIo::Tcp(stream))
225 .await
226 .map_err(connection_error)?;
227 Ok(TestWebSocket { stream })
228 }
229 }
230 }
231}
232
233fn rejected(status: StatusCode) -> Error {
235 Error::bad_request(format!("websocket upgrade rejected with status {status}"))
236 .with_code("WS_UPGRADE_REJECTED")
237}
238
239pub struct TestWebSocket {
241 stream: WebSocketStream<ClientIo>,
242}
243
244impl TestWebSocket {
245 pub async fn send_text(&mut self, text: impl Into<String>) -> Result<()> {
247 self.send(WsMessage::Text(text.into())).await
248 }
249
250 pub async fn send_json<T: Serialize>(&mut self, value: &T) -> Result<()> {
252 let text = serde_json::to_string(value)
253 .map_err(|error| Error::internal(format!("failed to encode message: {error}")))?;
254 self.send_text(text).await
255 }
256
257 pub async fn send_binary(&mut self, bytes: impl Into<Vec<u8>>) -> Result<()> {
259 self.send(WsMessage::Binary(bytes.into())).await
260 }
261
262 async fn send(&mut self, message: WsMessage) -> Result<()> {
263 self.stream
264 .send(into_tungstenite(message))
265 .await
266 .map_err(connection_error)
267 }
268
269 pub async fn receive(&mut self) -> Result<Option<WsMessage>> {
271 loop {
272 match self.stream.next().await {
273 Some(Ok(message)) => {
274 if let Some(message) = from_tungstenite(message) {
275 return Ok(Some(message));
276 }
277 }
278 Some(Err(error)) => return Err(connection_error(error)),
279 None => return Ok(None),
280 }
281 }
282 }
283
284 pub async fn receive_text(&mut self) -> Result<String> {
286 loop {
287 match self.receive().await? {
288 Some(WsMessage::Text(text)) => return Ok(text),
289 Some(WsMessage::Close(_)) | None => {
290 return Err(closed_error());
291 }
292 Some(_) => continue,
293 }
294 }
295 }
296
297 pub async fn receive_json<T: DeserializeOwned>(&mut self) -> Result<T> {
299 loop {
300 match self.receive().await? {
301 Some(WsMessage::Text(text)) => {
302 return serde_json::from_str(&text).map_err(decode_error);
303 }
304 Some(WsMessage::Binary(bytes)) => {
305 return serde_json::from_slice(&bytes).map_err(decode_error);
306 }
307 Some(WsMessage::Close(_)) | None => return Err(closed_error()),
308 Some(_) => continue,
309 }
310 }
311 }
312
313 pub async fn receive_close(&mut self) -> Result<WsClose> {
315 loop {
316 match self.receive().await? {
317 Some(WsMessage::Close(Some(close))) => return Ok(close),
318 Some(WsMessage::Close(None)) => {
319 return Ok(WsClose {
320 code: WsCloseCode::NormalClosure,
321 reason: String::new(),
322 });
323 }
324 None => return Err(closed_error()),
325 Some(_) => continue,
326 }
327 }
328 }
329
330 pub async fn close(&mut self) -> Result<()> {
332 SinkExt::close(&mut self.stream)
333 .await
334 .map_err(connection_error)
335 }
336}
337
338fn closed_error() -> Error {
340 Error::internal("websocket connection closed").with_code("WS_CLOSED")
341}
342
343fn decode_error(error: serde_json::Error) -> Error {
345 Error::internal(format!("message is not valid JSON: {error}"))
346}
347
348#[cfg(test)]
349mod tests {
350 use super::super::client::{Shared, Transport};
351 use super::super::cookie::CookieJar;
352 use super::*;
353 use crate::app::App;
354 use http::HeaderMap;
355 use std::sync::Arc;
356 use tokio::io::{AsyncReadExt, AsyncWriteExt};
357 use tokio::net::TcpListener;
358
359 #[test]
360 fn builder_ignores_invalid_headers_and_keeps_query_and_subprotocols() {
361 let shared = Arc::new(Shared {
362 transport: Transport::InProcess(Arc::new(App::new().build().unwrap())),
363 default_headers: HeaderMap::new(),
364 unsafe_default_headers: HeaderMap::new(),
365 cookies: std::sync::Mutex::new(CookieJar::default()),
366 });
367
368 let builder = TestWebSocketBuilder::new(shared, "/ws")
369 .header("x-good", "ok")
370 .header("bad name", "ignored")
371 .header("x-bad-value", "line\nbreak")
372 .query("room", "main hall")
373 .subprotocol("json")
374 .subprotocol("binary");
375
376 assert_eq!(builder.headers.len(), 1);
377 assert_eq!(
378 builder.query,
379 vec![("room".to_owned(), "main hall".to_owned())]
380 );
381 assert_eq!(
382 builder.subprotocols,
383 vec!["json".to_owned(), "binary".to_owned()]
384 );
385 }
386
387 #[test]
388 fn unsafe_header_marks_the_entry() {
389 let shared = Arc::new(Shared {
390 transport: Transport::InProcess(Arc::new(App::new().build().unwrap())),
391 default_headers: HeaderMap::new(),
392 unsafe_default_headers: HeaderMap::new(),
393 cookies: std::sync::Mutex::new(CookieJar::default()),
394 });
395 let builder = TestWebSocketBuilder::new(shared, "/ws").unsafe_header("host", "example.com");
396 assert_eq!(builder.headers.len(), 1);
397 assert!(builder.headers[0].unsafe_allowed);
398 }
399
400 #[test]
401 fn rejected_error_uses_stable_code() {
402 let error = rejected(StatusCode::FORBIDDEN);
403
404 assert_eq!(error.code(), "WS_UPGRADE_REJECTED");
405 assert_eq!(
406 error.message(),
407 "websocket upgrade rejected with status 403 Forbidden"
408 );
409 }
410
411 #[test]
412 fn closed_error_uses_stable_code() {
413 let error = closed_error();
414
415 assert_eq!(error.code(), "WS_CLOSED");
416 assert_eq!(error.message(), "websocket connection closed");
417 }
418
419 #[test]
420 fn decode_error_reports_json_failure() {
421 let source = serde_json::from_str::<serde_json::Value>("{").unwrap_err();
422 let error = decode_error(source);
423
424 assert!(error.message().starts_with("message is not valid JSON:"));
425 }
426
427 #[tokio::test]
428 async fn client_io_duplex_supports_async_read_and_write() {
429 let (left, mut right) = tokio::io::duplex(16);
430 let mut io = ClientIo::Duplex(left);
431
432 io.write_all(b"ping").await.unwrap();
433 io.flush().await.unwrap();
434
435 let mut buf = [0u8; 4];
436 right.read_exact(&mut buf).await.unwrap();
437 assert_eq!(&buf, b"ping");
438
439 right.write_all(b"pong").await.unwrap();
440 right.flush().await.unwrap();
441
442 let mut back = [0u8; 4];
443 io.read_exact(&mut back).await.unwrap();
444 assert_eq!(&back, b"pong");
445 }
446
447 #[tokio::test]
448 async fn client_io_tcp_supports_async_read_and_write() {
449 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
450 let addr = listener.local_addr().unwrap();
451
452 let server = tokio::spawn(async move {
453 let (mut socket, _) = listener.accept().await.unwrap();
454 let mut buf = [0u8; 4];
455 socket.read_exact(&mut buf).await.unwrap();
456 assert_eq!(&buf, b"ping");
457 socket.write_all(b"pong").await.unwrap();
458 socket.flush().await.unwrap();
459 });
460
461 let stream = TcpStream::connect(addr).await.unwrap();
462 let mut io = ClientIo::Tcp(stream);
463
464 io.write_all(b"ping").await.unwrap();
465 io.flush().await.unwrap();
466
467 let mut back = [0u8; 4];
468 io.read_exact(&mut back).await.unwrap();
469 assert_eq!(&back, b"pong");
470
471 let _ = server.await;
472 }
473
474 #[tokio::test]
475 async fn client_io_duplex_poll_shutdown_completes() {
476 use tokio::io::AsyncWriteExt;
477 let (left, _right) = tokio::io::duplex(16);
478 let mut io = ClientIo::Duplex(left);
479 io.shutdown().await.unwrap();
481 }
482
483 #[tokio::test]
484 async fn client_io_tcp_poll_shutdown_completes() {
485 use tokio::io::AsyncWriteExt;
486 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
487 let addr = listener.local_addr().unwrap();
488 let server = tokio::spawn(async move {
489 let _ = listener.accept().await;
490 });
491 let stream = TcpStream::connect(addr).await.unwrap();
492 let mut io = ClientIo::Tcp(stream);
493 io.shutdown().await.unwrap();
494 let _ = server.await;
495 }
496
497 #[test]
498 fn builder_keeps_query_parameters() {
499 let shared = Arc::new(Shared {
500 transport: Transport::InProcess(Arc::new(App::new().build().unwrap())),
501 default_headers: HeaderMap::new(),
502 unsafe_default_headers: HeaderMap::new(),
503 cookies: std::sync::Mutex::new(CookieJar::default()),
504 });
505
506 let builder = TestWebSocketBuilder::new(shared, "/ws")
507 .query("a", "1")
508 .query("b", "two");
509
510 assert_eq!(
511 builder.query,
512 vec![
513 ("a".to_owned(), "1".to_owned()),
514 ("b".to_owned(), "two".to_owned()),
515 ]
516 );
517 }
518}