spikard_http/
testing.rs

1use axum::body::Body;
2use axum::http::Request as AxumRequest;
3use axum_test::{TestResponse as AxumTestResponse, TestServer, TestWebSocket, WsMessage};
4
5pub mod multipart;
6pub use multipart::{MultipartFilePart, build_multipart_body};
7
8pub mod form;
9
10pub mod test_client;
11pub use test_client::TestClient;
12
13use brotli::Decompressor;
14use flate2::read::GzDecoder;
15pub use form::encode_urlencoded_body;
16use http_body_util::BodyExt;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::io::{Cursor, Read};
20
21/// Snapshot of an Axum response used by higher-level language bindings.
22#[derive(Debug, Clone)]
23pub struct ResponseSnapshot {
24    /// HTTP status code.
25    pub status: u16,
26    /// Response headers (lowercase keys for predictable lookups).
27    pub headers: HashMap<String, String>,
28    /// Response body bytes (decoded for supported encodings).
29    pub body: Vec<u8>,
30}
31
32impl ResponseSnapshot {
33    /// Return response body as UTF-8 string.
34    pub fn text(&self) -> Result<String, std::string::FromUtf8Error> {
35        String::from_utf8(self.body.clone())
36    }
37
38    /// Parse response body as JSON.
39    pub fn json(&self) -> Result<Value, serde_json::Error> {
40        serde_json::from_slice(&self.body)
41    }
42
43    /// Lookup header by case-insensitive name.
44    pub fn header(&self, name: &str) -> Option<&str> {
45        self.headers.get(&name.to_ascii_lowercase()).map(|s| s.as_str())
46    }
47}
48
49/// Possible errors while converting an Axum response into a snapshot.
50#[derive(Debug)]
51pub enum SnapshotError {
52    /// Response header could not be decoded to UTF-8.
53    InvalidHeader(String),
54    /// Body decompression failed.
55    Decompression(String),
56}
57
58impl std::fmt::Display for SnapshotError {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        match self {
61            SnapshotError::InvalidHeader(msg) => write!(f, "Invalid header: {}", msg),
62            SnapshotError::Decompression(msg) => write!(f, "Failed to decode body: {}", msg),
63        }
64    }
65}
66
67impl std::error::Error for SnapshotError {}
68
69/// Execute an HTTP request against an Axum [`TestServer`] by rehydrating it
70/// into the server's own [`axum_test::TestRequest`] builder.
71pub async fn call_test_server(server: &TestServer, request: AxumRequest<Body>) -> AxumTestResponse {
72    let (parts, body) = request.into_parts();
73
74    let mut path = parts.uri.path().to_string();
75    if let Some(query) = parts.uri.query()
76        && !query.is_empty()
77    {
78        path.push('?');
79        path.push_str(query);
80    }
81
82    let mut test_request = server.method(parts.method.clone(), &path);
83
84    for (name, value) in parts.headers.iter() {
85        test_request = test_request.add_header(name.clone(), value.clone());
86    }
87
88    let collected = body
89        .collect()
90        .await
91        .expect("failed to read request body for test dispatch");
92    let bytes = collected.to_bytes();
93    if !bytes.is_empty() {
94        test_request = test_request.bytes(bytes);
95    }
96
97    test_request.await
98}
99
100/// Convert an `AxumTestResponse` into a reusable [`ResponseSnapshot`].
101pub async fn snapshot_response(response: AxumTestResponse) -> Result<ResponseSnapshot, SnapshotError> {
102    let status = response.status_code().as_u16();
103
104    let mut headers = HashMap::new();
105    for (name, value) in response.headers() {
106        let header_value = value
107            .to_str()
108            .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
109        headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
110    }
111
112    let body = response.into_bytes();
113    let decoded_body = decode_body(&headers, body.to_vec())?;
114
115    Ok(ResponseSnapshot {
116        status,
117        headers,
118        body: decoded_body,
119    })
120}
121
122/// Convert an Axum response into a reusable [`ResponseSnapshot`].
123pub async fn snapshot_http_response(
124    response: axum::response::Response<Body>,
125) -> Result<ResponseSnapshot, SnapshotError> {
126    let (parts, body) = response.into_parts();
127    let status = parts.status.as_u16();
128
129    let mut headers = HashMap::new();
130    for (name, value) in parts.headers.iter() {
131        let header_value = value
132            .to_str()
133            .map_err(|e| SnapshotError::InvalidHeader(e.to_string()))?;
134        headers.insert(name.to_string().to_ascii_lowercase(), header_value.to_string());
135    }
136
137    let collected = body
138        .collect()
139        .await
140        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
141    let bytes = collected.to_bytes();
142    let decoded_body = decode_body(&headers, bytes.to_vec())?;
143
144    Ok(ResponseSnapshot {
145        status,
146        headers,
147        body: decoded_body,
148    })
149}
150
151fn decode_body(headers: &HashMap<String, String>, body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
152    let encoding = headers
153        .get("content-encoding")
154        .map(|value| value.trim().to_ascii_lowercase());
155
156    match encoding.as_deref() {
157        Some("gzip" | "x-gzip") => decode_gzip(body),
158        Some("br") => decode_brotli(body),
159        _ => Ok(body),
160    }
161}
162
163fn decode_gzip(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
164    let mut decoder = GzDecoder::new(Cursor::new(body));
165    let mut decoded_bytes = Vec::new();
166    decoder
167        .read_to_end(&mut decoded_bytes)
168        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
169    Ok(decoded_bytes)
170}
171
172fn decode_brotli(body: Vec<u8>) -> Result<Vec<u8>, SnapshotError> {
173    let mut decoder = Decompressor::new(Cursor::new(body), 4096);
174    let mut decoded_bytes = Vec::new();
175    decoder
176        .read_to_end(&mut decoded_bytes)
177        .map_err(|e| SnapshotError::Decompression(e.to_string()))?;
178    Ok(decoded_bytes)
179}
180
181/// WebSocket connection wrapper for testing.
182///
183/// Provides a simple interface for sending and receiving WebSocket messages
184/// during tests without needing a real network connection.
185pub struct WebSocketConnection {
186    inner: TestWebSocket,
187}
188
189impl WebSocketConnection {
190    /// Create a new WebSocket connection from an axum-test TestWebSocket.
191    pub fn new(inner: TestWebSocket) -> Self {
192        Self { inner }
193    }
194
195    /// Send a text message over the WebSocket.
196    pub async fn send_text(&mut self, text: impl std::fmt::Display) {
197        self.inner.send_text(text).await;
198    }
199
200    /// Send a JSON message over the WebSocket.
201    pub async fn send_json<T: serde::Serialize>(&mut self, value: &T) {
202        self.inner.send_json(value).await;
203    }
204
205    /// Send a raw WebSocket message.
206    pub async fn send_message(&mut self, msg: WsMessage) {
207        self.inner.send_message(msg).await;
208    }
209
210    /// Receive the next text message from the WebSocket.
211    pub async fn receive_text(&mut self) -> String {
212        self.inner.receive_text().await
213    }
214
215    /// Receive and parse a JSON message from the WebSocket.
216    pub async fn receive_json<T: serde::de::DeserializeOwned>(&mut self) -> T {
217        self.inner.receive_json().await
218    }
219
220    /// Receive raw bytes from the WebSocket.
221    pub async fn receive_bytes(&mut self) -> bytes::Bytes {
222        self.inner.receive_bytes().await
223    }
224
225    /// Receive the next raw message from the WebSocket.
226    pub async fn receive_message(&mut self) -> WebSocketMessage {
227        let msg = self.inner.receive_message().await;
228        WebSocketMessage::from_ws_message(msg)
229    }
230
231    /// Close the WebSocket connection.
232    pub async fn close(self) {
233        self.inner.close().await;
234    }
235}
236
237/// A WebSocket message that can be text or binary.
238#[derive(Debug, Clone)]
239pub enum WebSocketMessage {
240    /// A text message.
241    Text(String),
242    /// A binary message.
243    Binary(Vec<u8>),
244    /// A close message.
245    Close(Option<String>),
246    /// A ping message.
247    Ping(Vec<u8>),
248    /// A pong message.
249    Pong(Vec<u8>),
250}
251
252impl WebSocketMessage {
253    fn from_ws_message(msg: WsMessage) -> Self {
254        match msg {
255            WsMessage::Text(text) => WebSocketMessage::Text(text.to_string()),
256            WsMessage::Binary(data) => WebSocketMessage::Binary(data.to_vec()),
257            WsMessage::Close(frame) => WebSocketMessage::Close(frame.map(|f| f.reason.to_string())),
258            WsMessage::Ping(data) => WebSocketMessage::Ping(data.to_vec()),
259            WsMessage::Pong(data) => WebSocketMessage::Pong(data.to_vec()),
260            WsMessage::Frame(_) => WebSocketMessage::Close(None),
261        }
262    }
263
264    /// Get the message as text, if it's a text message.
265    pub fn as_text(&self) -> Option<&str> {
266        match self {
267            WebSocketMessage::Text(text) => Some(text),
268            _ => None,
269        }
270    }
271
272    /// Get the message as JSON, if it's a text message containing JSON.
273    pub fn as_json(&self) -> Result<Value, String> {
274        match self {
275            WebSocketMessage::Text(text) => {
276                serde_json::from_str(text).map_err(|e| format!("Failed to parse JSON: {}", e))
277            }
278            _ => Err("Message is not text".to_string()),
279        }
280    }
281
282    /// Get the message as binary, if it's a binary message.
283    pub fn as_binary(&self) -> Option<&[u8]> {
284        match self {
285            WebSocketMessage::Binary(data) => Some(data),
286            _ => None,
287        }
288    }
289
290    /// Check if this is a close message.
291    pub fn is_close(&self) -> bool {
292        matches!(self, WebSocketMessage::Close(_))
293    }
294}
295
296/// Connect to a WebSocket endpoint on the test server.
297pub async fn connect_websocket(server: &TestServer, path: &str) -> WebSocketConnection {
298    let ws = server.get_websocket(path).await.into_websocket().await;
299    WebSocketConnection::new(ws)
300}
301
302/// Server-Sent Events (SSE) stream for testing.
303///
304/// Wraps a response body and provides methods to parse SSE events.
305#[derive(Debug)]
306pub struct SseStream {
307    body: String,
308    events: Vec<SseEvent>,
309}
310
311impl SseStream {
312    /// Create a new SSE stream from a response.
313    pub fn from_response(response: &ResponseSnapshot) -> Result<Self, String> {
314        let body = response
315            .text()
316            .map_err(|e| format!("Failed to read response body: {}", e))?;
317
318        let events = Self::parse_events(&body);
319
320        Ok(Self { body, events })
321    }
322
323    fn parse_events(body: &str) -> Vec<SseEvent> {
324        let mut events = Vec::new();
325        let lines: Vec<&str> = body.lines().collect();
326        let mut i = 0;
327
328        while i < lines.len() {
329            if lines[i].starts_with("data:") {
330                let data = lines[i].trim_start_matches("data:").trim().to_string();
331                events.push(SseEvent { data });
332            } else if lines[i].starts_with("data") {
333                let data = lines[i].trim_start_matches("data").trim().to_string();
334                if !data.is_empty() || lines[i].len() == 4 {
335                    events.push(SseEvent { data });
336                }
337            }
338            i += 1;
339        }
340
341        events
342    }
343
344    /// Get all events from the stream.
345    pub fn events(&self) -> &[SseEvent] {
346        &self.events
347    }
348
349    /// Get the raw body of the SSE response.
350    pub fn body(&self) -> &str {
351        &self.body
352    }
353
354    /// Get events as JSON values.
355    pub fn events_as_json(&self) -> Result<Vec<Value>, String> {
356        self.events
357            .iter()
358            .map(|event| event.as_json())
359            .collect::<Result<Vec<_>, _>>()
360    }
361}
362
363/// A single Server-Sent Event.
364#[derive(Debug, Clone)]
365pub struct SseEvent {
366    /// The data field of the event.
367    pub data: String,
368}
369
370impl SseEvent {
371    /// Parse the event data as JSON.
372    pub fn as_json(&self) -> Result<Value, String> {
373        serde_json::from_str(&self.data).map_err(|e| format!("Failed to parse JSON: {}", e))
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn sse_stream_parses_multiple_events() {
383        let mut headers = HashMap::new();
384        headers.insert("content-type".to_string(), "text/event-stream".to_string());
385
386        let snapshot = ResponseSnapshot {
387            status: 200,
388            headers,
389            body: b"data: {\"id\": 1}\n\ndata: \"hello\"\n\n".to_vec(),
390        };
391
392        let stream = SseStream::from_response(&snapshot).expect("stream");
393        assert_eq!(stream.events().len(), 2);
394        assert_eq!(stream.events()[0].as_json().unwrap()["id"], serde_json::json!(1));
395        assert_eq!(stream.events()[1].data, "\"hello\"");
396        assert_eq!(stream.events_as_json().unwrap().len(), 2);
397    }
398
399    #[test]
400    fn sse_event_reports_invalid_json() {
401        let event = SseEvent {
402            data: "not-json".to_string(),
403        };
404        assert!(event.as_json().is_err());
405    }
406}