spikard_http/testing/
test_client.rs

1//! Core test client for Spikard applications
2//!
3//! This module provides a language-agnostic TestClient that can be wrapped by
4//! language bindings (PyO3, napi-rs, magnus) to provide Pythonic, JavaScripty, and
5//! Ruby-like APIs respectively.
6//!
7//! The core client handles all HTTP method dispatch, query params, header management,
8//! body encoding (JSON, form-data, multipart), and response snapshot capture.
9
10use super::{ResponseSnapshot, SnapshotError, snapshot_response};
11use axum::http::{HeaderName, HeaderValue, Method};
12use axum_test::TestServer;
13use bytes::Bytes;
14use serde_json::Value;
15use std::sync::Arc;
16use urlencoding::encode;
17
18type MultipartPayload = Option<(Vec<(String, String)>, Vec<super::MultipartFilePart>)>;
19
20/// Core test client for making HTTP requests to a Spikard application.
21///
22/// This struct wraps axum-test's TestServer and provides a language-agnostic
23/// interface for making HTTP requests, sending WebSocket connections, and
24/// handling Server-Sent Events. Language bindings wrap this to provide
25/// native API surfaces.
26pub struct TestClient {
27    server: Arc<TestServer>,
28}
29
30impl TestClient {
31    /// Create a new test client from an Axum router
32    pub fn from_router(router: axum::Router) -> Result<Self, String> {
33        let server = if tokio::runtime::Handle::try_current().is_ok() {
34            TestServer::builder()
35                .http_transport()
36                .build(router)
37                .map_err(|e| format!("Failed to create test server: {}", e))?
38        } else {
39            TestServer::new(router).map_err(|e| format!("Failed to create test server: {}", e))?
40        };
41
42        Ok(Self {
43            server: Arc::new(server),
44        })
45    }
46
47    /// Get the underlying test server (for WebSocket and SSE connections)
48    pub fn server(&self) -> &TestServer {
49        &self.server
50    }
51
52    /// Make a GET request
53    pub async fn get(
54        &self,
55        path: &str,
56        query_params: Option<Vec<(String, String)>>,
57        headers: Option<Vec<(String, String)>>,
58    ) -> Result<ResponseSnapshot, SnapshotError> {
59        let full_path = build_full_path(path, query_params.as_deref());
60        let mut request = self.server.get(&full_path);
61
62        if let Some(headers_vec) = headers {
63            request = self.add_headers(request, headers_vec)?;
64        }
65
66        let response = request.await;
67        snapshot_response(response).await
68    }
69
70    /// Make a POST request
71    pub async fn post(
72        &self,
73        path: &str,
74        json: Option<Value>,
75        form_data: Option<Vec<(String, String)>>,
76        multipart: MultipartPayload,
77        query_params: Option<Vec<(String, String)>>,
78        headers: Option<Vec<(String, String)>>,
79    ) -> Result<ResponseSnapshot, SnapshotError> {
80        let full_path = build_full_path(path, query_params.as_deref());
81        let mut request = self.server.post(&full_path);
82
83        if let Some(headers_vec) = headers {
84            request = self.add_headers(request, headers_vec.clone())?;
85        }
86
87        if let Some((form_fields, files)) = multipart {
88            let (body, boundary) = super::build_multipart_body(&form_fields, &files);
89            let content_type = format!("multipart/form-data; boundary={}", boundary);
90            request = request.add_header("content-type", &content_type);
91            request = request.bytes(Bytes::from(body));
92        } else if let Some(form_fields) = form_data {
93            let encoded = super::encode_urlencoded_body(&serde_json::to_value(&form_fields).unwrap_or(Value::Null))
94                .map_err(|e| SnapshotError::Decompression(format!("Form encoding failed: {}", e)))?;
95            request = request.add_header("content-type", "application/x-www-form-urlencoded");
96            request = request.bytes(Bytes::from(encoded));
97        } else if let Some(json_value) = json {
98            request = request.json(&json_value);
99        }
100
101        let response = request.await;
102        snapshot_response(response).await
103    }
104
105    /// Make a request with a raw body payload.
106    pub async fn request_raw(
107        &self,
108        method: Method,
109        path: &str,
110        body: Bytes,
111        query_params: Option<Vec<(String, String)>>,
112        headers: Option<Vec<(String, String)>>,
113    ) -> Result<ResponseSnapshot, SnapshotError> {
114        let full_path = build_full_path(path, query_params.as_deref());
115        let mut request = self.server.method(method, &full_path);
116
117        if let Some(headers_vec) = headers {
118            request = self.add_headers(request, headers_vec)?;
119        }
120
121        request = request.bytes(body);
122        let response = request.await;
123        snapshot_response(response).await
124    }
125
126    /// Make a PUT request
127    pub async fn put(
128        &self,
129        path: &str,
130        json: Option<Value>,
131        query_params: Option<Vec<(String, String)>>,
132        headers: Option<Vec<(String, String)>>,
133    ) -> Result<ResponseSnapshot, SnapshotError> {
134        let full_path = build_full_path(path, query_params.as_deref());
135        let mut request = self.server.put(&full_path);
136
137        if let Some(headers_vec) = headers {
138            request = self.add_headers(request, headers_vec)?;
139        }
140
141        if let Some(json_value) = json {
142            request = request.json(&json_value);
143        }
144
145        let response = request.await;
146        snapshot_response(response).await
147    }
148
149    /// Make a PATCH request
150    pub async fn patch(
151        &self,
152        path: &str,
153        json: Option<Value>,
154        query_params: Option<Vec<(String, String)>>,
155        headers: Option<Vec<(String, String)>>,
156    ) -> Result<ResponseSnapshot, SnapshotError> {
157        let full_path = build_full_path(path, query_params.as_deref());
158        let mut request = self.server.patch(&full_path);
159
160        if let Some(headers_vec) = headers {
161            request = self.add_headers(request, headers_vec)?;
162        }
163
164        if let Some(json_value) = json {
165            request = request.json(&json_value);
166        }
167
168        let response = request.await;
169        snapshot_response(response).await
170    }
171
172    /// Make a DELETE request
173    pub async fn delete(
174        &self,
175        path: &str,
176        query_params: Option<Vec<(String, String)>>,
177        headers: Option<Vec<(String, String)>>,
178    ) -> Result<ResponseSnapshot, SnapshotError> {
179        let full_path = build_full_path(path, query_params.as_deref());
180        let mut request = self.server.delete(&full_path);
181
182        if let Some(headers_vec) = headers {
183            request = self.add_headers(request, headers_vec)?;
184        }
185
186        let response = request.await;
187        snapshot_response(response).await
188    }
189
190    /// Make an OPTIONS request
191    pub async fn options(
192        &self,
193        path: &str,
194        query_params: Option<Vec<(String, String)>>,
195        headers: Option<Vec<(String, String)>>,
196    ) -> Result<ResponseSnapshot, SnapshotError> {
197        let full_path = build_full_path(path, query_params.as_deref());
198        let mut request = self.server.method(Method::OPTIONS, &full_path);
199
200        if let Some(headers_vec) = headers {
201            request = self.add_headers(request, headers_vec)?;
202        }
203
204        let response = request.await;
205        snapshot_response(response).await
206    }
207
208    /// Make a HEAD request
209    pub async fn head(
210        &self,
211        path: &str,
212        query_params: Option<Vec<(String, String)>>,
213        headers: Option<Vec<(String, String)>>,
214    ) -> Result<ResponseSnapshot, SnapshotError> {
215        let full_path = build_full_path(path, query_params.as_deref());
216        let mut request = self.server.method(Method::HEAD, &full_path);
217
218        if let Some(headers_vec) = headers {
219            request = self.add_headers(request, headers_vec)?;
220        }
221
222        let response = request.await;
223        snapshot_response(response).await
224    }
225
226    /// Make a TRACE request
227    pub async fn trace(
228        &self,
229        path: &str,
230        query_params: Option<Vec<(String, String)>>,
231        headers: Option<Vec<(String, String)>>,
232    ) -> Result<ResponseSnapshot, SnapshotError> {
233        let full_path = build_full_path(path, query_params.as_deref());
234        let mut request = self.server.method(Method::TRACE, &full_path);
235
236        if let Some(headers_vec) = headers {
237            request = self.add_headers(request, headers_vec)?;
238        }
239
240        let response = request.await;
241        snapshot_response(response).await
242    }
243
244    /// Add headers to a test request builder
245    fn add_headers(
246        &self,
247        mut request: axum_test::TestRequest,
248        headers: Vec<(String, String)>,
249    ) -> Result<axum_test::TestRequest, SnapshotError> {
250        for (key, value) in headers {
251            let header_name = HeaderName::from_bytes(key.as_bytes())
252                .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header name: {}", e)))?;
253            let header_value = HeaderValue::from_str(&value)
254                .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header value: {}", e)))?;
255            request = request.add_header(header_name, header_value);
256        }
257        Ok(request)
258    }
259}
260
261/// Build a full path with query parameters
262fn build_full_path(path: &str, query_params: Option<&[(String, String)]>) -> String {
263    match query_params {
264        None | Some(&[]) => path.to_string(),
265        Some(params) => {
266            let query_string: Vec<String> = params
267                .iter()
268                .map(|(k, v)| format!("{}={}", encode(k), encode(v)))
269                .collect();
270
271            if path.contains('?') {
272                format!("{}&{}", path, query_string.join("&"))
273            } else {
274                format!("{}?{}", path, query_string.join("&"))
275            }
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn build_full_path_no_params() {
286        let path = "/users";
287        assert_eq!(build_full_path(path, None), "/users");
288        assert_eq!(build_full_path(path, Some(&[])), "/users");
289    }
290
291    #[test]
292    fn build_full_path_with_params() {
293        let path = "/users";
294        let params = vec![
295            ("id".to_string(), "123".to_string()),
296            ("name".to_string(), "test user".to_string()),
297        ];
298        let result = build_full_path(path, Some(&params));
299        assert!(result.starts_with("/users?"));
300        assert!(result.contains("id=123"));
301        assert!(result.contains("name=test%20user"));
302    }
303
304    #[test]
305    fn build_full_path_existing_query() {
306        let path = "/users?active=true";
307        let params = vec![("id".to_string(), "123".to_string())];
308        let result = build_full_path(path, Some(&params));
309        assert_eq!(result, "/users?active=true&id=123");
310    }
311}