1use axum::http::{self};
60use axum::response::Response;
61use axum::Router;
62use axum::{body::Body, http::Request};
63use serde::{Deserialize, Serialize};
64use std::collections::HashMap;
65use std::fmt::Display;
66use thiserror::Error;
67use tower_service::Service;
68
69#[derive(Error, Debug)]
70pub enum Error {
71 #[error("Could not parse method from LocalRequest")]
72 RequestMethodParseError(String),
73
74 #[error("Could not parse body from LocalRequest")]
75 RequestBodyParseError(#[from] http::Error),
76}
77
78#[derive(Serialize, Deserialize, Clone, Debug)]
80pub struct LocalRequest {
81 pub uri: String,
82 pub method: String,
83 pub body: Option<String>,
84 pub headers: HashMap<String, String>,
85}
86
87impl LocalRequest {
88 pub async fn send_to_router(self, router: &mut Router) -> LocalResponse {
89 match self.to_axum_request() {
90 Ok(request) => match router.call(request).await {
91 Ok(response) => LocalResponse::from_response(response).await,
92 Err(error) => LocalResponse::internal_server_error(error),
93 },
94 Err(error) => LocalResponse::internal_server_error(error),
95 }
96 }
97
98 fn to_axum_request(&self) -> Result<http::Request<Body>, Error> {
99 let uri = self.uri.to_string();
100 let mut request_builder = match self.method.to_uppercase().as_str() {
101 "GET" => Ok(Request::get(uri)),
102 "POST" => Ok(Request::post(uri)),
103 "PUT" => Ok(Request::put(uri)),
104 "DELETE" => Ok(Request::delete(uri)),
105 "PATCH" => Ok(Request::patch(uri)),
106 _ => Err(Error::RequestMethodParseError(self.method.to_string())),
107 }?;
108
109 for (key, value) in self.headers.iter() {
110 request_builder = request_builder.header(key, value);
111 }
112
113 let request = match &self.body {
114 None => request_builder.body(Body::empty()),
115 Some(body) => request_builder.body(body.to_string().into()),
116 }?;
117
118 Ok(request)
119 }
120}
121
122#[derive(Serialize, Deserialize, Debug, Clone)]
124pub struct LocalResponse {
125 pub status_code: u16,
126 pub body: Vec<u8>,
127 pub headers: HashMap<String, String>,
128}
129
130impl LocalResponse {
131 pub fn internal_server_error(error: impl Display) -> Self {
132 let error_message = format!("An error occured: {}", error);
133 LocalResponse {
134 status_code: 500,
135 body: error_message.into(),
136 headers: Default::default(),
137 }
138 }
139}
140
141impl LocalResponse {
142 pub async fn from_response(response: Response) -> Self {
143 let code = response.status();
144 let response_headers = response.headers().clone();
145 let bytes_result = axum::body::to_bytes(response.into_body(), usize::MAX).await;
146
147 let mut headers: HashMap<String, String> = HashMap::new();
148 for (key, value) in response_headers.iter() {
149 headers.insert(key.to_string(), value.to_str().unwrap().to_string());
150 }
151
152 match bytes_result {
153 Ok(data) => LocalResponse {
154 status_code: code.as_u16(),
155 body: data.to_vec(),
156 headers,
157 },
158 Err(_) => LocalResponse {
159 status_code: code.as_u16(),
160 body: Vec::new(),
161 headers: headers.clone(),
162 },
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use axum::{
171 routing::{get, post},
172 Json,
173 };
174 use serde_json::json;
175
176 fn create_test_router() -> Router {
178 Router::new()
179 .route("/test", get(|| async { "Hello, World!" }))
180 .route("/echo", post(|body: String| async move { body }))
181 .route("/json", get(|| async { Json(json!({"status": "ok"})) }))
182 }
183
184 mod local_request_tests {
185 use super::*;
186
187 #[tokio::test]
188 async fn test_basic_get_request() {
189 let mut router = create_test_router();
190 let request = LocalRequest {
191 uri: "/test".to_string(),
192 method: "GET".to_string(),
193 body: None,
194 headers: HashMap::new(),
195 };
196
197 let response = request.send_to_router(&mut router).await;
198 assert_eq!(response.status_code, 200);
199 assert_eq!(String::from_utf8(response.body).unwrap(), "Hello, World!");
200 }
201
202 #[tokio::test]
203 async fn test_post_request_with_body() {
204 let mut router = create_test_router();
205 let body = "Test Body";
206 let request = LocalRequest {
207 uri: "/echo".to_string(),
208 method: "POST".to_string(),
209 body: Some(body.to_string()),
210 headers: HashMap::new(),
211 };
212
213 let response = request.send_to_router(&mut router).await;
214 assert_eq!(response.status_code, 200);
215 assert_eq!(String::from_utf8(response.body).unwrap(), body);
216 }
217
218 #[tokio::test]
219 async fn test_invalid_method() {
220 let mut router = create_test_router();
221 let request = LocalRequest {
222 uri: "/test".to_string(),
223 method: "INVALID".to_string(),
224 body: None,
225 headers: HashMap::new(),
226 };
227
228 let response = request.send_to_router(&mut router).await;
229 assert_eq!(response.status_code, 500);
230 assert!(String::from_utf8(response.body)
231 .unwrap()
232 .contains("Could not parse method"));
233 }
234
235 #[tokio::test]
236 async fn test_request_with_headers() {
237 let mut router = Router::new().route(
238 "/headers",
239 get(|req: Request<Body>| async move {
240 let header_value = req
241 .headers()
242 .get("X-Test-Header")
243 .and_then(|v| v.to_str().ok())
244 .unwrap_or("")
245 .to_string();
246 header_value
247 }),
248 );
249
250 let mut headers = HashMap::new();
251 headers.insert("X-Test-Header".to_string(), "test-value".to_string());
252
253 let request = LocalRequest {
254 uri: "/headers".to_string(),
255 method: "GET".to_string(),
256 body: None,
257 headers,
258 };
259
260 let response = request.send_to_router(&mut router).await;
261 assert_eq!(response.status_code, 200);
262 assert_eq!(String::from_utf8(response.body).unwrap(), "test-value");
263 }
264 }
265
266 mod local_response_tests {
267 use super::*;
268 use http::response::Builder;
269
270 #[tokio::test]
271 async fn test_response_creation_with_body() {
272 let response = Builder::new()
273 .status(200)
274 .body(Body::from("test body"))
275 .unwrap();
276
277 let local_response = LocalResponse::from_response(response).await;
278 assert_eq!(local_response.status_code, 200);
279 assert_eq!(String::from_utf8(local_response.body).unwrap(), "test body");
280 }
281
282 #[tokio::test]
283 async fn test_response_with_headers() {
284 let response = Builder::new()
285 .status(200)
286 .header("X-Test", "test-value")
287 .body(Body::empty())
288 .unwrap();
289
290 let local_response = LocalResponse::from_response(response).await;
291 assert_eq!(local_response.status_code, 200);
292 assert_eq!(local_response.headers.get("x-test").unwrap(), "test-value");
293 }
294
295 #[tokio::test]
296 async fn test_internal_server_error() {
297 let error_message = "Test error";
298 let response = LocalResponse::internal_server_error(error_message);
299
300 assert_eq!(response.status_code, 500);
301 assert!(String::from_utf8(response.body)
302 .unwrap()
303 .contains(error_message));
304 assert!(response.headers.is_empty());
305 }
306 }
307
308 mod method_tests {
309 use super::*;
310
311 #[tokio::test]
312 async fn test_all_valid_methods() {
313 let methods = vec!["GET", "POST", "PUT", "DELETE", "PATCH"];
314
315 for method in methods {
316 let request = LocalRequest {
317 uri: "/test".to_string(),
318 method: method.to_string(),
319 body: None,
320 headers: HashMap::new(),
321 };
322
323 assert!(request.to_axum_request().is_ok());
324 }
325 }
326
327 #[tokio::test]
328 async fn test_method_case_insensitivity() {
329 let request = LocalRequest {
330 uri: "/test".to_string(),
331 method: "get".to_string(),
332 body: None,
333 headers: HashMap::new(),
334 };
335
336 assert!(request.to_axum_request().is_ok());
337 }
338 }
339}