1use super::{ResponseSnapshot, SnapshotError, snapshot_response};
11use axum::http::{HeaderName, HeaderValue, Method};
12use axum_test::TestServer;
13use bytes::Bytes;
14use serde_json::Value;
15use std::sync::Arc;
16use urlencoding::encode;
17
18type MultipartPayload = Option<(Vec<(String, String)>, Vec<super::MultipartFilePart>)>;
19
20pub struct TestClient {
27 server: Arc<TestServer>,
28}
29
30impl TestClient {
31 pub fn from_router(router: axum::Router) -> Result<Self, String> {
33 let server = if tokio::runtime::Handle::try_current().is_ok() {
34 TestServer::builder()
35 .http_transport()
36 .build(router)
37 .map_err(|e| format!("Failed to create test server: {}", e))?
38 } else {
39 TestServer::new(router).map_err(|e| format!("Failed to create test server: {}", e))?
40 };
41
42 Ok(Self {
43 server: Arc::new(server),
44 })
45 }
46
47 pub fn server(&self) -> &TestServer {
49 &self.server
50 }
51
52 pub async fn get(
54 &self,
55 path: &str,
56 query_params: Option<Vec<(String, String)>>,
57 headers: Option<Vec<(String, String)>>,
58 ) -> Result<ResponseSnapshot, SnapshotError> {
59 let full_path = build_full_path(path, query_params.as_deref());
60 let mut request = self.server.get(&full_path);
61
62 if let Some(headers_vec) = headers {
63 request = self.add_headers(request, headers_vec)?;
64 }
65
66 let response = request.await;
67 snapshot_response(response).await
68 }
69
70 pub async fn post(
72 &self,
73 path: &str,
74 json: Option<Value>,
75 form_data: Option<Vec<(String, String)>>,
76 multipart: MultipartPayload,
77 query_params: Option<Vec<(String, String)>>,
78 headers: Option<Vec<(String, String)>>,
79 ) -> Result<ResponseSnapshot, SnapshotError> {
80 let full_path = build_full_path(path, query_params.as_deref());
81 let mut request = self.server.post(&full_path);
82
83 if let Some(headers_vec) = headers {
84 request = self.add_headers(request, headers_vec.clone())?;
85 }
86
87 if let Some((form_fields, files)) = multipart {
88 let (body, boundary) = super::build_multipart_body(&form_fields, &files);
89 let content_type = format!("multipart/form-data; boundary={}", boundary);
90 request = request.add_header("content-type", &content_type);
91 request = request.bytes(Bytes::from(body));
92 } else if let Some(form_fields) = form_data {
93 let encoded = super::encode_urlencoded_body(&serde_json::to_value(&form_fields).unwrap_or(Value::Null))
94 .map_err(|e| SnapshotError::Decompression(format!("Form encoding failed: {}", e)))?;
95 request = request.add_header("content-type", "application/x-www-form-urlencoded");
96 request = request.bytes(Bytes::from(encoded));
97 } else if let Some(json_value) = json {
98 request = request.json(&json_value);
99 }
100
101 let response = request.await;
102 snapshot_response(response).await
103 }
104
105 pub async fn request_raw(
107 &self,
108 method: Method,
109 path: &str,
110 body: Bytes,
111 query_params: Option<Vec<(String, String)>>,
112 headers: Option<Vec<(String, String)>>,
113 ) -> Result<ResponseSnapshot, SnapshotError> {
114 let full_path = build_full_path(path, query_params.as_deref());
115 let mut request = self.server.method(method, &full_path);
116
117 if let Some(headers_vec) = headers {
118 request = self.add_headers(request, headers_vec)?;
119 }
120
121 request = request.bytes(body);
122 let response = request.await;
123 snapshot_response(response).await
124 }
125
126 pub async fn put(
128 &self,
129 path: &str,
130 json: Option<Value>,
131 query_params: Option<Vec<(String, String)>>,
132 headers: Option<Vec<(String, String)>>,
133 ) -> Result<ResponseSnapshot, SnapshotError> {
134 let full_path = build_full_path(path, query_params.as_deref());
135 let mut request = self.server.put(&full_path);
136
137 if let Some(headers_vec) = headers {
138 request = self.add_headers(request, headers_vec)?;
139 }
140
141 if let Some(json_value) = json {
142 request = request.json(&json_value);
143 }
144
145 let response = request.await;
146 snapshot_response(response).await
147 }
148
149 pub async fn patch(
151 &self,
152 path: &str,
153 json: Option<Value>,
154 query_params: Option<Vec<(String, String)>>,
155 headers: Option<Vec<(String, String)>>,
156 ) -> Result<ResponseSnapshot, SnapshotError> {
157 let full_path = build_full_path(path, query_params.as_deref());
158 let mut request = self.server.patch(&full_path);
159
160 if let Some(headers_vec) = headers {
161 request = self.add_headers(request, headers_vec)?;
162 }
163
164 if let Some(json_value) = json {
165 request = request.json(&json_value);
166 }
167
168 let response = request.await;
169 snapshot_response(response).await
170 }
171
172 pub async fn delete(
174 &self,
175 path: &str,
176 query_params: Option<Vec<(String, String)>>,
177 headers: Option<Vec<(String, String)>>,
178 ) -> Result<ResponseSnapshot, SnapshotError> {
179 let full_path = build_full_path(path, query_params.as_deref());
180 let mut request = self.server.delete(&full_path);
181
182 if let Some(headers_vec) = headers {
183 request = self.add_headers(request, headers_vec)?;
184 }
185
186 let response = request.await;
187 snapshot_response(response).await
188 }
189
190 pub async fn options(
192 &self,
193 path: &str,
194 query_params: Option<Vec<(String, String)>>,
195 headers: Option<Vec<(String, String)>>,
196 ) -> Result<ResponseSnapshot, SnapshotError> {
197 let full_path = build_full_path(path, query_params.as_deref());
198 let mut request = self.server.method(Method::OPTIONS, &full_path);
199
200 if let Some(headers_vec) = headers {
201 request = self.add_headers(request, headers_vec)?;
202 }
203
204 let response = request.await;
205 snapshot_response(response).await
206 }
207
208 pub async fn head(
210 &self,
211 path: &str,
212 query_params: Option<Vec<(String, String)>>,
213 headers: Option<Vec<(String, String)>>,
214 ) -> Result<ResponseSnapshot, SnapshotError> {
215 let full_path = build_full_path(path, query_params.as_deref());
216 let mut request = self.server.method(Method::HEAD, &full_path);
217
218 if let Some(headers_vec) = headers {
219 request = self.add_headers(request, headers_vec)?;
220 }
221
222 let response = request.await;
223 snapshot_response(response).await
224 }
225
226 pub async fn trace(
228 &self,
229 path: &str,
230 query_params: Option<Vec<(String, String)>>,
231 headers: Option<Vec<(String, String)>>,
232 ) -> Result<ResponseSnapshot, SnapshotError> {
233 let full_path = build_full_path(path, query_params.as_deref());
234 let mut request = self.server.method(Method::TRACE, &full_path);
235
236 if let Some(headers_vec) = headers {
237 request = self.add_headers(request, headers_vec)?;
238 }
239
240 let response = request.await;
241 snapshot_response(response).await
242 }
243
244 fn add_headers(
246 &self,
247 mut request: axum_test::TestRequest,
248 headers: Vec<(String, String)>,
249 ) -> Result<axum_test::TestRequest, SnapshotError> {
250 for (key, value) in headers {
251 let header_name = HeaderName::from_bytes(key.as_bytes())
252 .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header name: {}", e)))?;
253 let header_value = HeaderValue::from_str(&value)
254 .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header value: {}", e)))?;
255 request = request.add_header(header_name, header_value);
256 }
257 Ok(request)
258 }
259}
260
261fn build_full_path(path: &str, query_params: Option<&[(String, String)]>) -> String {
263 match query_params {
264 None | Some(&[]) => path.to_string(),
265 Some(params) => {
266 let query_string: Vec<String> = params
267 .iter()
268 .map(|(k, v)| format!("{}={}", encode(k), encode(v)))
269 .collect();
270
271 if path.contains('?') {
272 format!("{}&{}", path, query_string.join("&"))
273 } else {
274 format!("{}?{}", path, query_string.join("&"))
275 }
276 }
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn build_full_path_no_params() {
286 let path = "/users";
287 assert_eq!(build_full_path(path, None), "/users");
288 assert_eq!(build_full_path(path, Some(&[])), "/users");
289 }
290
291 #[test]
292 fn build_full_path_with_params() {
293 let path = "/users";
294 let params = vec![
295 ("id".to_string(), "123".to_string()),
296 ("name".to_string(), "test user".to_string()),
297 ];
298 let result = build_full_path(path, Some(¶ms));
299 assert!(result.starts_with("/users?"));
300 assert!(result.contains("id=123"));
301 assert!(result.contains("name=test%20user"));
302 }
303
304 #[test]
305 fn build_full_path_existing_query() {
306 let path = "/users?active=true";
307 let params = vec![("id".to_string(), "123".to_string())];
308 let result = build_full_path(path, Some(¶ms));
309 assert_eq!(result, "/users?active=true&id=123");
310 }
311}