1use axum::body::Body;
2use axum::http::Request as AxumRequest;
3use axum_test::{TestResponse as AxumTestResponse, TestServer, TestWebSocket, WsMessage};
4
5pub mod multipart;
6pub use multipart::{MultipartFilePart, build_multipart_body};
7
8pub mod form;
9
10pub mod test_client;
11pub use test_client::TestClient;
12
13use brotli::Decompressor;
14use flate2::read::GzDecoder;
15pub use form::encode_urlencoded_body;
16use http_body_util::BodyExt;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::io::{Cursor, Read};
20
21#[derive(Debug, Clone)]
23pub struct ResponseSnapshot {
24 pub status: u16,
26 pub headers: HashMap<String, String>,
28 pub body: Vec<u8>,
30}
31
32impl ResponseSnapshot {
33 pub fn text(&self) -> Result<String, std::string::FromUtf8Error> {
35 String::from_utf8(self.body.clone())
36 }
37
38 pub fn json(&self) -> Result<Value, serde_json::Error> {
40 serde_json::from_slice(&self.body)
41 }
42
43 pub fn header(&self, name: &str) -> Option<&str> {
45 self.headers.get(&name.to_ascii_lowercase()).map(|s| s.as_str())
46 }
47
48 pub fn graphql_data(&self) -> Result<Value, SnapshotError> {
50 let body: Value = serde_json::from_slice(&self.body)
51 .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
52
53 body.get("data")
54 .cloned()
55 .ok_or_else(|| SnapshotError::Decompression("No 'data' field in GraphQL response".to_string()))
56 }
57
58 pub fn graphql_errors(&self) -> Result<Vec<Value>, SnapshotError> {
60 let body: Value = serde_json::from_slice(&self.body)
61 .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
62
63 Ok(body
64 .get("errors")
65 .and_then(|e| e.as_array())
66 .cloned()
67 .unwrap_or_default())
68 }
69}
70
71#[derive(Debug)]
73pub enum SnapshotError {
74 InvalidHeader(String),
76 Decompression(String),
78}
79
80impl std::fmt::Display for SnapshotError {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 SnapshotError::InvalidHeader(msg) => write!(f, "Invalid header: {}", msg),
84 SnapshotError::Decompression(msg) => write!(f, "Failed to decode body: {}", msg),
85 }
86 }
87}
88
89impl std::error::Error for SnapshotError {}
90
91pub async fn call_test_server(server: &TestServer, request: AxumRequest<Body>) -> AxumTestResponse {
94 let (parts, body) = request.into_parts();
95
96 let mut path = parts.uri.path().to_string();
97 if let Some(query) = parts.uri.query()
98 && !query.is_empty()
99 {
100 path.push('?');
101 path.push_str(query);
102 }
103
104 let mut test_request = server.method(parts.method.clone(), &path);
105
106 for (name, value) in parts.headers.iter() {
107 test_request = test_request.add_header(name.clone(), value.clone());
108 }
109
110 let collected = body
111 .collect()
112 .await
113 .expect("failed to read request body for test dispatch");
114 let bytes = collected.to_bytes();
115 if !bytes.is_empty() {
116 test_request = test_request.bytes(bytes);
117 }
118
119 test_request.await
120}
121
122pub async fn snapshot_response(response: AxumTestResponse) -> Result<ResponseSnapshot, SnapshotError> {
124 let status = response.status_code().as_u16();
125
126 let mut headers = HashMap::new();
127 for (name, value) in response.headers() {
128 let header_value = value
129 .to_str()
130 .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
131 headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
132 }
133
134 let body = response.into_bytes();
135 let decoded_body = decode_body(&headers, body.to_vec())?;
136
137 Ok(ResponseSnapshot {
138 status,
139 headers,
140 body: decoded_body,
141 })
142}
143
144pub async fn snapshot_http_response(
146 response: axum::response::Response<Body>,
147) -> Result<ResponseSnapshot, SnapshotError> {
148 let (parts, body) = response.into_parts();
149 let status = parts.status.as_u16();
150
151 let mut headers = HashMap::new();
152 for (name, value) in parts.headers.iter() {
153 let header_value = value
154 .to_str()
155 .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
156 headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
157 }
158
159 let collected = body
160 .collect()
161 .await
162 .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
163 let bytes = collected.to_bytes();
164 let decoded_body = decode_body(&headers, bytes.to_vec())?;
165
166 Ok(ResponseSnapshot {
167 status,
168 headers,
169 body: decoded_body,
170 })
171}
172
173pub async fn snapshot_http_response_allow_body_errors(
179 response: axum::response::Response<Body>,
180) -> Result<ResponseSnapshot, SnapshotError> {
181 let (parts, mut body) = response.into_parts();
182 let status = parts.status.as_u16();
183
184 let mut headers = HashMap::new();
185 for (name, value) in parts.headers.iter() {
186 let header_value = value
187 .to_str()
188 .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
189 headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
190 }
191
192 let mut bytes = Vec::<u8>::new();
193 while let Some(frame_result) = body.frame().await {
194 match frame_result {
195 Ok(frame) => {
196 if let Ok(data) = frame.into_data() {
197 bytes.extend_from_slice(&data);
198 }
199 }
200 Err(_) => break,
201 }
202 }
203
204 let decoded_body = decode_body(&headers, bytes)?;
205
206 Ok(ResponseSnapshot {
207 status,
208 headers,
209 body: decoded_body,
210 })
211}
212
213fn decode_body(headers: &HashMap<String, String>, body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
214 let encoding = headers
215 .get("content-encoding")
216 .map(|value| value.trim().to_ascii_lowercase());
217
218 match encoding.as_deref() {
219 Some("gzip" | "x-gzip") => decode_gzip(body),
220 Some("br") => decode_brotli(body),
221 _ => Ok(body),
222 }
223}
224
225fn decode_gzip(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
226 let mut decoder = GzDecoder::new(Cursor::new(body));
227 let mut decoded_bytes = Vec::new();
228 decoder
229 .read_to_end(&mut decoded_bytes)
230 .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
231 Ok(decoded_bytes)
232}
233
234fn decode_brotli(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
235 let mut decoder = Decompressor::new(Cursor::new(body), 4096);
236 let mut decoded_bytes = Vec::new();
237 decoder
238 .read_to_end(&mut decoded_bytes)
239 .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
240 Ok(decoded_bytes)
241}
242
243pub struct WebSocketConnection {
248 inner: TestWebSocket,
249}
250
251impl WebSocketConnection {
252 pub fn new(inner: TestWebSocket) -> Self {
254 Self { inner }
255 }
256
257 pub async fn send_text(&mut self, text: impl std::fmt::Display) {
259 self.inner.send_text(text).await;
260 }
261
262 pub async fn send_json<T: serde::Serialize>(&mut self, value: &T) {
264 self.inner.send_json(value).await;
265 }
266
267 pub async fn send_message(&mut self, msg: WsMessage) {
269 self.inner.send_message(msg).await;
270 }
271
272 pub async fn receive_text(&mut self) -> String {
274 self.inner.receive_text().await
275 }
276
277 pub async fn receive_json<T: serde::de::DeserializeOwned>(&mut self) -> T {
279 self.inner.receive_json().await
280 }
281
282 pub async fn receive_bytes(&mut self) -> bytes::Bytes {
284 self.inner.receive_bytes().await
285 }
286
287 pub async fn receive_message(&mut self) -> WebSocketMessage {
289 let msg = self.inner.receive_message().await;
290 WebSocketMessage::from_ws_message(msg)
291 }
292
293 pub async fn close(self) {
295 self.inner.close().await;
296 }
297}
298
299#[derive(Debug, Clone)]
301pub enum WebSocketMessage {
302 Text(String),
304 Binary(Vec<u8>),
306 Close(Option<String>),
308 Ping(Vec<u8>),
310 Pong(Vec<u8>),
312}
313
314impl WebSocketMessage {
315 fn from_ws_message(msg: WsMessage) -> Self {
316 match msg {
317 WsMessage::Text(text) => WebSocketMessage::Text(text.to_string()),
318 WsMessage::Binary(data) => WebSocketMessage::Binary(data.to_vec()),
319 WsMessage::Close(frame) => WebSocketMessage::Close(frame.map(|f| f.reason.to_string())),
320 WsMessage::Ping(data) => WebSocketMessage::Ping(data.to_vec()),
321 WsMessage::Pong(data) => WebSocketMessage::Pong(data.to_vec()),
322 WsMessage::Frame(_) => WebSocketMessage::Close(None),
323 }
324 }
325
326 pub fn as_text(&self) -> Option<&str> {
328 match self {
329 WebSocketMessage::Text(text) => Some(text),
330 _ => None,
331 }
332 }
333
334 pub fn as_json(&self) -> Result<Value, String> {
336 match self {
337 WebSocketMessage::Text(text) => {
338 serde_json::from_str(text).map_err(|e| format!("Failed to parse JSON: {}", e))
339 }
340 _ => Err("Message is not text".to_string()),
341 }
342 }
343
344 pub fn as_binary(&self) -> Option<&[u8]> {
346 match self {
347 WebSocketMessage::Binary(data) => Some(data),
348 _ => None,
349 }
350 }
351
352 pub fn is_close(&self) -> bool {
354 matches!(self, WebSocketMessage::Close(_))
355 }
356}
357
358pub async fn connect_websocket(server: &TestServer, path: &str) -> WebSocketConnection {
360 let ws = server.get_websocket(path).await.into_websocket().await;
361 WebSocketConnection::new(ws)
362}
363
364#[derive(Debug)]
368pub struct SseStream {
369 body: String,
370 events: Vec<SseEvent>,
371}
372
373impl SseStream {
374 pub fn from_response(response: &ResponseSnapshot) -> Result<Self, String> {
376 let body = response
377 .text()
378 .map_err(|e| format!("Failed to read response body: {}", e))?;
379
380 let events = Self::parse_events(&body);
381
382 Ok(Self { body, events })
383 }
384
385 fn parse_events(body: &str) -> Vec<SseEvent> {
386 let mut events = Vec::new();
387 let lines: Vec<&str> = body.lines().collect();
388 let mut i = 0;
389
390 while i < lines.len() {
391 if lines[i].starts_with("data:") {
392 let data = lines[i].trim_start_matches("data:").trim().to_string();
393 events.push(SseEvent { data });
394 } else if lines[i].starts_with("data") {
395 let data = lines[i].trim_start_matches("data").trim().to_string();
396 if !data.is_empty() || lines[i].len() == 4 {
397 events.push(SseEvent { data });
398 }
399 }
400 i += 1;
401 }
402
403 events
404 }
405
406 pub fn events(&self) -> &[SseEvent] {
408 &self.events
409 }
410
411 pub fn body(&self) -> &str {
413 &self.body
414 }
415
416 pub fn events_as_json(&self) -> Result<Vec<Value>, String> {
418 self.events
419 .iter()
420 .map(|event| event.as_json())
421 .collect::<Result<Vec<_>, _>>()
422 }
423}
424
425#[derive(Debug, Clone)]
427pub struct SseEvent {
428 pub data: String,
430}
431
432impl SseEvent {
433 pub fn as_json(&self) -> Result<Value, String> {
435 serde_json::from_str(&self.data).map_err(|e| format!("Failed to parse JSON: {}", e))
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use axum::body::Body;
443 use axum::response::Response;
444 use std::io::Write;
445
446 #[test]
447 fn sse_stream_parses_multiple_events() {
448 let mut headers = HashMap::new();
449 headers.insert("content-type".to_string(), "text/event-stream".to_string());
450
451 let snapshot = ResponseSnapshot {
452 status: 200,
453 headers,
454 body: b"data: {\"id\": 1}\n\ndata: \"hello\"\n\n".to_vec(),
455 };
456
457 let stream = SseStream::from_response(&snapshot).expect("stream");
458 assert_eq!(stream.events().len(), 2);
459 assert_eq!(stream.events()[0].as_json().unwrap()["id"], serde_json::json!(1));
460 assert_eq!(stream.events()[1].data, "\"hello\"");
461 assert_eq!(stream.events_as_json().unwrap().len(), 2);
462 }
463
464 #[test]
465 fn sse_event_reports_invalid_json() {
466 let event = SseEvent {
467 data: "not-json".to_string(),
468 };
469 assert!(event.as_json().is_err());
470 }
471
472 #[test]
473 fn test_graphql_data_extraction() {
474 let mut headers = HashMap::new();
475 headers.insert("content-type".to_string(), "application/json".to_string());
476
477 let graphql_response = serde_json::json!({
478 "data": {
479 "user": {
480 "id": "1",
481 "name": "Alice"
482 }
483 }
484 });
485
486 let snapshot = ResponseSnapshot {
487 status: 200,
488 headers,
489 body: serde_json::to_vec(&graphql_response).unwrap(),
490 };
491
492 let data = snapshot.graphql_data().expect("data extraction");
493 assert_eq!(data["user"]["id"], "1");
494 assert_eq!(data["user"]["name"], "Alice");
495 }
496
497 #[test]
498 fn test_graphql_errors_extraction() {
499 let mut headers = HashMap::new();
500 headers.insert("content-type".to_string(), "application/json".to_string());
501
502 let graphql_response = serde_json::json!({
503 "errors": [
504 {
505 "message": "Field not found",
506 "path": ["user", "email"]
507 },
508 {
509 "message": "Unauthorized",
510 "extensions": { "code": "UNAUTHENTICATED" }
511 }
512 ]
513 });
514
515 let snapshot = ResponseSnapshot {
516 status: 400,
517 headers,
518 body: serde_json::to_vec(&graphql_response).unwrap(),
519 };
520
521 let errors = snapshot.graphql_errors().expect("errors extraction");
522 assert_eq!(errors.len(), 2);
523 assert_eq!(errors[0]["message"], "Field not found");
524 assert_eq!(errors[1]["message"], "Unauthorized");
525 }
526
527 #[test]
528 fn test_graphql_missing_data_field() {
529 let mut headers = HashMap::new();
530 headers.insert("content-type".to_string(), "application/json".to_string());
531
532 let graphql_response = serde_json::json!({
533 "errors": [{ "message": "Query failed" }]
534 });
535
536 let snapshot = ResponseSnapshot {
537 status: 400,
538 headers,
539 body: serde_json::to_vec(&graphql_response).unwrap(),
540 };
541
542 let result = snapshot.graphql_data();
543 assert!(result.is_err());
544 assert!(result.unwrap_err().to_string().contains("No 'data' field"));
545 }
546
547 #[test]
548 fn test_graphql_empty_errors() {
549 let mut headers = HashMap::new();
550 headers.insert("content-type".to_string(), "application/json".to_string());
551
552 let graphql_response = serde_json::json!({
553 "data": { "result": null }
554 });
555
556 let snapshot = ResponseSnapshot {
557 status: 200,
558 headers,
559 body: serde_json::to_vec(&graphql_response).unwrap(),
560 };
561
562 let errors = snapshot.graphql_errors().expect("errors extraction");
563 assert!(errors.is_empty());
564 }
565
566 fn gzip_bytes(input: &[u8]) -> Vec<u8> {
567 let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
568 encoder.write_all(input).expect("gzip write");
569 encoder.finish().expect("gzip finish")
570 }
571
572 fn brotli_bytes(input: &[u8]) -> Vec<u8> {
573 let mut encoder = brotli::CompressorWriter::new(Vec::new(), 4096, 5, 22);
574 encoder.write_all(input).expect("brotli write");
575 encoder.into_inner()
576 }
577
578 #[tokio::test]
579 async fn snapshot_http_response_decodes_gzip_body() {
580 let body = b"hello gzip";
581 let compressed = gzip_bytes(body);
582 let response = Response::builder()
583 .status(200)
584 .header("content-encoding", "gzip")
585 .body(Body::from(compressed))
586 .unwrap();
587
588 let snapshot = snapshot_http_response(response).await.expect("snapshot");
589 assert_eq!(snapshot.body, body);
590 }
591
592 #[tokio::test]
593 async fn snapshot_http_response_decodes_brotli_body() {
594 let body = b"hello brotli";
595 let compressed = brotli_bytes(body);
596 let response = Response::builder()
597 .status(200)
598 .header("content-encoding", "br")
599 .body(Body::from(compressed))
600 .unwrap();
601
602 let snapshot = snapshot_http_response(response).await.expect("snapshot");
603 assert_eq!(snapshot.body, body);
604 }
605
606 #[tokio::test]
607 async fn snapshot_http_response_leaves_plain_body() {
608 let body = b"plain";
609 let response = Response::builder()
610 .status(200)
611 .body(Body::from(body.as_slice()))
612 .unwrap();
613
614 let snapshot = snapshot_http_response(response).await.expect("snapshot");
615 assert_eq!(snapshot.body, body);
616 }
617}