Skip to main content

spikard_http/
testing.rs

1use axum::body::Body;
2use axum::http::Request as AxumRequest;
3use axum_test::{TestResponse as AxumTestResponse, TestServer, TestWebSocket, WsMessage};
4
5pub mod multipart;
6pub use multipart::{MultipartFilePart, build_multipart_body};
7
8pub mod form;
9
10pub mod test_client;
11pub use test_client::{GraphQLSubscriptionSnapshot, TestClient};
12
13use brotli::Decompressor;
14use flate2::read::GzDecoder;
15pub use form::encode_urlencoded_body;
16use http_body_util::BodyExt;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::io::{Cursor, Read};
20
21/// Snapshot of an Axum response used by higher-level language bindings.
22#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
23pub struct ResponseSnapshot {
24    /// HTTP status code.
25    pub status: u16,
26    /// Response headers (lowercase keys for predictable lookups).
27    pub headers: HashMap<String, String>,
28    /// Response body bytes (decoded for supported encodings).
29    pub body: Vec<u8>,
30}
31
32impl ResponseSnapshot {
33    /// Return response body as UTF-8 string.
34    pub fn text(&self) -> Result<String, std::string::FromUtf8Error> {
35        String::from_utf8(self.body.clone())
36    }
37
38    /// Parse response body as JSON.
39    pub fn json(&self) -> Result<Value, serde_json::Error> {
40        serde_json::from_slice(&self.body)
41    }
42
43    /// Lookup header by case-insensitive name.
44    pub fn header(&self, name: &str) -> Option<&str> {
45        self.headers.get(&name.to_ascii_lowercase()).map(|s| s.as_str())
46    }
47
48    /// Extract GraphQL data from response
49    pub fn graphql_data(&self) -> Result<Value, SnapshotError> {
50        let body: Value = serde_json::from_slice(&self.body)
51            .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
52
53        body.get("data")
54            .cloned()
55            .ok_or_else(|| SnapshotError::Decompression("No 'data' field in GraphQL response".to_string()))
56    }
57
58    /// Extract GraphQL errors from response
59    pub fn graphql_errors(&self) -> Result<Vec<Value>, SnapshotError> {
60        let body: Value = serde_json::from_slice(&self.body)
61            .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
62
63        Ok(body
64            .get("errors")
65            .and_then(|e| e.as_array())
66            .cloned()
67            .unwrap_or_default())
68    }
69}
70
71/// Possible errors while converting an Axum response into a snapshot.
72#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
73pub enum SnapshotError {
74    /// Response header could not be decoded to UTF-8.
75    InvalidHeader(String),
76    /// Body decompression failed.
77    Decompression(String),
78}
79
80impl std::fmt::Display for SnapshotError {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        match self {
83            SnapshotError::InvalidHeader(msg) => write!(f, "Invalid header: {}", msg),
84            SnapshotError::Decompression(msg) => write!(f, "Failed to decode body: {}", msg),
85        }
86    }
87}
88
89impl std::error::Error for SnapshotError {}
90
91impl Default for SnapshotError {
92    fn default() -> Self {
93        Self::InvalidHeader(String::new())
94    }
95}
96
97/// Execute an HTTP request against an Axum [`TestServer`] by rehydrating it
98/// into the server's own [`axum_test::TestRequest`] builder.
99pub async fn call_test_server(server: &TestServer, request: AxumRequest<Body>) -> AxumTestResponse {
100    let (parts, body) = request.into_parts();
101
102    let mut path = parts.uri.path().to_string();
103    if let Some(query) = parts.uri.query()
104        && !query.is_empty()
105    {
106        path.push('?');
107        path.push_str(query);
108    }
109
110    let mut test_request = server.method(parts.method.clone(), &path);
111
112    for (name, value) in parts.headers.iter() {
113        test_request = test_request.add_header(name.clone(), value.clone());
114    }
115
116    let collected = body
117        .collect()
118        .await
119        .expect("failed to read request body for test dispatch");
120    let bytes = collected.to_bytes();
121    if !bytes.is_empty() {
122        test_request = test_request.bytes(bytes);
123    }
124
125    test_request.await
126}
127
128/// Convert an `AxumTestResponse` into a reusable [`ResponseSnapshot`].
129pub async fn snapshot_response(response: AxumTestResponse) -> Result<ResponseSnapshot, SnapshotError> {
130    let status = response.status_code().as_u16();
131
132    let mut headers = HashMap::new();
133    for (name, value) in response.headers() {
134        let header_value = value
135            .to_str()
136            .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
137        headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
138    }
139
140    let body = response.into_bytes();
141    let decoded_body = decode_body(&headers, body.to_vec())?;
142
143    Ok(ResponseSnapshot {
144        status,
145        headers,
146        body: decoded_body,
147    })
148}
149
150/// Convert an Axum response into a reusable [`ResponseSnapshot`].
151pub async fn snapshot_http_response(
152    response: axum::response::Response<Body>,
153) -> Result<ResponseSnapshot, SnapshotError> {
154    let (parts, body) = response.into_parts();
155    let status = parts.status.as_u16();
156
157    let mut headers = HashMap::new();
158    for (name, value) in parts.headers.iter() {
159        let header_value = value
160            .to_str()
161            .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
162        headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
163    }
164
165    let collected = body
166        .collect()
167        .await
168        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
169    let bytes = collected.to_bytes();
170    let decoded_body = decode_body(&headers, bytes.to_vec())?;
171
172    Ok(ResponseSnapshot {
173        status,
174        headers,
175        body: decoded_body,
176    })
177}
178
179/// Convert an Axum response into a reusable [`ResponseSnapshot`], allowing body stream errors.
180///
181/// This is useful for streaming responses where a producer might abort mid-stream (for example,
182/// a JavaScript async generator throwing). In those cases, collecting the whole body can fail
183/// with a "Stream error". This helper returns the bytes read up to the first stream error.
184pub async fn snapshot_http_response_allow_body_errors(
185    response: axum::response::Response<Body>,
186) -> Result<ResponseSnapshot, SnapshotError> {
187    let (parts, mut body) = response.into_parts();
188    let status = parts.status.as_u16();
189
190    let mut headers = HashMap::new();
191    for (name, value) in parts.headers.iter() {
192        let header_value = value
193            .to_str()
194            .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
195        headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
196    }
197
198    let mut bytes = Vec::<u8>::new();
199    while let Some(frame_result) = body.frame().await {
200        match frame_result {
201            Ok(frame) => {
202                if let Ok(data) = frame.into_data() {
203                    bytes.extend_from_slice(&data);
204                }
205            }
206            Err(_) => break,
207        }
208    }
209
210    let decoded_body = decode_body(&headers, bytes)?;
211
212    Ok(ResponseSnapshot {
213        status,
214        headers,
215        body: decoded_body,
216    })
217}
218
219fn decode_body(headers: &HashMap<String, String>, body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
220    let encoding = headers
221        .get("content-encoding")
222        .map(|value| value.trim().to_ascii_lowercase());
223
224    match encoding.as_deref() {
225        Some("gzip" | "x-gzip") => decode_gzip(body),
226        Some("br") => decode_brotli(body),
227        _ => Ok(body),
228    }
229}
230
231fn decode_gzip(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
232    let mut decoder = GzDecoder::new(Cursor::new(body));
233    let mut decoded_bytes = Vec::new();
234    decoder
235        .read_to_end(&mut decoded_bytes)
236        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
237    Ok(decoded_bytes)
238}
239
240fn decode_brotli(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
241    let mut decoder = Decompressor::new(Cursor::new(body), 4096);
242    let mut decoded_bytes = Vec::new();
243    decoder
244        .read_to_end(&mut decoded_bytes)
245        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
246    Ok(decoded_bytes)
247}
248
249/// WebSocket connection wrapper for testing.
250///
251/// Provides a simple interface for sending and receiving WebSocket messages
252/// during tests without needing a real network connection.
253pub struct WebSocketConnection {
254    inner: TestWebSocket,
255}
256
257impl WebSocketConnection {
258    /// Create a new WebSocket connection from an axum-test TestWebSocket.
259    pub fn new(inner: TestWebSocket) -> Self {
260        Self { inner }
261    }
262
263    /// Send a text message over the WebSocket.
264    pub async fn send_text(&mut self, text: impl std::fmt::Display) {
265        self.inner.send_text(text).await;
266    }
267
268    /// Send a JSON message over the WebSocket.
269    pub async fn send_json<T: serde::Serialize>(&mut self, value: &T) {
270        self.inner.send_json(value).await;
271    }
272
273    /// Send a raw WebSocket message.
274    pub async fn send_message(&mut self, msg: WsMessage) {
275        self.inner.send_message(msg).await;
276    }
277
278    /// Receive the next text message from the WebSocket.
279    pub async fn receive_text(&mut self) -> String {
280        self.inner.receive_text().await
281    }
282
283    /// Receive and parse a JSON message from the WebSocket.
284    pub async fn receive_json<T: serde::de::DeserializeOwned>(&mut self) -> T {
285        self.inner.receive_json().await
286    }
287
288    /// Receive raw bytes from the WebSocket.
289    pub async fn receive_bytes(&mut self) -> bytes::Bytes {
290        self.inner.receive_bytes().await
291    }
292
293    /// Receive the next raw message from the WebSocket.
294    pub async fn receive_message(&mut self) -> WebSocketMessage {
295        let msg = self.inner.receive_message().await;
296        WebSocketMessage::from_ws_message(msg)
297    }
298
299    /// Close the WebSocket connection with code 1000 (Normal Closure) and no reason.
300    pub async fn close(self) -> Result<(), String> {
301        self.close_with(1000, None).await
302    }
303
304    /// Close the WebSocket connection with a specific RFC 6455 close code and optional reason.
305    ///
306    /// Common codes: 1000 Normal Closure, 1001 Going Away, 1002 Protocol Error.
307    pub async fn close_with(mut self, code: u16, reason: Option<String>) -> Result<(), String> {
308        use axum_test::WsMessage;
309        use tungstenite::protocol::frame::CloseFrame;
310        use tungstenite::protocol::frame::coding::CloseCode;
311
312        let frame = CloseFrame {
313            code: CloseCode::from(code),
314            reason: reason.unwrap_or_default().into(),
315        };
316        self.inner.send_message(WsMessage::Close(Some(frame))).await;
317        Ok(())
318    }
319}
320
321/// A WebSocket message that can be text or binary.
322#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
323pub enum WebSocketMessage {
324    /// A text message.
325    Text(String),
326    /// A binary message.
327    Binary(Vec<u8>),
328    /// A close message with a numeric close code (RFC 6455) and optional reason text.
329    ///
330    /// Common codes: 1000 Normal Closure, 1001 Going Away, 1005 No Status Received,
331    /// 1006 Abnormal Closure.
332    Close {
333        /// RFC 6455 close code.
334        code: u16,
335        /// Optional human-readable reason string.
336        reason: Option<String>,
337    },
338    /// A ping message.
339    Ping(Vec<u8>),
340    /// A pong message.
341    Pong(Vec<u8>),
342}
343
344impl Default for WebSocketMessage {
345    fn default() -> Self {
346        Self::Text(String::new())
347    }
348}
349
350impl WebSocketMessage {
351    fn from_ws_message(msg: WsMessage) -> Self {
352        match msg {
353            WsMessage::Text(text) => WebSocketMessage::Text(text.to_string()),
354            WsMessage::Binary(data) => WebSocketMessage::Binary(data.to_vec()),
355            WsMessage::Close(Some(frame)) => {
356                let code: u16 = frame.code.into();
357                let reason_str = frame.reason.to_string();
358                let reason = if reason_str.is_empty() { None } else { Some(reason_str) };
359                WebSocketMessage::Close { code, reason }
360            }
361            // RFC 6455 §7.1.5: no close frame means no status code — use 1005
362            WsMessage::Close(None) | WsMessage::Frame(_) => WebSocketMessage::Close {
363                code: 1005,
364                reason: None,
365            },
366            WsMessage::Ping(data) => WebSocketMessage::Ping(data.to_vec()),
367            WsMessage::Pong(data) => WebSocketMessage::Pong(data.to_vec()),
368        }
369    }
370
371    /// Get the message as text, if it's a text message.
372    pub fn as_text(&self) -> Option<&str> {
373        match self {
374            WebSocketMessage::Text(text) => Some(text),
375            _ => None,
376        }
377    }
378
379    /// Get the message as JSON, if it's a text message containing JSON.
380    pub fn as_json(&self) -> Result<Value, String> {
381        match self {
382            WebSocketMessage::Text(text) => {
383                serde_json::from_str(text).map_err(|e| format!("Failed to parse JSON: {}", e))
384            }
385            _ => Err("Message is not text".to_string()),
386        }
387    }
388
389    /// Get the message as binary, if it's a binary message.
390    pub fn as_binary(&self) -> Option<&[u8]> {
391        match self {
392            WebSocketMessage::Binary(data) => Some(data),
393            _ => None,
394        }
395    }
396
397    /// Check if this is a close message.
398    pub fn is_close(&self) -> bool {
399        matches!(self, WebSocketMessage::Close { .. })
400    }
401
402    /// Return the close code if this is a close message.
403    pub fn close_code(&self) -> Option<u16> {
404        match self {
405            WebSocketMessage::Close { code, .. } => Some(*code),
406            _ => None,
407        }
408    }
409
410    /// Return the close reason if this is a close message with a reason.
411    pub fn close_reason(&self) -> Option<&str> {
412        match self {
413            WebSocketMessage::Close { reason, .. } => reason.as_deref(),
414            _ => None,
415        }
416    }
417}
418
419/// Connect to a WebSocket endpoint on the test server.
420pub async fn connect_websocket(server: &TestServer, path: &str) -> WebSocketConnection {
421    let ws = server.get_websocket(path).await.into_websocket().await;
422    WebSocketConnection::new(ws)
423}
424
425/// Server-Sent Events (SSE) stream for testing.
426///
427/// Wraps a response body and provides methods to parse SSE events.
428#[derive(Debug)]
429pub struct SseStream {
430    body: String,
431    events: Vec<SseEvent>,
432}
433
434impl SseStream {
435    /// Create a new SSE stream from a response.
436    pub fn from_response(response: &ResponseSnapshot) -> Result<Self, String> {
437        let body = response
438            .text()
439            .map_err(|e| format!("Failed to read response body: {}", e))?;
440
441        let events = Self::parse_events(&body);
442
443        Ok(Self { body, events })
444    }
445
446    fn parse_events(body: &str) -> Vec<SseEvent> {
447        let mut events = Vec::new();
448        let lines: Vec<&str> = body.lines().collect();
449        let mut i = 0;
450
451        while i < lines.len() {
452            if lines[i].starts_with("data:") {
453                let data = lines[i].trim_start_matches("data:").trim().to_string();
454                events.push(SseEvent { data });
455            } else if lines[i].starts_with("data") {
456                let data = lines[i].trim_start_matches("data").trim().to_string();
457                if !data.is_empty() || lines[i].len() == 4 {
458                    events.push(SseEvent { data });
459                }
460            }
461            i += 1;
462        }
463
464        events
465    }
466
467    /// Get all events from the stream.
468    pub fn events(&self) -> &[SseEvent] {
469        &self.events
470    }
471
472    /// Get the raw body of the SSE response.
473    pub fn body(&self) -> &str {
474        &self.body
475    }
476
477    /// Get events as JSON values.
478    pub fn events_as_json(&self) -> Result<Vec<Value>, String> {
479        self.events
480            .iter()
481            .map(|event| event.as_json())
482            .collect::<Result<Vec<_>, _>>()
483    }
484}
485
486/// A single Server-Sent Event.
487#[derive(Debug, Clone)]
488pub struct SseEvent {
489    /// The data field of the event.
490    pub data: String,
491}
492
493impl SseEvent {
494    /// Parse the event data as JSON.
495    pub fn as_json(&self) -> Result<Value, String> {
496        serde_json::from_str(&self.data).map_err(|e| format!("Failed to parse JSON: {}", e))
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use axum::body::Body;
504    use axum::response::Response;
505    use std::io::Write;
506
507    #[test]
508    fn sse_stream_parses_multiple_events() {
509        let mut headers = HashMap::new();
510        headers.insert("content-type".to_string(), "text/event-stream".to_string());
511
512        let snapshot = ResponseSnapshot {
513            status: 200,
514            headers,
515            body: b"data: {\"id\": 1}\n\ndata: \"hello\"\n\n".to_vec(),
516        };
517
518        let stream = SseStream::from_response(&snapshot).expect("stream");
519        assert_eq!(stream.events().len(), 2);
520        assert_eq!(stream.events()[0].as_json().unwrap()["id"], serde_json::json!(1));
521        assert_eq!(stream.events()[1].data, "\"hello\"");
522        assert_eq!(stream.events_as_json().unwrap().len(), 2);
523    }
524
525    #[test]
526    fn sse_event_reports_invalid_json() {
527        let event = SseEvent {
528            data: "not-json".to_string(),
529        };
530        assert!(event.as_json().is_err());
531    }
532
533    #[test]
534    fn test_graphql_data_extraction() {
535        let mut headers = HashMap::new();
536        headers.insert("content-type".to_string(), "application/json".to_string());
537
538        let graphql_response = serde_json::json!({
539            "data": {
540                "user": {
541                    "id": "1",
542                    "name": "Alice"
543                }
544            }
545        });
546
547        let snapshot = ResponseSnapshot {
548            status: 200,
549            headers,
550            body: serde_json::to_vec(&graphql_response).unwrap(),
551        };
552
553        let data = snapshot.graphql_data().expect("data extraction");
554        assert_eq!(data["user"]["id"], "1");
555        assert_eq!(data["user"]["name"], "Alice");
556    }
557
558    #[test]
559    fn test_graphql_errors_extraction() {
560        let mut headers = HashMap::new();
561        headers.insert("content-type".to_string(), "application/json".to_string());
562
563        let graphql_response = serde_json::json!({
564            "errors": [
565                {
566                    "message": "Field not found",
567                    "path": ["user", "email"]
568                },
569                {
570                    "message": "Unauthorized",
571                    "extensions": { "code": "UNAUTHENTICATED" }
572                }
573            ]
574        });
575
576        let snapshot = ResponseSnapshot {
577            status: 400,
578            headers,
579            body: serde_json::to_vec(&graphql_response).unwrap(),
580        };
581
582        let errors = snapshot.graphql_errors().expect("errors extraction");
583        assert_eq!(errors.len(), 2);
584        assert_eq!(errors[0]["message"], "Field not found");
585        assert_eq!(errors[1]["message"], "Unauthorized");
586    }
587
588    #[test]
589    fn test_graphql_missing_data_field() {
590        let mut headers = HashMap::new();
591        headers.insert("content-type".to_string(), "application/json".to_string());
592
593        let graphql_response = serde_json::json!({
594            "errors": [{ "message": "Query failed" }]
595        });
596
597        let snapshot = ResponseSnapshot {
598            status: 400,
599            headers,
600            body: serde_json::to_vec(&graphql_response).unwrap(),
601        };
602
603        let result = snapshot.graphql_data();
604        assert!(result.is_err());
605        assert!(result.unwrap_err().to_string().contains("No 'data' field"));
606    }
607
608    #[test]
609    fn test_graphql_empty_errors() {
610        let mut headers = HashMap::new();
611        headers.insert("content-type".to_string(), "application/json".to_string());
612
613        let graphql_response = serde_json::json!({
614            "data": { "result": null }
615        });
616
617        let snapshot = ResponseSnapshot {
618            status: 200,
619            headers,
620            body: serde_json::to_vec(&graphql_response).unwrap(),
621        };
622
623        let errors = snapshot.graphql_errors().expect("errors extraction");
624        assert!(errors.is_empty());
625    }
626
627    fn gzip_bytes(input: &[u8]) -> Vec<u8> {
628        let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
629        encoder.write_all(input).expect("gzip write");
630        encoder.finish().expect("gzip finish")
631    }
632
633    fn brotli_bytes(input: &[u8]) -> Vec<u8> {
634        let mut encoder = brotli::CompressorWriter::new(Vec::new(), 4096, 5, 22);
635        encoder.write_all(input).expect("brotli write");
636        encoder.into_inner()
637    }
638
639    #[tokio::test]
640    async fn snapshot_http_response_decodes_gzip_body() {
641        let body = b"hello gzip";
642        let compressed = gzip_bytes(body);
643        let response = Response::builder()
644            .status(200)
645            .header("content-encoding", "gzip")
646            .body(Body::from(compressed))
647            .unwrap();
648
649        let snapshot = snapshot_http_response(response).await.expect("snapshot");
650        assert_eq!(snapshot.body, body);
651    }
652
653    #[tokio::test]
654    async fn snapshot_http_response_decodes_brotli_body() {
655        let body = b"hello brotli";
656        let compressed = brotli_bytes(body);
657        let response = Response::builder()
658            .status(200)
659            .header("content-encoding", "br")
660            .body(Body::from(compressed))
661            .unwrap();
662
663        let snapshot = snapshot_http_response(response).await.expect("snapshot");
664        assert_eq!(snapshot.body, body);
665    }
666
667    #[tokio::test]
668    async fn snapshot_http_response_leaves_plain_body() {
669        let body = b"plain";
670        let response = Response::builder()
671            .status(200)
672            .body(Body::from(body.as_slice()))
673            .unwrap();
674
675        let snapshot = snapshot_http_response(response).await.expect("snapshot");
676        assert_eq!(snapshot.body, body);
677    }
678}