web_server_abstraction/adapters/
axum.rs1use crate::core::{HandlerFn, Middleware};
4use crate::error::{Result, WebServerError};
5use crate::types::{Body, Headers, HttpMethod, Request, Response, StatusCode};
6use axum::{
7 http::{self, HeaderMap, Method, StatusCode as AxumStatusCode},
8 response::IntoResponse,
9 routing::{delete, get, head, options, patch, post, put, MethodRouter},
10 Router,
11};
12use std::{collections::HashMap, net::SocketAddr, sync::Arc};
13use tokio::net::TcpListener;
14use tower::ServiceBuilder;
15use tower_http::trace::TraceLayer;
16
17pub struct AxumAdapter {
19 router: Router,
20 middleware: Vec<Box<dyn Middleware>>,
21 addr: Option<SocketAddr>,
22}
23
24impl Default for AxumAdapter {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl AxumAdapter {
31 pub fn new() -> Self {
32 Self {
33 router: Router::new(),
34 middleware: Vec::new(),
35 addr: None,
36 }
37 }
38
39 pub async fn bind(&mut self, addr: &str) -> Result<()> {
41 let socket_addr: SocketAddr = addr
42 .parse()
43 .map_err(|e| WebServerError::bind_error(format!("Invalid address {}: {}", addr, e)))?;
44 self.addr = Some(socket_addr);
45 Ok(())
46 }
47
48 pub async fn run(self) -> Result<()> {
50 let addr = self
51 .addr
52 .ok_or_else(|| WebServerError::bind_error("No address bound"))?;
53
54 let app = self
56 .router
57 .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()));
58
59 println!("Axum server starting on {}", addr);
60
61 let listener = TcpListener::bind(addr).await.map_err(|e| {
62 WebServerError::adapter_error(format!("Failed to bind listener: {}", e))
63 })?;
64
65 axum::serve(listener, app)
66 .await
67 .map_err(|e| WebServerError::adapter_error(format!("Axum server error: {}", e)))?;
68
69 Ok(())
70 }
71
72 pub fn route(&mut self, path: &str, method: HttpMethod, handler: HandlerFn) -> &mut Self {
74 let handler = Arc::new(handler);
76 let adapter_handler = {
77 let handler_clone = Arc::clone(&handler);
79 move |req: axum::extract::Request| {
80 let handler = Arc::clone(&handler_clone);
81 async move {
82 let our_request = match convert_axum_request_to_ours(req).await {
84 Ok(req) => req,
85 Err(e) => {
86 return (
87 AxumStatusCode::BAD_REQUEST,
88 format!("Request conversion error: {}", e),
89 )
90 .into_response();
91 }
92 };
93
94 let result = handler(our_request).await;
96
97 match result {
99 Ok(response) => {
100 convert_our_response_to_axum(response).await.into_response()
101 }
102 Err(e) => (
103 AxumStatusCode::INTERNAL_SERVER_ERROR,
104 format!("Handler error: {}", e),
105 )
106 .into_response(),
107 }
108 }
109 }
110 };
111
112 let method_router: MethodRouter = match method {
113 HttpMethod::GET => get(adapter_handler),
114 HttpMethod::POST => post(adapter_handler),
115 HttpMethod::PUT => put(adapter_handler),
116 HttpMethod::DELETE => delete(adapter_handler),
117 HttpMethod::PATCH => patch(adapter_handler),
118 HttpMethod::HEAD => head(adapter_handler),
119 HttpMethod::OPTIONS => options(adapter_handler),
120 _ => get(adapter_handler),
121 };
122
123 self.router = self.router.clone().route(path, method_router);
124 println!("Added Axum route: {:?} {}", method, path);
125 self
126 }
127
128 pub fn middleware(&mut self, middleware: Box<dyn Middleware>) -> &mut Self {
130 self.middleware.push(middleware);
131 println!("Added middleware to Axum adapter");
132 self
133 }
134}
135
136async fn convert_axum_request_to_ours(req: axum::extract::Request) -> Result<Request> {
138 let (parts, body) = req.into_parts();
139
140 let method = match parts.method {
142 Method::GET => HttpMethod::GET,
143 Method::POST => HttpMethod::POST,
144 Method::PUT => HttpMethod::PUT,
145 Method::DELETE => HttpMethod::DELETE,
146 Method::PATCH => HttpMethod::PATCH,
147 Method::HEAD => HttpMethod::HEAD,
148 Method::OPTIONS => HttpMethod::OPTIONS,
149 _ => HttpMethod::GET, };
151
152 let uri = parts.uri;
154
155 let mut headers = Headers::new();
157 for (name, value) in parts.headers.iter() {
158 if let Ok(value_str) = value.to_str() {
159 headers.insert(name.to_string(), value_str.to_string());
160 }
161 }
162
163 let body_bytes = axum::body::to_bytes(body, usize::MAX)
165 .await
166 .map_err(|e| WebServerError::adapter_error(format!("Failed to read body: {}", e)))?;
167 let body = Body::from(body_bytes.to_vec());
168
169 let extensions = HashMap::new();
171 let path_params = HashMap::new();
172 let cookies = HashMap::new();
173 let form_data = None;
174 let multipart = None;
175
176 Ok(Request {
177 method,
178 uri,
179 version: parts.version, headers,
181 body,
182 extensions,
183 path_params,
184 cookies,
185 form_data,
186 multipart,
187 })
188}
189
190async fn convert_our_response_to_axum(response: Response) -> impl IntoResponse {
192 let mut header_map = HeaderMap::new();
193
194 for (name, value) in response.headers.iter() {
196 if let (Ok(header_name), Ok(header_value)) = (
197 name.parse::<http::HeaderName>(),
198 value.parse::<http::HeaderValue>(),
199 ) {
200 header_map.insert(header_name, header_value);
201 }
202 }
203
204 let axum_status = match response.status {
206 StatusCode::OK => AxumStatusCode::OK,
207 StatusCode::CREATED => AxumStatusCode::CREATED,
208 StatusCode::NOT_FOUND => AxumStatusCode::NOT_FOUND,
209 StatusCode::INTERNAL_SERVER_ERROR => AxumStatusCode::INTERNAL_SERVER_ERROR,
210 StatusCode::BAD_REQUEST => AxumStatusCode::BAD_REQUEST,
211 StatusCode::UNAUTHORIZED => AxumStatusCode::UNAUTHORIZED,
212 StatusCode::FORBIDDEN => AxumStatusCode::FORBIDDEN,
213 StatusCode::NO_CONTENT => AxumStatusCode::NO_CONTENT,
214 _ => AxumStatusCode::OK, };
216
217 let body_bytes = response.body.bytes().await.unwrap_or_default();
219
220 (axum_status, header_map, body_bytes)
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use crate::types::{Body, Headers, StatusCode};
227
228 #[tokio::test]
229 async fn test_axum_adapter_creation() {
230 let adapter = AxumAdapter::new();
231 assert!(adapter.middleware.is_empty());
232 assert!(adapter.addr.is_none());
233 }
234
235 #[tokio::test]
236 async fn test_axum_adapter_bind() {
237 let mut adapter = AxumAdapter::new();
238 let result = adapter.bind("127.0.0.1:0").await;
239 assert!(result.is_ok());
240 assert!(adapter.addr.is_some());
241 }
242
243 #[tokio::test]
244 async fn test_response_conversion() {
245 let response = Response {
246 status: StatusCode::OK,
247 headers: {
248 let mut h = Headers::new();
249 h.insert("content-type".to_string(), "application/json".to_string());
250 h
251 },
252 body: Body::from("test response"),
253 };
254
255 let _axum_response = convert_our_response_to_axum(response).await;
256 }
259}