1use axum::body::Body;
2use axum::http::Request as AxumRequest;
3use axum_test::{TestResponse as AxumTestResponse, TestServer, TestWebSocket, WsMessage};
4
5pub mod multipart;
6pub use multipart::{MultipartFilePart, build_multipart_body};
7
8pub mod form;
9
10pub mod test_client;
11pub use test_client::{GraphQLSubscriptionSnapshot, TestClient};
12
13use brotli::Decompressor;
14use flate2::read::GzDecoder;
15pub use form::encode_urlencoded_body;
16use http_body_util::BodyExt;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::io::{Cursor, Read};
20
21#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
23pub struct ResponseSnapshot {
24 pub status: u16,
26 pub headers: HashMap<String, String>,
28 pub body: Vec<u8>,
30}
31
32impl ResponseSnapshot {
33 pub fn text(&self) -> Result<String, std::string::FromUtf8Error> {
35 String::from_utf8(self.body.clone())
36 }
37
38 pub fn json(&self) -> Result<Value, serde_json::Error> {
40 serde_json::from_slice(&self.body)
41 }
42
43 pub fn header(&self, name: &str) -> Option<&str> {
45 self.headers.get(&name.to_ascii_lowercase()).map(|s| s.as_str())
46 }
47
48 pub fn graphql_data(&self) -> Result<Value, SnapshotError> {
50 let body: Value = serde_json::from_slice(&self.body)
51 .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
52
53 body.get("data")
54 .cloned()
55 .ok_or_else(|| SnapshotError::Decompression("No 'data' field in GraphQL response".to_string()))
56 }
57
58 pub fn graphql_errors(&self) -> Result<Vec<Value>, SnapshotError> {
60 let body: Value = serde_json::from_slice(&self.body)
61 .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
62
63 Ok(body
64 .get("errors")
65 .and_then(|e| e.as_array())
66 .cloned()
67 .unwrap_or_default())
68 }
69}
70
71#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
73pub enum SnapshotError {
74 InvalidHeader(String),
76 Decompression(String),
78}
79
80impl std::fmt::Display for SnapshotError {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 SnapshotError::InvalidHeader(msg) => write!(f, "Invalid header: {}", msg),
84 SnapshotError::Decompression(msg) => write!(f, "Failed to decode body: {}", msg),
85 }
86 }
87}
88
89impl std::error::Error for SnapshotError {}
90
91impl Default for SnapshotError {
92 fn default() -> Self {
93 Self::InvalidHeader(String::new())
94 }
95}
96
97pub async fn call_test_server(server: &TestServer, request: AxumRequest<Body>) -> AxumTestResponse {
100 let (parts, body) = request.into_parts();
101
102 let mut path = parts.uri.path().to_string();
103 if let Some(query) = parts.uri.query()
104 && !query.is_empty()
105 {
106 path.push('?');
107 path.push_str(query);
108 }
109
110 let mut test_request = server.method(parts.method.clone(), &path);
111
112 for (name, value) in parts.headers.iter() {
113 test_request = test_request.add_header(name.clone(), value.clone());
114 }
115
116 let collected = body
117 .collect()
118 .await
119 .expect("failed to read request body for test dispatch");
120 let bytes = collected.to_bytes();
121 if !bytes.is_empty() {
122 test_request = test_request.bytes(bytes);
123 }
124
125 test_request.await
126}
127
128pub async fn snapshot_response(response: AxumTestResponse) -> Result<ResponseSnapshot, SnapshotError> {
130 let status = response.status_code().as_u16();
131
132 let mut headers = HashMap::new();
133 for (name, value) in response.headers() {
134 let header_value = value
135 .to_str()
136 .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
137 headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
138 }
139
140 let body = response.into_bytes();
141 let decoded_body = decode_body(&headers, body.to_vec())?;
142
143 Ok(ResponseSnapshot {
144 status,
145 headers,
146 body: decoded_body,
147 })
148}
149
150pub async fn snapshot_http_response(
152 response: axum::response::Response<Body>,
153) -> Result<ResponseSnapshot, SnapshotError> {
154 let (parts, body) = response.into_parts();
155 let status = parts.status.as_u16();
156
157 let mut headers = HashMap::new();
158 for (name, value) in parts.headers.iter() {
159 let header_value = value
160 .to_str()
161 .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
162 headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
163 }
164
165 let collected = body
166 .collect()
167 .await
168 .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
169 let bytes = collected.to_bytes();
170 let decoded_body = decode_body(&headers, bytes.to_vec())?;
171
172 Ok(ResponseSnapshot {
173 status,
174 headers,
175 body: decoded_body,
176 })
177}
178
179pub async fn snapshot_http_response_allow_body_errors(
185 response: axum::response::Response<Body>,
186) -> Result<ResponseSnapshot, SnapshotError> {
187 let (parts, mut body) = response.into_parts();
188 let status = parts.status.as_u16();
189
190 let mut headers = HashMap::new();
191 for (name, value) in parts.headers.iter() {
192 let header_value = value
193 .to_str()
194 .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
195 headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
196 }
197
198 let mut bytes = Vec::<u8>::new();
199 while let Some(frame_result) = body.frame().await {
200 match frame_result {
201 Ok(frame) => {
202 if let Ok(data) = frame.into_data() {
203 bytes.extend_from_slice(&data);
204 }
205 }
206 Err(_) => break,
207 }
208 }
209
210 let decoded_body = decode_body(&headers, bytes)?;
211
212 Ok(ResponseSnapshot {
213 status,
214 headers,
215 body: decoded_body,
216 })
217}
218
219fn decode_body(headers: &HashMap<String, String>, body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
220 let encoding = headers
221 .get("content-encoding")
222 .map(|value| value.trim().to_ascii_lowercase());
223
224 match encoding.as_deref() {
225 Some("gzip" | "x-gzip") => decode_gzip(body),
226 Some("br") => decode_brotli(body),
227 _ => Ok(body),
228 }
229}
230
231fn decode_gzip(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
232 let mut decoder = GzDecoder::new(Cursor::new(body));
233 let mut decoded_bytes = Vec::new();
234 decoder
235 .read_to_end(&mut decoded_bytes)
236 .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
237 Ok(decoded_bytes)
238}
239
240fn decode_brotli(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
241 let mut decoder = Decompressor::new(Cursor::new(body), 4096);
242 let mut decoded_bytes = Vec::new();
243 decoder
244 .read_to_end(&mut decoded_bytes)
245 .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
246 Ok(decoded_bytes)
247}
248
249pub struct WebSocketConnection {
254 inner: TestWebSocket,
255}
256
257impl WebSocketConnection {
258 pub fn new(inner: TestWebSocket) -> Self {
260 Self { inner }
261 }
262
263 pub async fn send_text(&mut self, text: impl std::fmt::Display) {
265 self.inner.send_text(text).await;
266 }
267
268 pub async fn send_json<T: serde::Serialize>(&mut self, value: &T) {
270 self.inner.send_json(value).await;
271 }
272
273 pub async fn send_message(&mut self, msg: WsMessage) {
275 self.inner.send_message(msg).await;
276 }
277
278 pub async fn receive_text(&mut self) -> String {
280 self.inner.receive_text().await
281 }
282
283 pub async fn receive_json<T: serde::de::DeserializeOwned>(&mut self) -> T {
285 self.inner.receive_json().await
286 }
287
288 pub async fn receive_bytes(&mut self) -> bytes::Bytes {
290 self.inner.receive_bytes().await
291 }
292
293 pub async fn receive_message(&mut self) -> WebSocketMessage {
295 let msg = self.inner.receive_message().await;
296 WebSocketMessage::from_ws_message(msg)
297 }
298
299 pub async fn close(self) -> Result<(), String> {
301 self.close_with(1000, None).await
302 }
303
304 pub async fn close_with(mut self, code: u16, reason: Option<String>) -> Result<(), String> {
308 use axum_test::WsMessage;
309 use tungstenite::protocol::frame::CloseFrame;
310 use tungstenite::protocol::frame::coding::CloseCode;
311
312 let frame = CloseFrame {
313 code: CloseCode::from(code),
314 reason: reason.unwrap_or_default().into(),
315 };
316 self.inner.send_message(WsMessage::Close(Some(frame))).await;
317 Ok(())
318 }
319}
320
321#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
323pub enum WebSocketMessage {
324 Text(String),
326 Binary(Vec<u8>),
328 Close {
333 code: u16,
335 reason: Option<String>,
337 },
338 Ping(Vec<u8>),
340 Pong(Vec<u8>),
342}
343
344impl Default for WebSocketMessage {
345 fn default() -> Self {
346 Self::Text(String::new())
347 }
348}
349
350impl WebSocketMessage {
351 fn from_ws_message(msg: WsMessage) -> Self {
352 match msg {
353 WsMessage::Text(text) => WebSocketMessage::Text(text.to_string()),
354 WsMessage::Binary(data) => WebSocketMessage::Binary(data.to_vec()),
355 WsMessage::Close(Some(frame)) => {
356 let code: u16 = frame.code.into();
357 let reason_str = frame.reason.to_string();
358 let reason = if reason_str.is_empty() { None } else { Some(reason_str) };
359 WebSocketMessage::Close { code, reason }
360 }
361 WsMessage::Close(None) | WsMessage::Frame(_) => WebSocketMessage::Close {
363 code: 1005,
364 reason: None,
365 },
366 WsMessage::Ping(data) => WebSocketMessage::Ping(data.to_vec()),
367 WsMessage::Pong(data) => WebSocketMessage::Pong(data.to_vec()),
368 }
369 }
370
371 pub fn as_text(&self) -> Option<&str> {
373 match self {
374 WebSocketMessage::Text(text) => Some(text),
375 _ => None,
376 }
377 }
378
379 pub fn as_json(&self) -> Result<Value, String> {
381 match self {
382 WebSocketMessage::Text(text) => {
383 serde_json::from_str(text).map_err(|e| format!("Failed to parse JSON: {}", e))
384 }
385 _ => Err("Message is not text".to_string()),
386 }
387 }
388
389 pub fn as_binary(&self) -> Option<&[u8]> {
391 match self {
392 WebSocketMessage::Binary(data) => Some(data),
393 _ => None,
394 }
395 }
396
397 pub fn is_close(&self) -> bool {
399 matches!(self, WebSocketMessage::Close { .. })
400 }
401
402 pub fn close_code(&self) -> Option<u16> {
404 match self {
405 WebSocketMessage::Close { code, .. } => Some(*code),
406 _ => None,
407 }
408 }
409
410 pub fn close_reason(&self) -> Option<&str> {
412 match self {
413 WebSocketMessage::Close { reason, .. } => reason.as_deref(),
414 _ => None,
415 }
416 }
417}
418
419pub async fn connect_websocket(server: &TestServer, path: &str) -> WebSocketConnection {
421 let ws = server.get_websocket(path).await.into_websocket().await;
422 WebSocketConnection::new(ws)
423}
424
425#[derive(Debug)]
429pub struct SseStream {
430 body: String,
431 events: Vec<SseEvent>,
432}
433
434impl SseStream {
435 pub fn from_response(response: &ResponseSnapshot) -> Result<Self, String> {
437 let body = response
438 .text()
439 .map_err(|e| format!("Failed to read response body: {}", e))?;
440
441 let events = Self::parse_events(&body);
442
443 Ok(Self { body, events })
444 }
445
446 fn parse_events(body: &str) -> Vec<SseEvent> {
447 let mut events = Vec::new();
448 let lines: Vec<&str> = body.lines().collect();
449 let mut i = 0;
450
451 while i < lines.len() {
452 if lines[i].starts_with("data:") {
453 let data = lines[i].trim_start_matches("data:").trim().to_string();
454 events.push(SseEvent { data });
455 } else if lines[i].starts_with("data") {
456 let data = lines[i].trim_start_matches("data").trim().to_string();
457 if !data.is_empty() || lines[i].len() == 4 {
458 events.push(SseEvent { data });
459 }
460 }
461 i += 1;
462 }
463
464 events
465 }
466
467 pub fn events(&self) -> &[SseEvent] {
469 &self.events
470 }
471
472 pub fn body(&self) -> &str {
474 &self.body
475 }
476
477 pub fn events_as_json(&self) -> Result<Vec<Value>, String> {
479 self.events
480 .iter()
481 .map(|event| event.as_json())
482 .collect::<Result<Vec<_>, _>>()
483 }
484}
485
486#[derive(Debug, Clone)]
488pub struct SseEvent {
489 pub data: String,
491}
492
493impl SseEvent {
494 pub fn as_json(&self) -> Result<Value, String> {
496 serde_json::from_str(&self.data).map_err(|e| format!("Failed to parse JSON: {}", e))
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use axum::body::Body;
504 use axum::response::Response;
505 use std::io::Write;
506
507 #[test]
508 fn sse_stream_parses_multiple_events() {
509 let mut headers = HashMap::new();
510 headers.insert("content-type".to_string(), "text/event-stream".to_string());
511
512 let snapshot = ResponseSnapshot {
513 status: 200,
514 headers,
515 body: b"data: {\"id\": 1}\n\ndata: \"hello\"\n\n".to_vec(),
516 };
517
518 let stream = SseStream::from_response(&snapshot).expect("stream");
519 assert_eq!(stream.events().len(), 2);
520 assert_eq!(stream.events()[0].as_json().unwrap()["id"], serde_json::json!(1));
521 assert_eq!(stream.events()[1].data, "\"hello\"");
522 assert_eq!(stream.events_as_json().unwrap().len(), 2);
523 }
524
525 #[test]
526 fn sse_event_reports_invalid_json() {
527 let event = SseEvent {
528 data: "not-json".to_string(),
529 };
530 assert!(event.as_json().is_err());
531 }
532
533 #[test]
534 fn test_graphql_data_extraction() {
535 let mut headers = HashMap::new();
536 headers.insert("content-type".to_string(), "application/json".to_string());
537
538 let graphql_response = serde_json::json!({
539 "data": {
540 "user": {
541 "id": "1",
542 "name": "Alice"
543 }
544 }
545 });
546
547 let snapshot = ResponseSnapshot {
548 status: 200,
549 headers,
550 body: serde_json::to_vec(&graphql_response).unwrap(),
551 };
552
553 let data = snapshot.graphql_data().expect("data extraction");
554 assert_eq!(data["user"]["id"], "1");
555 assert_eq!(data["user"]["name"], "Alice");
556 }
557
558 #[test]
559 fn test_graphql_errors_extraction() {
560 let mut headers = HashMap::new();
561 headers.insert("content-type".to_string(), "application/json".to_string());
562
563 let graphql_response = serde_json::json!({
564 "errors": [
565 {
566 "message": "Field not found",
567 "path": ["user", "email"]
568 },
569 {
570 "message": "Unauthorized",
571 "extensions": { "code": "UNAUTHENTICATED" }
572 }
573 ]
574 });
575
576 let snapshot = ResponseSnapshot {
577 status: 400,
578 headers,
579 body: serde_json::to_vec(&graphql_response).unwrap(),
580 };
581
582 let errors = snapshot.graphql_errors().expect("errors extraction");
583 assert_eq!(errors.len(), 2);
584 assert_eq!(errors[0]["message"], "Field not found");
585 assert_eq!(errors[1]["message"], "Unauthorized");
586 }
587
588 #[test]
589 fn test_graphql_missing_data_field() {
590 let mut headers = HashMap::new();
591 headers.insert("content-type".to_string(), "application/json".to_string());
592
593 let graphql_response = serde_json::json!({
594 "errors": [{ "message": "Query failed" }]
595 });
596
597 let snapshot = ResponseSnapshot {
598 status: 400,
599 headers,
600 body: serde_json::to_vec(&graphql_response).unwrap(),
601 };
602
603 let result = snapshot.graphql_data();
604 assert!(result.is_err());
605 assert!(result.unwrap_err().to_string().contains("No 'data' field"));
606 }
607
608 #[test]
609 fn test_graphql_empty_errors() {
610 let mut headers = HashMap::new();
611 headers.insert("content-type".to_string(), "application/json".to_string());
612
613 let graphql_response = serde_json::json!({
614 "data": { "result": null }
615 });
616
617 let snapshot = ResponseSnapshot {
618 status: 200,
619 headers,
620 body: serde_json::to_vec(&graphql_response).unwrap(),
621 };
622
623 let errors = snapshot.graphql_errors().expect("errors extraction");
624 assert!(errors.is_empty());
625 }
626
627 fn gzip_bytes(input: &[u8]) -> Vec<u8> {
628 let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
629 encoder.write_all(input).expect("gzip write");
630 encoder.finish().expect("gzip finish")
631 }
632
633 fn brotli_bytes(input: &[u8]) -> Vec<u8> {
634 let mut encoder = brotli::CompressorWriter::new(Vec::new(), 4096, 5, 22);
635 encoder.write_all(input).expect("brotli write");
636 encoder.into_inner()
637 }
638
639 #[tokio::test]
640 async fn snapshot_http_response_decodes_gzip_body() {
641 let body = b"hello gzip";
642 let compressed = gzip_bytes(body);
643 let response = Response::builder()
644 .status(200)
645 .header("content-encoding", "gzip")
646 .body(Body::from(compressed))
647 .unwrap();
648
649 let snapshot = snapshot_http_response(response).await.expect("snapshot");
650 assert_eq!(snapshot.body, body);
651 }
652
653 #[tokio::test]
654 async fn snapshot_http_response_decodes_brotli_body() {
655 let body = b"hello brotli";
656 let compressed = brotli_bytes(body);
657 let response = Response::builder()
658 .status(200)
659 .header("content-encoding", "br")
660 .body(Body::from(compressed))
661 .unwrap();
662
663 let snapshot = snapshot_http_response(response).await.expect("snapshot");
664 assert_eq!(snapshot.body, body);
665 }
666
667 #[tokio::test]
668 async fn snapshot_http_response_leaves_plain_body() {
669 let body = b"plain";
670 let response = Response::builder()
671 .status(200)
672 .body(Body::from(body.as_slice()))
673 .unwrap();
674
675 let snapshot = snapshot_http_response(response).await.expect("snapshot");
676 assert_eq!(snapshot.body, body);
677 }
678}