1use super::{ResponseSnapshot, SnapshotError, snapshot_response};
11use axum::http::{HeaderName, HeaderValue, Method};
12use axum_test::TestServer;
13use bytes::Bytes;
14use serde_json::Value;
15use std::sync::Arc;
16use urlencoding::encode;
17
18type MultipartPayload = Option<(Vec<(String, String)>, Vec<super::MultipartFilePart>)>;
19
20pub struct TestClient {
27 server: Arc<TestServer>,
28}
29
30impl TestClient {
31 pub fn from_router(router: axum::Router) -> Result<Self, String> {
33 let server = TestServer::new(router).map_err(|e| format!("Failed to create test server: {}", e))?;
34
35 Ok(Self {
36 server: Arc::new(server),
37 })
38 }
39
40 pub fn server(&self) -> &TestServer {
42 &self.server
43 }
44
45 pub async fn get(
47 &self,
48 path: &str,
49 query_params: Option<Vec<(String, String)>>,
50 headers: Option<Vec<(String, String)>>,
51 ) -> Result<ResponseSnapshot, SnapshotError> {
52 let full_path = build_full_path(path, query_params.as_deref());
53 let mut request = self.server.get(&full_path);
54
55 if let Some(headers_vec) = headers {
56 request = self.add_headers(request, headers_vec)?;
57 }
58
59 let response = request.await;
60 snapshot_response(response).await
61 }
62
63 pub async fn post(
65 &self,
66 path: &str,
67 json: Option<Value>,
68 form_data: Option<Vec<(String, String)>>,
69 multipart: MultipartPayload,
70 query_params: Option<Vec<(String, String)>>,
71 headers: Option<Vec<(String, String)>>,
72 ) -> Result<ResponseSnapshot, SnapshotError> {
73 let full_path = build_full_path(path, query_params.as_deref());
74 let mut request = self.server.post(&full_path);
75
76 if let Some(headers_vec) = headers {
77 request = self.add_headers(request, headers_vec.clone())?;
78 }
79
80 if let Some((form_fields, files)) = multipart {
81 let (body, boundary) = super::build_multipart_body(&form_fields, &files);
82 let content_type = format!("multipart/form-data; boundary={}", boundary);
83 request = request.add_header("content-type", &content_type);
84 request = request.bytes(Bytes::from(body));
85 } else if let Some(form_fields) = form_data {
86 let encoded = super::encode_urlencoded_body(&serde_json::to_value(&form_fields).unwrap_or(Value::Null))
87 .map_err(|e| SnapshotError::Decompression(format!("Form encoding failed: {}", e)))?;
88 request = request.add_header("content-type", "application/x-www-form-urlencoded");
89 request = request.bytes(Bytes::from(encoded));
90 } else if let Some(json_value) = json {
91 request = request.json(&json_value);
92 }
93
94 let response = request.await;
95 snapshot_response(response).await
96 }
97
98 pub async fn put(
100 &self,
101 path: &str,
102 json: Option<Value>,
103 query_params: Option<Vec<(String, String)>>,
104 headers: Option<Vec<(String, String)>>,
105 ) -> Result<ResponseSnapshot, SnapshotError> {
106 let full_path = build_full_path(path, query_params.as_deref());
107 let mut request = self.server.put(&full_path);
108
109 if let Some(headers_vec) = headers {
110 request = self.add_headers(request, headers_vec)?;
111 }
112
113 if let Some(json_value) = json {
114 request = request.json(&json_value);
115 }
116
117 let response = request.await;
118 snapshot_response(response).await
119 }
120
121 pub async fn patch(
123 &self,
124 path: &str,
125 json: Option<Value>,
126 query_params: Option<Vec<(String, String)>>,
127 headers: Option<Vec<(String, String)>>,
128 ) -> Result<ResponseSnapshot, SnapshotError> {
129 let full_path = build_full_path(path, query_params.as_deref());
130 let mut request = self.server.patch(&full_path);
131
132 if let Some(headers_vec) = headers {
133 request = self.add_headers(request, headers_vec)?;
134 }
135
136 if let Some(json_value) = json {
137 request = request.json(&json_value);
138 }
139
140 let response = request.await;
141 snapshot_response(response).await
142 }
143
144 pub async fn delete(
146 &self,
147 path: &str,
148 query_params: Option<Vec<(String, String)>>,
149 headers: Option<Vec<(String, String)>>,
150 ) -> Result<ResponseSnapshot, SnapshotError> {
151 let full_path = build_full_path(path, query_params.as_deref());
152 let mut request = self.server.delete(&full_path);
153
154 if let Some(headers_vec) = headers {
155 request = self.add_headers(request, headers_vec)?;
156 }
157
158 let response = request.await;
159 snapshot_response(response).await
160 }
161
162 pub async fn options(
164 &self,
165 path: &str,
166 query_params: Option<Vec<(String, String)>>,
167 headers: Option<Vec<(String, String)>>,
168 ) -> Result<ResponseSnapshot, SnapshotError> {
169 let full_path = build_full_path(path, query_params.as_deref());
170 let mut request = self.server.method(Method::OPTIONS, &full_path);
171
172 if let Some(headers_vec) = headers {
173 request = self.add_headers(request, headers_vec)?;
174 }
175
176 let response = request.await;
177 snapshot_response(response).await
178 }
179
180 pub async fn head(
182 &self,
183 path: &str,
184 query_params: Option<Vec<(String, String)>>,
185 headers: Option<Vec<(String, String)>>,
186 ) -> Result<ResponseSnapshot, SnapshotError> {
187 let full_path = build_full_path(path, query_params.as_deref());
188 let mut request = self.server.method(Method::HEAD, &full_path);
189
190 if let Some(headers_vec) = headers {
191 request = self.add_headers(request, headers_vec)?;
192 }
193
194 let response = request.await;
195 snapshot_response(response).await
196 }
197
198 pub async fn trace(
200 &self,
201 path: &str,
202 query_params: Option<Vec<(String, String)>>,
203 headers: Option<Vec<(String, String)>>,
204 ) -> Result<ResponseSnapshot, SnapshotError> {
205 let full_path = build_full_path(path, query_params.as_deref());
206 let mut request = self.server.method(Method::TRACE, &full_path);
207
208 if let Some(headers_vec) = headers {
209 request = self.add_headers(request, headers_vec)?;
210 }
211
212 let response = request.await;
213 snapshot_response(response).await
214 }
215
216 fn add_headers(
218 &self,
219 mut request: axum_test::TestRequest,
220 headers: Vec<(String, String)>,
221 ) -> Result<axum_test::TestRequest, SnapshotError> {
222 for (key, value) in headers {
223 let header_name = HeaderName::from_bytes(key.as_bytes())
224 .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header name: {}", e)))?;
225 let header_value = HeaderValue::from_str(&value)
226 .map_err(|e| SnapshotError::InvalidHeader(format!("Invalid header value: {}", e)))?;
227 request = request.add_header(header_name, header_value);
228 }
229 Ok(request)
230 }
231}
232
233fn build_full_path(path: &str, query_params: Option<&[(String, String)]>) -> String {
235 match query_params {
236 None | Some(&[]) => path.to_string(),
237 Some(params) => {
238 let query_string: Vec<String> = params
239 .iter()
240 .map(|(k, v)| format!("{}={}", encode(k), encode(v)))
241 .collect();
242
243 if path.contains('?') {
244 format!("{}&{}", path, query_string.join("&"))
245 } else {
246 format!("{}?{}", path, query_string.join("&"))
247 }
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn build_full_path_no_params() {
258 let path = "/users";
259 assert_eq!(build_full_path(path, None), "/users");
260 assert_eq!(build_full_path(path, Some(&[])), "/users");
261 }
262
263 #[test]
264 fn build_full_path_with_params() {
265 let path = "/users";
266 let params = vec![
267 ("id".to_string(), "123".to_string()),
268 ("name".to_string(), "test user".to_string()),
269 ];
270 let result = build_full_path(path, Some(¶ms));
271 assert!(result.starts_with("/users?"));
272 assert!(result.contains("id=123"));
273 assert!(result.contains("name=test%20user"));
274 }
275
276 #[test]
277 fn build_full_path_existing_query() {
278 let path = "/users?active=true";
279 let params = vec![("id".to_string(), "123".to_string())];
280 let result = build_full_path(path, Some(¶ms));
281 assert_eq!(result, "/users?active=true&id=123");
282 }
283}