Skip to main content

spikard_http/server/
lifecycle_execution.rs

1//! Lifecycle hooks execution logic
2
3use crate::handler_trait::Handler;
4use axum::body::Body;
5use axum::http::StatusCode;
6use std::sync::Arc;
7
8/// Execute a handler with lifecycle hooks
9///
10/// This wraps the handler execution with lifecycle hooks at appropriate points:
11/// 1. preValidation hooks (before handler, which does validation)
12/// 2. preHandler hooks (after validation, before handler)
13/// 3. Handler execution
14/// 4. onResponse hooks (after successful handler execution)
15/// 5. onError hooks (if handler or any hook fails)
16pub async fn execute_with_lifecycle_hooks(
17    req: axum::http::Request<Body>,
18    request_data: crate::handler_trait::RequestData,
19    handler: Arc<dyn Handler>,
20    hooks: Option<Arc<crate::LifecycleHooks>>,
21) -> Result<axum::http::Response<Body>, (axum::http::StatusCode, String)> {
22    use crate::lifecycle::HookResult;
23
24    let Some(hooks) = hooks else {
25        return handler.call(req, request_data).await;
26    };
27
28    if hooks.is_empty() {
29        return handler.call(req, request_data).await;
30    }
31
32    let req = match hooks.execute_on_request(req).await {
33        Ok(HookResult::Continue(r)) => r,
34        Ok(HookResult::ShortCircuit(response)) => return Ok(response),
35        Err(e) => {
36            let error_response = axum::http::Response::builder()
37                .status(StatusCode::INTERNAL_SERVER_ERROR)
38                .body(Body::from(format!("{{\"error\":\"onRequest hook failed: {}\"}}", e)))
39                .unwrap();
40
41            return match hooks.execute_on_error(error_response).await {
42                Ok(resp) => Ok(resp),
43                Err(_) => Ok(axum::http::Response::builder()
44                    .status(StatusCode::INTERNAL_SERVER_ERROR)
45                    .body(Body::from("{\"error\":\"Hook execution failed\"}"))
46                    .unwrap()),
47            };
48        }
49    };
50
51    let req = match hooks.execute_pre_validation(req).await {
52        Ok(HookResult::Continue(r)) => r,
53        Ok(HookResult::ShortCircuit(response)) => return Ok(response),
54        Err(e) => {
55            let error_response = axum::http::Response::builder()
56                .status(StatusCode::INTERNAL_SERVER_ERROR)
57                .body(Body::from(format!(
58                    "{{\"error\":\"preValidation hook failed: {}\"}}",
59                    e
60                )))
61                .unwrap();
62
63            return match hooks.execute_on_error(error_response).await {
64                Ok(resp) => Ok(resp),
65                Err(_) => Ok(axum::http::Response::builder()
66                    .status(StatusCode::INTERNAL_SERVER_ERROR)
67                    .body(Body::from("{\"error\":\"Hook execution failed\"}"))
68                    .unwrap()),
69            };
70        }
71    };
72
73    let req = match hooks.execute_pre_handler(req).await {
74        Ok(HookResult::Continue(r)) => r,
75        Ok(HookResult::ShortCircuit(response)) => return Ok(response),
76        Err(e) => {
77            let error_response = axum::http::Response::builder()
78                .status(StatusCode::INTERNAL_SERVER_ERROR)
79                .body(Body::from(format!("{{\"error\":\"preHandler hook failed: {}\"}}", e)))
80                .unwrap();
81
82            return match hooks.execute_on_error(error_response).await {
83                Ok(resp) => Ok(resp),
84                Err(_) => Ok(axum::http::Response::builder()
85                    .status(StatusCode::INTERNAL_SERVER_ERROR)
86                    .body(Body::from("{\"error\":\"Hook execution failed\"}"))
87                    .unwrap()),
88            };
89        }
90    };
91
92    let response = match handler.call(req, request_data).await {
93        Ok(resp) => resp,
94        Err((status, message)) => {
95            let error_response = axum::http::Response::builder()
96                .status(status)
97                .body(Body::from(message))
98                .unwrap();
99
100            return match hooks.execute_on_error(error_response).await {
101                Ok(resp) => Ok(resp),
102                Err(e) => Ok(axum::http::Response::builder()
103                    .status(StatusCode::INTERNAL_SERVER_ERROR)
104                    .body(Body::from(format!("{{\"error\":\"onError hook failed: {}\"}}", e)))
105                    .unwrap()),
106            };
107        }
108    };
109
110    match hooks.execute_on_response(response).await {
111        Ok(resp) => Ok(resp),
112        Err(e) => Ok(axum::http::Response::builder()
113            .status(StatusCode::INTERNAL_SERVER_ERROR)
114            .body(Body::from(format!("{{\"error\":\"onResponse hook failed: {}\"}}", e)))
115            .unwrap()),
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use crate::lifecycle::{HookResult, request_hook, response_hook};
123    use axum::http::{Request, Response, StatusCode};
124    use http_body_util::BodyExt;
125    use serde_json::json;
126    use std::collections::HashMap;
127
128    struct OkHandler;
129
130    impl Handler for OkHandler {
131        fn call(
132            &self,
133            _request: Request<Body>,
134            _request_data: crate::handler_trait::RequestData,
135        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::handler_trait::HandlerResult> + Send + '_>>
136        {
137            Box::pin(async move {
138                Ok(Response::builder()
139                    .status(StatusCode::OK)
140                    .body(Body::from("ok"))
141                    .unwrap())
142            })
143        }
144    }
145
146    struct ErrHandler;
147
148    impl Handler for ErrHandler {
149        fn call(
150            &self,
151            _request: Request<Body>,
152            _request_data: crate::handler_trait::RequestData,
153        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::handler_trait::HandlerResult> + Send + '_>>
154        {
155            Box::pin(async move { Err((StatusCode::BAD_REQUEST, "bad".to_string())) })
156        }
157    }
158
159    fn empty_request_data() -> crate::handler_trait::RequestData {
160        crate::handler_trait::RequestData {
161            path_params: std::sync::Arc::new(HashMap::new()),
162            query_params: std::sync::Arc::new(json!({})),
163            validated_params: None,
164            raw_query_params: std::sync::Arc::new(HashMap::new()),
165            body: std::sync::Arc::new(json!(null)),
166            raw_body: None,
167            headers: std::sync::Arc::new(HashMap::new()),
168            cookies: std::sync::Arc::new(HashMap::new()),
169            method: "GET".to_string(),
170            path: "/".to_string(),
171            #[cfg(feature = "di")]
172            dependencies: None,
173        }
174    }
175
176    #[tokio::test]
177    async fn pre_validation_error_with_failing_on_error_hook_returns_fallback() {
178        let mut hooks = crate::LifecycleHooks::new();
179        hooks.add_pre_validation(request_hook("boom", |_req| async move { Err("boom".to_string()) }));
180        hooks.add_on_error(response_hook("fail-on-error", |_resp| async move {
181            Err("fail".to_string())
182        }));
183
184        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
185        let resp = execute_with_lifecycle_hooks(req, empty_request_data(), Arc::new(OkHandler), Some(Arc::new(hooks)))
186            .await
187            .unwrap();
188
189        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
190        let body = resp.into_body().collect().await.unwrap().to_bytes();
191        assert_eq!(body.as_ref(), b"{\"error\":\"Hook execution failed\"}");
192    }
193
194    #[tokio::test]
195    async fn handler_error_with_failing_on_error_hook_returns_on_error_hook_failed_response() {
196        let mut hooks = crate::LifecycleHooks::new();
197        hooks.add_on_error(response_hook("fail-on-error", |_resp| async move {
198            Err("boom".to_string())
199        }));
200
201        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
202        let resp = execute_with_lifecycle_hooks(req, empty_request_data(), Arc::new(ErrHandler), Some(Arc::new(hooks)))
203            .await
204            .unwrap();
205
206        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
207        let body = resp.into_body().collect().await.unwrap().to_bytes();
208        let body_str = std::str::from_utf8(body.as_ref()).unwrap();
209        assert!(body_str.contains("\"error\":\"onError hook failed:"));
210        assert!(body_str.contains("boom"));
211    }
212
213    #[tokio::test]
214    async fn on_response_hook_error_returns_on_response_hook_failed_response() {
215        let mut hooks = crate::LifecycleHooks::new();
216        hooks.add_on_response(response_hook("fail-on-response", |_resp| async move {
217            Err("boom".to_string())
218        }));
219
220        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
221        let resp = execute_with_lifecycle_hooks(req, empty_request_data(), Arc::new(OkHandler), Some(Arc::new(hooks)))
222            .await
223            .unwrap();
224
225        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
226        let body = resp.into_body().collect().await.unwrap().to_bytes();
227        let body_str = std::str::from_utf8(body.as_ref()).unwrap();
228        assert!(body_str.contains("\"error\":\"onResponse hook failed:"));
229        assert!(body_str.contains("boom"));
230    }
231
232    #[tokio::test]
233    async fn pre_validation_short_circuit_skips_handler_and_returns_response() {
234        let mut hooks = crate::LifecycleHooks::new();
235        hooks.add_pre_validation(request_hook("short-circuit", |_req| async move {
236            Ok(HookResult::ShortCircuit(
237                Response::builder()
238                    .status(StatusCode::UNAUTHORIZED)
239                    .body(Body::from("nope"))
240                    .unwrap(),
241            ))
242        }));
243
244        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
245        let resp = execute_with_lifecycle_hooks(req, empty_request_data(), Arc::new(ErrHandler), Some(Arc::new(hooks)))
246            .await
247            .unwrap();
248
249        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
250        let body = resp.into_body().collect().await.unwrap().to_bytes();
251        assert_eq!(body.as_ref(), b"nope");
252    }
253}