web_server_abstraction/adapters/
axum.rs

1//! Axum framework adapter.
2
3use crate::core::{HandlerFn, Middleware};
4use crate::error::{Result, WebServerError};
5use crate::types::{Body, Headers, HttpMethod, Request, Response, StatusCode};
6use axum::{
7    http::{self, HeaderMap, Method, StatusCode as AxumStatusCode},
8    response::IntoResponse,
9    routing::{delete, get, head, options, patch, post, put, MethodRouter},
10    Router,
11};
12use std::{collections::HashMap, net::SocketAddr, sync::Arc};
13use tokio::net::TcpListener;
14use tower::ServiceBuilder;
15use tower_http::trace::TraceLayer;
16
17/// Axum framework adapter
18pub struct AxumAdapter {
19    router: Router,
20    middleware: Vec<Box<dyn Middleware>>,
21    addr: Option<SocketAddr>,
22}
23
24impl Default for AxumAdapter {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl AxumAdapter {
31    pub fn new() -> Self {
32        Self {
33            router: Router::new(),
34            middleware: Vec::new(),
35            addr: None,
36        }
37    }
38
39    /// Bind the server to an address
40    pub async fn bind(&mut self, addr: &str) -> Result<()> {
41        let socket_addr: SocketAddr = addr
42            .parse()
43            .map_err(|e| WebServerError::bind_error(format!("Invalid address {}: {}", addr, e)))?;
44        self.addr = Some(socket_addr);
45        Ok(())
46    }
47
48    /// Run the server
49    pub async fn run(self) -> Result<()> {
50        let addr = self
51            .addr
52            .ok_or_else(|| WebServerError::bind_error("No address bound"))?;
53
54        // Apply middleware and tracing to the router
55        let app = self
56            .router
57            .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()));
58
59        println!("Axum server starting on {}", addr);
60
61        let listener = TcpListener::bind(addr).await.map_err(|e| {
62            WebServerError::adapter_error(format!("Failed to bind listener: {}", e))
63        })?;
64
65        axum::serve(listener, app)
66            .await
67            .map_err(|e| WebServerError::adapter_error(format!("Axum server error: {}", e)))?;
68
69        Ok(())
70    }
71
72    /// Add a route to the server
73    pub fn route(&mut self, path: &str, method: HttpMethod, handler: HandlerFn) -> &mut Self {
74        // Create an adapter that converts between Axum and our types
75        let handler = Arc::new(handler);
76        let adapter_handler = {
77            // Clone the Arc for the closure
78            let handler_clone = Arc::clone(&handler);
79            move |req: axum::extract::Request| {
80                let handler = Arc::clone(&handler_clone);
81                async move {
82                    // Convert Axum request to our Request type
83                    let our_request = match convert_axum_request_to_ours(req).await {
84                        Ok(req) => req,
85                        Err(e) => {
86                            return (
87                                AxumStatusCode::BAD_REQUEST,
88                                format!("Request conversion error: {}", e),
89                            )
90                                .into_response();
91                        }
92                    };
93
94                    // Call handler directly (middleware would be applied at server level)
95                    let result = handler(our_request).await;
96
97                    // Convert our Response to Axum response
98                    match result {
99                        Ok(response) => {
100                            convert_our_response_to_axum(response).await.into_response()
101                        }
102                        Err(e) => (
103                            AxumStatusCode::INTERNAL_SERVER_ERROR,
104                            format!("Handler error: {}", e),
105                        )
106                            .into_response(),
107                    }
108                }
109            }
110        };
111
112        let method_router: MethodRouter = match method {
113            HttpMethod::GET => get(adapter_handler),
114            HttpMethod::POST => post(adapter_handler),
115            HttpMethod::PUT => put(adapter_handler),
116            HttpMethod::DELETE => delete(adapter_handler),
117            HttpMethod::PATCH => patch(adapter_handler),
118            HttpMethod::HEAD => head(adapter_handler),
119            HttpMethod::OPTIONS => options(adapter_handler),
120            _ => get(adapter_handler),
121        };
122
123        self.router = self.router.clone().route(path, method_router);
124        println!("Added Axum route: {:?} {}", method, path);
125        self
126    }
127
128    /// Add middleware to the server
129    pub fn middleware(&mut self, middleware: Box<dyn Middleware>) -> &mut Self {
130        self.middleware.push(middleware);
131        println!("Added middleware to Axum adapter");
132        self
133    }
134}
135
136/// Convert Axum request to our Request type
137async fn convert_axum_request_to_ours(req: axum::extract::Request) -> Result<Request> {
138    let (parts, body) = req.into_parts();
139
140    // Convert method
141    let method = match parts.method {
142        Method::GET => HttpMethod::GET,
143        Method::POST => HttpMethod::POST,
144        Method::PUT => HttpMethod::PUT,
145        Method::DELETE => HttpMethod::DELETE,
146        Method::PATCH => HttpMethod::PATCH,
147        Method::HEAD => HttpMethod::HEAD,
148        Method::OPTIONS => HttpMethod::OPTIONS,
149        _ => HttpMethod::GET, // Default fallback
150    };
151
152    // Convert URI directly from the Axum request
153    let uri = parts.uri;
154
155    // Convert headers
156    let mut headers = Headers::new();
157    for (name, value) in parts.headers.iter() {
158        if let Ok(value_str) = value.to_str() {
159            headers.insert(name.to_string(), value_str.to_string());
160        }
161    }
162
163    // Convert body
164    let body_bytes = axum::body::to_bytes(body, usize::MAX)
165        .await
166        .map_err(|e| WebServerError::adapter_error(format!("Failed to read body: {}", e)))?;
167    let body = Body::from(body_bytes.to_vec());
168
169    // Initialize extensions and other fields
170    let extensions = HashMap::new();
171    let path_params = HashMap::new();
172    let cookies = HashMap::new();
173    let form_data = None;
174    let multipart = None;
175
176    Ok(Request {
177        method,
178        uri,
179        version: parts.version, // Use the actual version from the request
180        headers,
181        body,
182        extensions,
183        path_params,
184        cookies,
185        form_data,
186        multipart,
187    })
188}
189
190/// Convert our Response to Axum response format
191async fn convert_our_response_to_axum(response: Response) -> impl IntoResponse {
192    let mut header_map = HeaderMap::new();
193
194    // Convert headers
195    for (name, value) in response.headers.iter() {
196        if let (Ok(header_name), Ok(header_value)) = (
197            name.parse::<http::HeaderName>(),
198            value.parse::<http::HeaderValue>(),
199        ) {
200            header_map.insert(header_name, header_value);
201        }
202    }
203
204    // Convert status code
205    let axum_status = match response.status {
206        StatusCode::OK => AxumStatusCode::OK,
207        StatusCode::CREATED => AxumStatusCode::CREATED,
208        StatusCode::NOT_FOUND => AxumStatusCode::NOT_FOUND,
209        StatusCode::INTERNAL_SERVER_ERROR => AxumStatusCode::INTERNAL_SERVER_ERROR,
210        StatusCode::BAD_REQUEST => AxumStatusCode::BAD_REQUEST,
211        StatusCode::UNAUTHORIZED => AxumStatusCode::UNAUTHORIZED,
212        StatusCode::FORBIDDEN => AxumStatusCode::FORBIDDEN,
213        StatusCode::NO_CONTENT => AxumStatusCode::NO_CONTENT,
214        _ => AxumStatusCode::OK, // Default fallback
215    };
216
217    // Convert body to bytes
218    let body_bytes = response.body.bytes().await.unwrap_or_default();
219
220    (axum_status, header_map, body_bytes)
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use crate::types::{Body, Headers, StatusCode};
227
228    #[tokio::test]
229    async fn test_axum_adapter_creation() {
230        let adapter = AxumAdapter::new();
231        assert!(adapter.middleware.is_empty());
232        assert!(adapter.addr.is_none());
233    }
234
235    #[tokio::test]
236    async fn test_axum_adapter_bind() {
237        let mut adapter = AxumAdapter::new();
238        let result = adapter.bind("127.0.0.1:0").await;
239        assert!(result.is_ok());
240        assert!(adapter.addr.is_some());
241    }
242
243    #[tokio::test]
244    async fn test_response_conversion() {
245        let response = Response {
246            status: StatusCode::OK,
247            headers: {
248                let mut h = Headers::new();
249                h.insert("content-type".to_string(), "application/json".to_string());
250                h
251            },
252            body: Body::from("test response"),
253        };
254
255        let _axum_response = convert_our_response_to_axum(response).await;
256        // The conversion should succeed without panicking
257        // More detailed testing would require integration with Axum's test framework
258    }
259}