spikard_http/
testing.rs

1use axum::body::Body;
2use axum::http::Request as AxumRequest;
3use axum_test::{TestResponse as AxumTestResponse, TestServer, TestWebSocket, WsMessage};
4
5pub mod multipart;
6pub use multipart::{MultipartFilePart, build_multipart_body};
7
8pub mod form;
9
10pub mod test_client;
11pub use test_client::TestClient;
12
13use brotli::Decompressor;
14use flate2::read::GzDecoder;
15pub use form::encode_urlencoded_body;
16use http_body_util::BodyExt;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::io::{Cursor, Read};
20
21/// Snapshot of an Axum response used by higher-level language bindings.
22#[derive(Debug, Clone)]
23pub struct ResponseSnapshot {
24    /// HTTP status code.
25    pub status: u16,
26    /// Response headers (lowercase keys for predictable lookups).
27    pub headers: HashMap<String, String>,
28    /// Response body bytes (decoded for supported encodings).
29    pub body: Vec<u8>,
30}
31
32impl ResponseSnapshot {
33    /// Return response body as UTF-8 string.
34    pub fn text(&self) -> Result<String, std::string::FromUtf8Error> {
35        String::from_utf8(self.body.clone())
36    }
37
38    /// Parse response body as JSON.
39    pub fn json(&self) -> Result<Value, serde_json::Error> {
40        serde_json::from_slice(&self.body)
41    }
42
43    /// Lookup header by case-insensitive name.
44    pub fn header(&self, name: &str) -> Option<&str> {
45        self.headers.get(&name.to_ascii_lowercase()).map(|s| s.as_str())
46    }
47
48    /// Extract GraphQL data from response
49    pub fn graphql_data(&self) -> Result<Value, SnapshotError> {
50        let body: Value = serde_json::from_slice(&self.body)
51            .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
52
53        body.get("data")
54            .cloned()
55            .ok_or_else(|| SnapshotError::Decompression("No 'data' field in GraphQL response".to_string()))
56    }
57
58    /// Extract GraphQL errors from response
59    pub fn graphql_errors(&self) -> Result<Vec<Value>, SnapshotError> {
60        let body: Value = serde_json::from_slice(&self.body)
61            .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
62
63        Ok(body
64            .get("errors")
65            .and_then(|e| e.as_array())
66            .cloned()
67            .unwrap_or_default())
68    }
69}
70
71/// Possible errors while converting an Axum response into a snapshot.
72#[derive(Debug)]
73pub enum SnapshotError {
74    /// Response header could not be decoded to UTF-8.
75    InvalidHeader(String),
76    /// Body decompression failed.
77    Decompression(String),
78}
79
80impl std::fmt::Display for SnapshotError {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        match self {
83            SnapshotError::InvalidHeader(msg) => write!(f, "Invalid header: {}", msg),
84            SnapshotError::Decompression(msg) => write!(f, "Failed to decode body: {}", msg),
85        }
86    }
87}
88
89impl std::error::Error for SnapshotError {}
90
91/// Execute an HTTP request against an Axum [`TestServer`] by rehydrating it
92/// into the server's own [`axum_test::TestRequest`] builder.
93pub async fn call_test_server(server: &TestServer, request: AxumRequest<Body>) -> AxumTestResponse {
94    let (parts, body) = request.into_parts();
95
96    let mut path = parts.uri.path().to_string();
97    if let Some(query) = parts.uri.query()
98        && !query.is_empty()
99    {
100        path.push('?');
101        path.push_str(query);
102    }
103
104    let mut test_request = server.method(parts.method.clone(), &path);
105
106    for (name, value) in parts.headers.iter() {
107        test_request = test_request.add_header(name.clone(), value.clone());
108    }
109
110    let collected = body
111        .collect()
112        .await
113        .expect("failed to read request body for test dispatch");
114    let bytes = collected.to_bytes();
115    if !bytes.is_empty() {
116        test_request = test_request.bytes(bytes);
117    }
118
119    test_request.await
120}
121
122/// Convert an `AxumTestResponse` into a reusable [`ResponseSnapshot`].
123pub async fn snapshot_response(response: AxumTestResponse) -> Result<ResponseSnapshot, SnapshotError> {
124    let status = response.status_code().as_u16();
125
126    let mut headers = HashMap::new();
127    for (name, value) in response.headers() {
128        let header_value = value
129            .to_str()
130            .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
131        headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
132    }
133
134    let body = response.into_bytes();
135    let decoded_body = decode_body(&headers, body.to_vec())?;
136
137    Ok(ResponseSnapshot {
138        status,
139        headers,
140        body: decoded_body,
141    })
142}
143
144/// Convert an Axum response into a reusable [`ResponseSnapshot`].
145pub async fn snapshot_http_response(
146    response: axum::response::Response<Body>,
147) -> Result<ResponseSnapshot, SnapshotError> {
148    let (parts, body) = response.into_parts();
149    let status = parts.status.as_u16();
150
151    let mut headers = HashMap::new();
152    for (name, value) in parts.headers.iter() {
153        let header_value = value
154            .to_str()
155            .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
156        headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
157    }
158
159    let collected = body
160        .collect()
161        .await
162        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
163    let bytes = collected.to_bytes();
164    let decoded_body = decode_body(&headers, bytes.to_vec())?;
165
166    Ok(ResponseSnapshot {
167        status,
168        headers,
169        body: decoded_body,
170    })
171}
172
173fn decode_body(headers: &HashMap<String, String>, body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
174    let encoding = headers
175        .get("content-encoding")
176        .map(|value| value.trim().to_ascii_lowercase());
177
178    match encoding.as_deref() {
179        Some("gzip" | "x-gzip") => decode_gzip(body),
180        Some("br") => decode_brotli(body),
181        _ => Ok(body),
182    }
183}
184
185fn decode_gzip(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
186    let mut decoder = GzDecoder::new(Cursor::new(body));
187    let mut decoded_bytes = Vec::new();
188    decoder
189        .read_to_end(&mut decoded_bytes)
190        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
191    Ok(decoded_bytes)
192}
193
194fn decode_brotli(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
195    let mut decoder = Decompressor::new(Cursor::new(body), 4096);
196    let mut decoded_bytes = Vec::new();
197    decoder
198        .read_to_end(&mut decoded_bytes)
199        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
200    Ok(decoded_bytes)
201}
202
203/// WebSocket connection wrapper for testing.
204///
205/// Provides a simple interface for sending and receiving WebSocket messages
206/// during tests without needing a real network connection.
207pub struct WebSocketConnection {
208    inner: TestWebSocket,
209}
210
211impl WebSocketConnection {
212    /// Create a new WebSocket connection from an axum-test TestWebSocket.
213    pub fn new(inner: TestWebSocket) -> Self {
214        Self { inner }
215    }
216
217    /// Send a text message over the WebSocket.
218    pub async fn send_text(&mut self, text: impl std::fmt::Display) {
219        self.inner.send_text(text).await;
220    }
221
222    /// Send a JSON message over the WebSocket.
223    pub async fn send_json<T: serde::Serialize>(&mut self, value: &T) {
224        self.inner.send_json(value).await;
225    }
226
227    /// Send a raw WebSocket message.
228    pub async fn send_message(&mut self, msg: WsMessage) {
229        self.inner.send_message(msg).await;
230    }
231
232    /// Receive the next text message from the WebSocket.
233    pub async fn receive_text(&mut self) -> String {
234        self.inner.receive_text().await
235    }
236
237    /// Receive and parse a JSON message from the WebSocket.
238    pub async fn receive_json<T: serde::de::DeserializeOwned>(&mut self) -> T {
239        self.inner.receive_json().await
240    }
241
242    /// Receive raw bytes from the WebSocket.
243    pub async fn receive_bytes(&mut self) -> bytes::Bytes {
244        self.inner.receive_bytes().await
245    }
246
247    /// Receive the next raw message from the WebSocket.
248    pub async fn receive_message(&mut self) -> WebSocketMessage {
249        let msg = self.inner.receive_message().await;
250        WebSocketMessage::from_ws_message(msg)
251    }
252
253    /// Close the WebSocket connection.
254    pub async fn close(self) {
255        self.inner.close().await;
256    }
257}
258
259/// A WebSocket message that can be text or binary.
260#[derive(Debug, Clone)]
261pub enum WebSocketMessage {
262    /// A text message.
263    Text(String),
264    /// A binary message.
265    Binary(Vec<u8>),
266    /// A close message.
267    Close(Option<String>),
268    /// A ping message.
269    Ping(Vec<u8>),
270    /// A pong message.
271    Pong(Vec<u8>),
272}
273
274impl WebSocketMessage {
275    fn from_ws_message(msg: WsMessage) -> Self {
276        match msg {
277            WsMessage::Text(text) => WebSocketMessage::Text(text.to_string()),
278            WsMessage::Binary(data) => WebSocketMessage::Binary(data.to_vec()),
279            WsMessage::Close(frame) => WebSocketMessage::Close(frame.map(|f| f.reason.to_string())),
280            WsMessage::Ping(data) => WebSocketMessage::Ping(data.to_vec()),
281            WsMessage::Pong(data) => WebSocketMessage::Pong(data.to_vec()),
282            WsMessage::Frame(_) => WebSocketMessage::Close(None),
283        }
284    }
285
286    /// Get the message as text, if it's a text message.
287    pub fn as_text(&self) -> Option<&str> {
288        match self {
289            WebSocketMessage::Text(text) => Some(text),
290            _ => None,
291        }
292    }
293
294    /// Get the message as JSON, if it's a text message containing JSON.
295    pub fn as_json(&self) -> Result<Value, String> {
296        match self {
297            WebSocketMessage::Text(text) => {
298                serde_json::from_str(text).map_err(|e| format!("Failed to parse JSON: {}", e))
299            }
300            _ => Err("Message is not text".to_string()),
301        }
302    }
303
304    /// Get the message as binary, if it's a binary message.
305    pub fn as_binary(&self) -> Option<&[u8]> {
306        match self {
307            WebSocketMessage::Binary(data) => Some(data),
308            _ => None,
309        }
310    }
311
312    /// Check if this is a close message.
313    pub fn is_close(&self) -> bool {
314        matches!(self, WebSocketMessage::Close(_))
315    }
316}
317
318/// Connect to a WebSocket endpoint on the test server.
319pub async fn connect_websocket(server: &TestServer, path: &str) -> WebSocketConnection {
320    let ws = server.get_websocket(path).await.into_websocket().await;
321    WebSocketConnection::new(ws)
322}
323
324/// Server-Sent Events (SSE) stream for testing.
325///
326/// Wraps a response body and provides methods to parse SSE events.
327#[derive(Debug)]
328pub struct SseStream {
329    body: String,
330    events: Vec<SseEvent>,
331}
332
333impl SseStream {
334    /// Create a new SSE stream from a response.
335    pub fn from_response(response: &ResponseSnapshot) -> Result<Self, String> {
336        let body = response
337            .text()
338            .map_err(|e| format!("Failed to read response body: {}", e))?;
339
340        let events = Self::parse_events(&body);
341
342        Ok(Self { body, events })
343    }
344
345    fn parse_events(body: &str) -> Vec<SseEvent> {
346        let mut events = Vec::new();
347        let lines: Vec<&str> = body.lines().collect();
348        let mut i = 0;
349
350        while i < lines.len() {
351            if lines[i].starts_with("data:") {
352                let data = lines[i].trim_start_matches("data:").trim().to_string();
353                events.push(SseEvent { data });
354            } else if lines[i].starts_with("data") {
355                let data = lines[i].trim_start_matches("data").trim().to_string();
356                if !data.is_empty() || lines[i].len() == 4 {
357                    events.push(SseEvent { data });
358                }
359            }
360            i += 1;
361        }
362
363        events
364    }
365
366    /// Get all events from the stream.
367    pub fn events(&self) -> &[SseEvent] {
368        &self.events
369    }
370
371    /// Get the raw body of the SSE response.
372    pub fn body(&self) -> &str {
373        &self.body
374    }
375
376    /// Get events as JSON values.
377    pub fn events_as_json(&self) -> Result<Vec<Value>, String> {
378        self.events
379            .iter()
380            .map(|event| event.as_json())
381            .collect::<Result<Vec<_>, _>>()
382    }
383}
384
385/// A single Server-Sent Event.
386#[derive(Debug, Clone)]
387pub struct SseEvent {
388    /// The data field of the event.
389    pub data: String,
390}
391
392impl SseEvent {
393    /// Parse the event data as JSON.
394    pub fn as_json(&self) -> Result<Value, String> {
395        serde_json::from_str(&self.data).map_err(|e| format!("Failed to parse JSON: {}", e))
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use axum::body::Body;
403    use axum::response::Response;
404    use std::io::Write;
405
406    #[test]
407    fn sse_stream_parses_multiple_events() {
408        let mut headers = HashMap::new();
409        headers.insert("content-type".to_string(), "text/event-stream".to_string());
410
411        let snapshot = ResponseSnapshot {
412            status: 200,
413            headers,
414            body: b"data: {\"id\": 1}\n\ndata: \"hello\"\n\n".to_vec(),
415        };
416
417        let stream = SseStream::from_response(&snapshot).expect("stream");
418        assert_eq!(stream.events().len(), 2);
419        assert_eq!(stream.events()[0].as_json().unwrap()["id"], serde_json::json!(1));
420        assert_eq!(stream.events()[1].data, "\"hello\"");
421        assert_eq!(stream.events_as_json().unwrap().len(), 2);
422    }
423
424    #[test]
425    fn sse_event_reports_invalid_json() {
426        let event = SseEvent {
427            data: "not-json".to_string(),
428        };
429        assert!(event.as_json().is_err());
430    }
431
432    #[test]
433    fn test_graphql_data_extraction() {
434        let mut headers = HashMap::new();
435        headers.insert("content-type".to_string(), "application/json".to_string());
436
437        let graphql_response = serde_json::json!({
438            "data": {
439                "user": {
440                    "id": "1",
441                    "name": "Alice"
442                }
443            }
444        });
445
446        let snapshot = ResponseSnapshot {
447            status: 200,
448            headers,
449            body: serde_json::to_vec(&graphql_response).unwrap(),
450        };
451
452        let data = snapshot.graphql_data().expect("data extraction");
453        assert_eq!(data["user"]["id"], "1");
454        assert_eq!(data["user"]["name"], "Alice");
455    }
456
457    #[test]
458    fn test_graphql_errors_extraction() {
459        let mut headers = HashMap::new();
460        headers.insert("content-type".to_string(), "application/json".to_string());
461
462        let graphql_response = serde_json::json!({
463            "errors": [
464                {
465                    "message": "Field not found",
466                    "path": ["user", "email"]
467                },
468                {
469                    "message": "Unauthorized",
470                    "extensions": { "code": "UNAUTHENTICATED" }
471                }
472            ]
473        });
474
475        let snapshot = ResponseSnapshot {
476            status: 400,
477            headers,
478            body: serde_json::to_vec(&graphql_response).unwrap(),
479        };
480
481        let errors = snapshot.graphql_errors().expect("errors extraction");
482        assert_eq!(errors.len(), 2);
483        assert_eq!(errors[0]["message"], "Field not found");
484        assert_eq!(errors[1]["message"], "Unauthorized");
485    }
486
487    #[test]
488    fn test_graphql_missing_data_field() {
489        let mut headers = HashMap::new();
490        headers.insert("content-type".to_string(), "application/json".to_string());
491
492        let graphql_response = serde_json::json!({
493            "errors": [{ "message": "Query failed" }]
494        });
495
496        let snapshot = ResponseSnapshot {
497            status: 400,
498            headers,
499            body: serde_json::to_vec(&graphql_response).unwrap(),
500        };
501
502        let result = snapshot.graphql_data();
503        assert!(result.is_err());
504        assert!(result.unwrap_err().to_string().contains("No 'data' field"));
505    }
506
507    #[test]
508    fn test_graphql_empty_errors() {
509        let mut headers = HashMap::new();
510        headers.insert("content-type".to_string(), "application/json".to_string());
511
512        let graphql_response = serde_json::json!({
513            "data": { "result": null }
514        });
515
516        let snapshot = ResponseSnapshot {
517            status: 200,
518            headers,
519            body: serde_json::to_vec(&graphql_response).unwrap(),
520        };
521
522        let errors = snapshot.graphql_errors().expect("errors extraction");
523        assert!(errors.is_empty());
524    }
525
526    fn gzip_bytes(input: &[u8]) -> Vec<u8> {
527        let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
528        encoder.write_all(input).expect("gzip write");
529        encoder.finish().expect("gzip finish")
530    }
531
532    fn brotli_bytes(input: &[u8]) -> Vec<u8> {
533        let mut encoder = brotli::CompressorWriter::new(Vec::new(), 4096, 5, 22);
534        encoder.write_all(input).expect("brotli write");
535        encoder.into_inner()
536    }
537
538    #[tokio::test]
539    async fn snapshot_http_response_decodes_gzip_body() {
540        let body = b"hello gzip";
541        let compressed = gzip_bytes(body);
542        let response = Response::builder()
543            .status(200)
544            .header("content-encoding", "gzip")
545            .body(Body::from(compressed))
546            .unwrap();
547
548        let snapshot = snapshot_http_response(response).await.expect("snapshot");
549        assert_eq!(snapshot.body, body);
550    }
551
552    #[tokio::test]
553    async fn snapshot_http_response_decodes_brotli_body() {
554        let body = b"hello brotli";
555        let compressed = brotli_bytes(body);
556        let response = Response::builder()
557            .status(200)
558            .header("content-encoding", "br")
559            .body(Body::from(compressed))
560            .unwrap();
561
562        let snapshot = snapshot_http_response(response).await.expect("snapshot");
563        assert_eq!(snapshot.body, body);
564    }
565
566    #[tokio::test]
567    async fn snapshot_http_response_leaves_plain_body() {
568        let body = b"plain";
569        let response = Response::builder()
570            .status(200)
571            .body(Body::from(body.as_slice()))
572            .unwrap();
573
574        let snapshot = snapshot_http_response(response).await.expect("snapshot");
575        assert_eq!(snapshot.body, body);
576    }
577}