1use super::{ResponseSnapshot, SnapshotError, snapshot_response};
11use axum::http::{HeaderName, HeaderValue, Method};
12use axum_test::TestServer;
13use bytes::Bytes;
14use serde_json::Value;
15use std::sync::Arc;
16use urlencoding::encode;
17
18type MultipartPayload = Option<(Vec<(String, String)>, Vec<super::MultipartFilePart>)>;
19
20pub struct TestClient {
27 server: Arc<TestServer>,
28}
29
30impl TestClient {
31 pub fn from_router(router: axum::Router) -> Result<Self, String> {
33 let server = if tokio::runtime::Handle::try_current().is_ok() {
34 TestServer::builder()
35 .http_transport()
36 .build(router)
37 .map_err(|e| format!("Failed to create test server: {}", e))?
38 } else {
39 TestServer::new(router).map_err(|e| format!("Failed to create test server: {}", e))?
40 };
41
42 Ok(Self {
43 server: Arc::new(server),
44 })
45 }
46
47 pub fn server(&self) -> &TestServer {
49 &self.server
50 }
51
52 pub async fn get(
54 &self,
55 path: &str,
56 query_params: Option<Vec<(String, String)>>,
57 headers: Option<Vec<(String, String)>>,
58 ) -> Result<ResponseSnapshot, SnapshotError> {
59 let full_path = build_full_path(path, query_params.as_deref());
60 let mut request = self.server.get(&full_path);
61
62 if let Some(headers_vec) = headers {
63 request = self.add_headers(request, headers_vec)?;
64 }
65
66 let response = request.await;
67 snapshot_response(response).await
68 }
69
70 pub async fn post(
72 &self,
73 path: &str,
74 json: Option<Value>,
75 form_data: Option<Vec<(String, String)>>,
76 multipart: MultipartPayload,
77 query_params: Option<Vec<(String, String)>>,
78 headers: Option<Vec<(String, String)>>,
79 ) -> Result<ResponseSnapshot, SnapshotError> {
80 let full_path = build_full_path(path, query_params.as_deref());
81 let mut request = self.server.post(&full_path);
82
83 if let Some(headers_vec) = headers {
84 request = self.add_headers(request, headers_vec)?;
85 }
86
87 if let Some((form_fields, files)) = multipart {
88 let (body, boundary) = super::build_multipart_body(&form_fields, &files);
89 let content_type = format!("multipart/form-data; boundary={}", boundary);
90 request = request.add_header("content-type", &content_type);
91 request = request.bytes(Bytes::from(body));
92 } else if let Some(form_fields) = form_data {
93 let fields_value = serde_json::to_value(&form_fields)
94 .map_err(|e| SnapshotError::Decompression(format!("Failed to serialize form fields: {}", e)))?;
95 let encoded = super::encode_urlencoded_body(&fields_value)
96 .map_err(|e| SnapshotError::Decompression(format!("Form encoding failed: {}", e)))?;
97 request = request.add_header("content-type", "application/x-www-form-urlencoded");
98 request = request.bytes(Bytes::from(encoded));
99 } else if let Some(json_value) = json {
100 request = request.json(&json_value);
101 }
102
103 let response = request.await;
104 snapshot_response(response).await
105 }
106
107 pub async fn request_raw(
109 &self,
110 method: Method,
111 path: &str,
112 body: Bytes,
113 query_params: Option<Vec<(String, String)>>,
114 headers: Option<Vec<(String, String)>>,
115 ) -> Result<ResponseSnapshot, SnapshotError> {
116 let full_path = build_full_path(path, query_params.as_deref());
117 let mut request = self.server.method(method, &full_path);
118
119 if let Some(headers_vec) = headers {
120 request = self.add_headers(request, headers_vec)?;
121 }
122
123 request = request.bytes(body);
124 let response = request.await;
125 snapshot_response(response).await
126 }
127
128 pub async fn put(
130 &self,
131 path: &str,
132 json: Option<Value>,
133 query_params: Option<Vec<(String, String)>>,
134 headers: Option<Vec<(String, String)>>,
135 ) -> Result<ResponseSnapshot, SnapshotError> {
136 let full_path = build_full_path(path, query_params.as_deref());
137 let mut request = self.server.put(&full_path);
138
139 if let Some(headers_vec) = headers {
140 request = self.add_headers(request, headers_vec)?;
141 }
142
143 if let Some(json_value) = json {
144 request = request.json(&json_value);
145 }
146
147 let response = request.await;
148 snapshot_response(response).await
149 }
150
151 pub async fn patch(
153 &self,
154 path: &str,
155 json: Option<Value>,
156 query_params: Option<Vec<(String, String)>>,
157 headers: Option<Vec<(String, String)>>,
158 ) -> Result<ResponseSnapshot, SnapshotError> {
159 let full_path = build_full_path(path, query_params.as_deref());
160 let mut request = self.server.patch(&full_path);
161
162 if let Some(headers_vec) = headers {
163 request = self.add_headers(request, headers_vec)?;
164 }
165
166 if let Some(json_value) = json {
167 request = request.json(&json_value);
168 }
169
170 let response = request.await;
171 snapshot_response(response).await
172 }
173
174 pub async fn delete(
176 &self,
177 path: &str,
178 query_params: Option<Vec<(String, String)>>,
179 headers: Option<Vec<(String, String)>>,
180 ) -> Result<ResponseSnapshot, SnapshotError> {
181 let full_path = build_full_path(path, query_params.as_deref());
182 let mut request = self.server.delete(&full_path);
183
184 if let Some(headers_vec) = headers {
185 request = self.add_headers(request, headers_vec)?;
186 }
187
188 let response = request.await;
189 snapshot_response(response).await
190 }
191
192 pub async fn options(
194 &self,
195 path: &str,
196 query_params: Option<Vec<(String, String)>>,
197 headers: Option<Vec<(String, String)>>,
198 ) -> Result<ResponseSnapshot, SnapshotError> {
199 let full_path = build_full_path(path, query_params.as_deref());
200 let mut request = self.server.method(Method::OPTIONS, &full_path);
201
202 if let Some(headers_vec) = headers {
203 request = self.add_headers(request, headers_vec)?;
204 }
205
206 let response = request.await;
207 snapshot_response(response).await
208 }
209
210 pub async fn head(
212 &self,
213 path: &str,
214 query_params: Option<Vec<(String, String)>>,
215 headers: Option<Vec<(String, String)>>,
216 ) -> Result<ResponseSnapshot, SnapshotError> {
217 let full_path = build_full_path(path, query_params.as_deref());
218 let mut request = self.server.method(Method::HEAD, &full_path);
219
220 if let Some(headers_vec) = headers {
221 request = self.add_headers(request, headers_vec)?;
222 }
223
224 let response = request.await;
225 snapshot_response(response).await
226 }
227
228 pub async fn trace(
230 &self,
231 path: &str,
232 query_params: Option<Vec<(String, String)>>,
233 headers: Option<Vec<(String, String)>>,
234 ) -> Result<ResponseSnapshot, SnapshotError> {
235 let full_path = build_full_path(path, query_params.as_deref());
236 let mut request = self.server.method(Method::TRACE, &full_path);
237
238 if let Some(headers_vec) = headers {
239 request = self.add_headers(request, headers_vec)?;
240 }
241
242 let response = request.await;
243 snapshot_response(response).await
244 }
245
246 pub async fn graphql_at(
248 &self,
249 endpoint: &str,
250 query: &str,
251 variables: Option<Value>,
252 operation_name: Option<&str>,
253 ) -> Result<ResponseSnapshot, SnapshotError> {
254 let body = build_graphql_body(query, variables, operation_name);
255 self.post(endpoint, Some(body), None, None, None, None).await
256 }
257
258 pub async fn graphql(
260 &self,
261 query: &str,
262 variables: Option<Value>,
263 operation_name: Option<&str>,
264 ) -> Result<ResponseSnapshot, SnapshotError> {
265 self.graphql_at("/graphql", query, variables, operation_name).await
266 }
267
268 pub async fn graphql_with_status(
284 &self,
285 query: &str,
286 variables: Option<Value>,
287 operation_name: Option<&str>,
288 ) -> Result<(u16, ResponseSnapshot), SnapshotError> {
289 let snapshot = self.graphql(query, variables, operation_name).await?;
290 let status = snapshot.status;
291 Ok((status, snapshot))
292 }
293
294 pub async fn graphql_subscription(
296 &self,
297 _query: &str,
298 _variables: Option<Value>,
299 _operation_name: Option<&str>,
300 ) -> Result<(), SnapshotError> {
301 Err(SnapshotError::Decompression(
303 "GraphQL subscriptions not yet implemented".to_string(),
304 ))
305 }
306
307 fn add_headers(
309 &self,
310 mut request: axum_test::TestRequest,
311 headers: Vec<(String, String)>,
312 ) -> Result<axum_test::TestRequest, SnapshotError> {
313 for (key, value) in headers {
314 let header_name = HeaderName::from_bytes(key.as_bytes())
315 .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header name: {}", e)))?;
316 let header_value = HeaderValue::from_str(&value)
317 .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header value: {}", e)))?;
318 request = request.add_header(header_name, header_value);
319 }
320 Ok(request)
321 }
322}
323
324pub fn build_graphql_body(query: &str, variables: Option<Value>, operation_name: Option<&str>) -> Value {
326 let mut body = serde_json::json!({ "query": query });
327 if let Some(vars) = variables {
328 body["variables"] = vars;
329 }
330 if let Some(op_name) = operation_name {
331 body["operationName"] = Value::String(op_name.to_string());
332 }
333 body
334}
335
336fn build_full_path(path: &str, query_params: Option<&[(String, String)]>) -> String {
338 match query_params {
339 None | Some(&[]) => path.to_string(),
340 Some(params) => {
341 let query_string: Vec<String> = params
342 .iter()
343 .map(|(k, v)| format!("{}={}", encode(k), encode(v)))
344 .collect();
345
346 if path.contains('?') {
347 format!("{}&{}", path, query_string.join("&"))
348 } else {
349 format!("{}?{}", path, query_string.join("&"))
350 }
351 }
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn build_full_path_no_params() {
361 let path = "/users";
362 assert_eq!(build_full_path(path, None), "/users");
363 assert_eq!(build_full_path(path, Some(&[])), "/users");
364 }
365
366 #[test]
367 fn build_full_path_with_params() {
368 let path = "/users";
369 let params = vec![
370 ("id".to_string(), "123".to_string()),
371 ("name".to_string(), "test user".to_string()),
372 ];
373 let result = build_full_path(path, Some(¶ms));
374 assert!(result.starts_with("/users?"));
375 assert!(result.contains("id=123"));
376 assert!(result.contains("name=test%20user"));
377 }
378
379 #[test]
380 fn build_full_path_existing_query() {
381 let path = "/users?active=true";
382 let params = vec![("id".to_string(), "123".to_string())];
383 let result = build_full_path(path, Some(¶ms));
384 assert_eq!(result, "/users?active=true&id=123");
385 }
386
387 #[test]
388 fn test_graphql_query_builder() {
389 let query = "{ users { id name } }";
390 let variables = Some(serde_json::json!({ "limit": 10 }));
391 let op_name = Some("GetUsers");
392
393 let mut body = serde_json::json!({ "query": query });
394 if let Some(vars) = variables {
395 body["variables"] = vars;
396 }
397 if let Some(op_name) = op_name {
398 body["operationName"] = Value::String(op_name.to_string());
399 }
400
401 assert_eq!(body["query"], query);
402 assert_eq!(body["variables"]["limit"], 10);
403 assert_eq!(body["operationName"], "GetUsers");
404 }
405
406 #[test]
407 fn test_graphql_with_status_method() {
408 let query = "query { hello }";
409 let body = serde_json::json!({
410 "query": query,
411 "variables": null,
412 "operationName": null
413 });
414
415 let expected_fields = vec!["query", "variables", "operationName"];
418 for field in expected_fields {
419 assert!(body.get(field).is_some(), "Missing field: {}", field);
420 }
421 }
422
423 #[test]
424 fn test_build_graphql_body_basic() {
425 let query = "{ users { id name } }";
426 let body = build_graphql_body(query, None, None);
427
428 assert_eq!(body["query"], query);
429 assert!(body.get("variables").is_none() || body["variables"].is_null());
430 assert!(body.get("operationName").is_none() || body["operationName"].is_null());
431 }
432
433 #[test]
434 fn test_build_graphql_body_with_variables() {
435 let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
436 let variables = Some(serde_json::json!({ "id": "123" }));
437 let body = build_graphql_body(query, variables, None);
438
439 assert_eq!(body["query"], query);
440 assert_eq!(body["variables"]["id"], "123");
441 }
442
443 #[test]
444 fn test_build_graphql_body_with_operation_name() {
445 let query = "query GetUsers { users { id } }";
446 let op_name = Some("GetUsers");
447 let body = build_graphql_body(query, None, op_name);
448
449 assert_eq!(body["query"], query);
450 assert_eq!(body["operationName"], "GetUsers");
451 }
452
453 #[test]
454 fn test_build_graphql_body_all_fields() {
455 let query = "mutation CreateUser($name: String!) { createUser(name: $name) { id } }";
456 let variables = Some(serde_json::json!({ "name": "Alice" }));
457 let op_name = Some("CreateUser");
458 let body = build_graphql_body(query, variables, op_name);
459
460 assert_eq!(body["query"], query);
461 assert_eq!(body["variables"]["name"], "Alice");
462 assert_eq!(body["operationName"], "CreateUser");
463 }
464}