Skip to main content

spikard_http/
testing.rs

1use axum::body::Body;
2use axum::http::Request as AxumRequest;
3use axum_test::{TestResponse as AxumTestResponse, TestServer, TestWebSocket, WsMessage};
4
5pub mod multipart;
6pub use multipart::{MultipartFilePart, build_multipart_body};
7
8pub mod form;
9
10pub mod test_client;
11pub use test_client::TestClient;
12
13use brotli::Decompressor;
14use flate2::read::GzDecoder;
15pub use form::encode_urlencoded_body;
16use http_body_util::BodyExt;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::io::{Cursor, Read};
20
21/// Snapshot of an Axum response used by higher-level language bindings.
22#[derive(Debug, Clone)]
23pub struct ResponseSnapshot {
24    /// HTTP status code.
25    pub status: u16,
26    /// Response headers (lowercase keys for predictable lookups).
27    pub headers: HashMap<String, String>,
28    /// Response body bytes (decoded for supported encodings).
29    pub body: Vec<u8>,
30}
31
32impl ResponseSnapshot {
33    /// Return response body as UTF-8 string.
34    pub fn text(&self) -> Result<String, std::string::FromUtf8Error> {
35        String::from_utf8(self.body.clone())
36    }
37
38    /// Parse response body as JSON.
39    pub fn json(&self) -> Result<Value, serde_json::Error> {
40        serde_json::from_slice(&self.body)
41    }
42
43    /// Lookup header by case-insensitive name.
44    pub fn header(&self, name: &str) -> Option<&str> {
45        self.headers.get(&name.to_ascii_lowercase()).map(|s| s.as_str())
46    }
47
48    /// Extract GraphQL data from response
49    pub fn graphql_data(&self) -> Result<Value, SnapshotError> {
50        let body: Value = serde_json::from_slice(&self.body)
51            .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
52
53        body.get("data")
54            .cloned()
55            .ok_or_else(|| SnapshotError::Decompression("No 'data' field in GraphQL response".to_string()))
56    }
57
58    /// Extract GraphQL errors from response
59    pub fn graphql_errors(&self) -> Result<Vec<Value>, SnapshotError> {
60        let body: Value = serde_json::from_slice(&self.body)
61            .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
62
63        Ok(body
64            .get("errors")
65            .and_then(|e| e.as_array())
66            .cloned()
67            .unwrap_or_default())
68    }
69}
70
71/// Possible errors while converting an Axum response into a snapshot.
72#[derive(Debug)]
73pub enum SnapshotError {
74    /// Response header could not be decoded to UTF-8.
75    InvalidHeader(String),
76    /// Body decompression failed.
77    Decompression(String),
78}
79
80impl std::fmt::Display for SnapshotError {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        match self {
83            SnapshotError::InvalidHeader(msg) => write!(f, "Invalid header: {}", msg),
84            SnapshotError::Decompression(msg) => write!(f, "Failed to decode body: {}", msg),
85        }
86    }
87}
88
89impl std::error::Error for SnapshotError {}
90
91/// Execute an HTTP request against an Axum [`TestServer`] by rehydrating it
92/// into the server's own [`axum_test::TestRequest`] builder.
93pub async fn call_test_server(server: &TestServer, request: AxumRequest<Body>) -> AxumTestResponse {
94    let (parts, body) = request.into_parts();
95
96    let mut path = parts.uri.path().to_string();
97    if let Some(query) = parts.uri.query()
98        && !query.is_empty()
99    {
100        path.push('?');
101        path.push_str(query);
102    }
103
104    let mut test_request = server.method(parts.method.clone(), &path);
105
106    for (name, value) in parts.headers.iter() {
107        test_request = test_request.add_header(name.clone(), value.clone());
108    }
109
110    let collected = body
111        .collect()
112        .await
113        .expect("failed to read request body for test dispatch");
114    let bytes = collected.to_bytes();
115    if !bytes.is_empty() {
116        test_request = test_request.bytes(bytes);
117    }
118
119    test_request.await
120}
121
122/// Convert an `AxumTestResponse` into a reusable [`ResponseSnapshot`].
123pub async fn snapshot_response(response: AxumTestResponse) -> Result<ResponseSnapshot, SnapshotError> {
124    let status = response.status_code().as_u16();
125
126    let mut headers = HashMap::new();
127    for (name, value) in response.headers() {
128        let header_value = value
129            .to_str()
130            .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
131        headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
132    }
133
134    let body = response.into_bytes();
135    let decoded_body = decode_body(&headers, body.to_vec())?;
136
137    Ok(ResponseSnapshot {
138        status,
139        headers,
140        body: decoded_body,
141    })
142}
143
144/// Convert an Axum response into a reusable [`ResponseSnapshot`].
145pub async fn snapshot_http_response(
146    response: axum::response::Response<Body>,
147) -> Result<ResponseSnapshot, SnapshotError> {
148    let (parts, body) = response.into_parts();
149    let status = parts.status.as_u16();
150
151    let mut headers = HashMap::new();
152    for (name, value) in parts.headers.iter() {
153        let header_value = value
154            .to_str()
155            .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
156        headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
157    }
158
159    let collected = body
160        .collect()
161        .await
162        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
163    let bytes = collected.to_bytes();
164    let decoded_body = decode_body(&headers, bytes.to_vec())?;
165
166    Ok(ResponseSnapshot {
167        status,
168        headers,
169        body: decoded_body,
170    })
171}
172
173/// Convert an Axum response into a reusable [`ResponseSnapshot`], allowing body stream errors.
174///
175/// This is useful for streaming responses where a producer might abort mid-stream (for example,
176/// a JavaScript async generator throwing). In those cases, collecting the whole body can fail
177/// with a "Stream error". This helper returns the bytes read up to the first stream error.
178pub async fn snapshot_http_response_allow_body_errors(
179    response: axum::response::Response<Body>,
180) -> Result<ResponseSnapshot, SnapshotError> {
181    let (parts, mut body) = response.into_parts();
182    let status = parts.status.as_u16();
183
184    let mut headers = HashMap::new();
185    for (name, value) in parts.headers.iter() {
186        let header_value = value
187            .to_str()
188            .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
189        headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
190    }
191
192    let mut bytes = Vec::<u8>::new();
193    while let Some(frame_result) = body.frame().await {
194        match frame_result {
195            Ok(frame) => {
196                if let Ok(data) = frame.into_data() {
197                    bytes.extend_from_slice(&data);
198                }
199            }
200            Err(_) => break,
201        }
202    }
203
204    let decoded_body = decode_body(&headers, bytes)?;
205
206    Ok(ResponseSnapshot {
207        status,
208        headers,
209        body: decoded_body,
210    })
211}
212
213fn decode_body(headers: &HashMap<String, String>, body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
214    let encoding = headers
215        .get("content-encoding")
216        .map(|value| value.trim().to_ascii_lowercase());
217
218    match encoding.as_deref() {
219        Some("gzip" | "x-gzip") => decode_gzip(body),
220        Some("br") => decode_brotli(body),
221        _ => Ok(body),
222    }
223}
224
225fn decode_gzip(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
226    let mut decoder = GzDecoder::new(Cursor::new(body));
227    let mut decoded_bytes = Vec::new();
228    decoder
229        .read_to_end(&mut decoded_bytes)
230        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
231    Ok(decoded_bytes)
232}
233
234fn decode_brotli(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
235    let mut decoder = Decompressor::new(Cursor::new(body), 4096);
236    let mut decoded_bytes = Vec::new();
237    decoder
238        .read_to_end(&mut decoded_bytes)
239        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
240    Ok(decoded_bytes)
241}
242
243/// WebSocket connection wrapper for testing.
244///
245/// Provides a simple interface for sending and receiving WebSocket messages
246/// during tests without needing a real network connection.
247pub struct WebSocketConnection {
248    inner: TestWebSocket,
249}
250
251impl WebSocketConnection {
252    /// Create a new WebSocket connection from an axum-test TestWebSocket.
253    pub fn new(inner: TestWebSocket) -> Self {
254        Self { inner }
255    }
256
257    /// Send a text message over the WebSocket.
258    pub async fn send_text(&mut self, text: impl std::fmt::Display) {
259        self.inner.send_text(text).await;
260    }
261
262    /// Send a JSON message over the WebSocket.
263    pub async fn send_json<T: serde::Serialize>(&mut self, value: &T) {
264        self.inner.send_json(value).await;
265    }
266
267    /// Send a raw WebSocket message.
268    pub async fn send_message(&mut self, msg: WsMessage) {
269        self.inner.send_message(msg).await;
270    }
271
272    /// Receive the next text message from the WebSocket.
273    pub async fn receive_text(&mut self) -> String {
274        self.inner.receive_text().await
275    }
276
277    /// Receive and parse a JSON message from the WebSocket.
278    pub async fn receive_json<T: serde::de::DeserializeOwned>(&mut self) -> T {
279        self.inner.receive_json().await
280    }
281
282    /// Receive raw bytes from the WebSocket.
283    pub async fn receive_bytes(&mut self) -> bytes::Bytes {
284        self.inner.receive_bytes().await
285    }
286
287    /// Receive the next raw message from the WebSocket.
288    pub async fn receive_message(&mut self) -> WebSocketMessage {
289        let msg = self.inner.receive_message().await;
290        WebSocketMessage::from_ws_message(msg)
291    }
292
293    /// Close the WebSocket connection.
294    pub async fn close(self) {
295        self.inner.close().await;
296    }
297}
298
299/// A WebSocket message that can be text or binary.
300#[derive(Debug, Clone)]
301pub enum WebSocketMessage {
302    /// A text message.
303    Text(String),
304    /// A binary message.
305    Binary(Vec<u8>),
306    /// A close message.
307    Close(Option<String>),
308    /// A ping message.
309    Ping(Vec<u8>),
310    /// A pong message.
311    Pong(Vec<u8>),
312}
313
314impl WebSocketMessage {
315    fn from_ws_message(msg: WsMessage) -> Self {
316        match msg {
317            WsMessage::Text(text) => WebSocketMessage::Text(text.to_string()),
318            WsMessage::Binary(data) => WebSocketMessage::Binary(data.to_vec()),
319            WsMessage::Close(frame) => WebSocketMessage::Close(frame.map(|f| f.reason.to_string())),
320            WsMessage::Ping(data) => WebSocketMessage::Ping(data.to_vec()),
321            WsMessage::Pong(data) => WebSocketMessage::Pong(data.to_vec()),
322            WsMessage::Frame(_) => WebSocketMessage::Close(None),
323        }
324    }
325
326    /// Get the message as text, if it's a text message.
327    pub fn as_text(&self) -> Option<&str> {
328        match self {
329            WebSocketMessage::Text(text) => Some(text),
330            _ => None,
331        }
332    }
333
334    /// Get the message as JSON, if it's a text message containing JSON.
335    pub fn as_json(&self) -> Result<Value, String> {
336        match self {
337            WebSocketMessage::Text(text) => {
338                serde_json::from_str(text).map_err(|e| format!("Failed to parse JSON: {}", e))
339            }
340            _ => Err("Message is not text".to_string()),
341        }
342    }
343
344    /// Get the message as binary, if it's a binary message.
345    pub fn as_binary(&self) -> Option<&[u8]> {
346        match self {
347            WebSocketMessage::Binary(data) => Some(data),
348            _ => None,
349        }
350    }
351
352    /// Check if this is a close message.
353    pub fn is_close(&self) -> bool {
354        matches!(self, WebSocketMessage::Close(_))
355    }
356}
357
358/// Connect to a WebSocket endpoint on the test server.
359pub async fn connect_websocket(server: &TestServer, path: &str) -> WebSocketConnection {
360    let ws = server.get_websocket(path).await.into_websocket().await;
361    WebSocketConnection::new(ws)
362}
363
364/// Server-Sent Events (SSE) stream for testing.
365///
366/// Wraps a response body and provides methods to parse SSE events.
367#[derive(Debug)]
368pub struct SseStream {
369    body: String,
370    events: Vec<SseEvent>,
371}
372
373impl SseStream {
374    /// Create a new SSE stream from a response.
375    pub fn from_response(response: &ResponseSnapshot) -> Result<Self, String> {
376        let body = response
377            .text()
378            .map_err(|e| format!("Failed to read response body: {}", e))?;
379
380        let events = Self::parse_events(&body);
381
382        Ok(Self { body, events })
383    }
384
385    fn parse_events(body: &str) -> Vec<SseEvent> {
386        let mut events = Vec::new();
387        let lines: Vec<&str> = body.lines().collect();
388        let mut i = 0;
389
390        while i < lines.len() {
391            if lines[i].starts_with("data:") {
392                let data = lines[i].trim_start_matches("data:").trim().to_string();
393                events.push(SseEvent { data });
394            } else if lines[i].starts_with("data") {
395                let data = lines[i].trim_start_matches("data").trim().to_string();
396                if !data.is_empty() || lines[i].len() == 4 {
397                    events.push(SseEvent { data });
398                }
399            }
400            i += 1;
401        }
402
403        events
404    }
405
406    /// Get all events from the stream.
407    pub fn events(&self) -> &[SseEvent] {
408        &self.events
409    }
410
411    /// Get the raw body of the SSE response.
412    pub fn body(&self) -> &str {
413        &self.body
414    }
415
416    /// Get events as JSON values.
417    pub fn events_as_json(&self) -> Result<Vec<Value>, String> {
418        self.events
419            .iter()
420            .map(|event| event.as_json())
421            .collect::<Result<Vec<_>, _>>()
422    }
423}
424
425/// A single Server-Sent Event.
426#[derive(Debug, Clone)]
427pub struct SseEvent {
428    /// The data field of the event.
429    pub data: String,
430}
431
432impl SseEvent {
433    /// Parse the event data as JSON.
434    pub fn as_json(&self) -> Result<Value, String> {
435        serde_json::from_str(&self.data).map_err(|e| format!("Failed to parse JSON: {}", e))
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use axum::body::Body;
443    use axum::response::Response;
444    use std::io::Write;
445
446    #[test]
447    fn sse_stream_parses_multiple_events() {
448        let mut headers = HashMap::new();
449        headers.insert("content-type".to_string(), "text/event-stream".to_string());
450
451        let snapshot = ResponseSnapshot {
452            status: 200,
453            headers,
454            body: b"data: {\"id\": 1}\n\ndata: \"hello\"\n\n".to_vec(),
455        };
456
457        let stream = SseStream::from_response(&snapshot).expect("stream");
458        assert_eq!(stream.events().len(), 2);
459        assert_eq!(stream.events()[0].as_json().unwrap()["id"], serde_json::json!(1));
460        assert_eq!(stream.events()[1].data, "\"hello\"");
461        assert_eq!(stream.events_as_json().unwrap().len(), 2);
462    }
463
464    #[test]
465    fn sse_event_reports_invalid_json() {
466        let event = SseEvent {
467            data: "not-json".to_string(),
468        };
469        assert!(event.as_json().is_err());
470    }
471
472    #[test]
473    fn test_graphql_data_extraction() {
474        let mut headers = HashMap::new();
475        headers.insert("content-type".to_string(), "application/json".to_string());
476
477        let graphql_response = serde_json::json!({
478            "data": {
479                "user": {
480                    "id": "1",
481                    "name": "Alice"
482                }
483            }
484        });
485
486        let snapshot = ResponseSnapshot {
487            status: 200,
488            headers,
489            body: serde_json::to_vec(&graphql_response).unwrap(),
490        };
491
492        let data = snapshot.graphql_data().expect("data extraction");
493        assert_eq!(data["user"]["id"], "1");
494        assert_eq!(data["user"]["name"], "Alice");
495    }
496
497    #[test]
498    fn test_graphql_errors_extraction() {
499        let mut headers = HashMap::new();
500        headers.insert("content-type".to_string(), "application/json".to_string());
501
502        let graphql_response = serde_json::json!({
503            "errors": [
504                {
505                    "message": "Field not found",
506                    "path": ["user", "email"]
507                },
508                {
509                    "message": "Unauthorized",
510                    "extensions": { "code": "UNAUTHENTICATED" }
511                }
512            ]
513        });
514
515        let snapshot = ResponseSnapshot {
516            status: 400,
517            headers,
518            body: serde_json::to_vec(&graphql_response).unwrap(),
519        };
520
521        let errors = snapshot.graphql_errors().expect("errors extraction");
522        assert_eq!(errors.len(), 2);
523        assert_eq!(errors[0]["message"], "Field not found");
524        assert_eq!(errors[1]["message"], "Unauthorized");
525    }
526
527    #[test]
528    fn test_graphql_missing_data_field() {
529        let mut headers = HashMap::new();
530        headers.insert("content-type".to_string(), "application/json".to_string());
531
532        let graphql_response = serde_json::json!({
533            "errors": [{ "message": "Query failed" }]
534        });
535
536        let snapshot = ResponseSnapshot {
537            status: 400,
538            headers,
539            body: serde_json::to_vec(&graphql_response).unwrap(),
540        };
541
542        let result = snapshot.graphql_data();
543        assert!(result.is_err());
544        assert!(result.unwrap_err().to_string().contains("No 'data' field"));
545    }
546
547    #[test]
548    fn test_graphql_empty_errors() {
549        let mut headers = HashMap::new();
550        headers.insert("content-type".to_string(), "application/json".to_string());
551
552        let graphql_response = serde_json::json!({
553            "data": { "result": null }
554        });
555
556        let snapshot = ResponseSnapshot {
557            status: 200,
558            headers,
559            body: serde_json::to_vec(&graphql_response).unwrap(),
560        };
561
562        let errors = snapshot.graphql_errors().expect("errors extraction");
563        assert!(errors.is_empty());
564    }
565
566    fn gzip_bytes(input: &[u8]) -> Vec<u8> {
567        let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
568        encoder.write_all(input).expect("gzip write");
569        encoder.finish().expect("gzip finish")
570    }
571
572    fn brotli_bytes(input: &[u8]) -> Vec<u8> {
573        let mut encoder = brotli::CompressorWriter::new(Vec::new(), 4096, 5, 22);
574        encoder.write_all(input).expect("brotli write");
575        encoder.into_inner()
576    }
577
578    #[tokio::test]
579    async fn snapshot_http_response_decodes_gzip_body() {
580        let body = b"hello gzip";
581        let compressed = gzip_bytes(body);
582        let response = Response::builder()
583            .status(200)
584            .header("content-encoding", "gzip")
585            .body(Body::from(compressed))
586            .unwrap();
587
588        let snapshot = snapshot_http_response(response).await.expect("snapshot");
589        assert_eq!(snapshot.body, body);
590    }
591
592    #[tokio::test]
593    async fn snapshot_http_response_decodes_brotli_body() {
594        let body = b"hello brotli";
595        let compressed = brotli_bytes(body);
596        let response = Response::builder()
597            .status(200)
598            .header("content-encoding", "br")
599            .body(Body::from(compressed))
600            .unwrap();
601
602        let snapshot = snapshot_http_response(response).await.expect("snapshot");
603        assert_eq!(snapshot.body, body);
604    }
605
606    #[tokio::test]
607    async fn snapshot_http_response_leaves_plain_body() {
608        let body = b"plain";
609        let response = Response::builder()
610            .status(200)
611            .body(Body::from(body.as_slice()))
612            .unwrap();
613
614        let snapshot = snapshot_http_response(response).await.expect("snapshot");
615        assert_eq!(snapshot.body, body);
616    }
617}