spikard_http/server/
lifecycle_execution.rs1use crate::handler_trait::Handler;
4use axum::body::Body;
5use axum::http::StatusCode;
6use std::sync::Arc;
7
8pub async fn execute_with_lifecycle_hooks(
17 req: axum::http::Request<Body>,
18 request_data: crate::handler_trait::RequestData,
19 handler: Arc<dyn Handler>,
20 hooks: Option<Arc<crate::LifecycleHooks>>,
21) -> Result<axum::http::Response<Body>, (axum::http::StatusCode, String)> {
22 use crate::lifecycle::HookResult;
23
24 let Some(hooks) = hooks else {
25 return handler.call(req, request_data).await;
26 };
27
28 if hooks.is_empty() {
29 return handler.call(req, request_data).await;
30 }
31
32 let req = match hooks.execute_on_request(req).await {
33 Ok(HookResult::Continue(r)) => r,
34 Ok(HookResult::ShortCircuit(response)) => return Ok(response),
35 Err(e) => {
36 let error_response = axum::http::Response::builder()
37 .status(StatusCode::INTERNAL_SERVER_ERROR)
38 .body(Body::from(format!("{{\"error\":\"onRequest hook failed: {}\"}}", e)))
39 .unwrap();
40
41 return match hooks.execute_on_error(error_response).await {
42 Ok(resp) => Ok(resp),
43 Err(_) => Ok(axum::http::Response::builder()
44 .status(StatusCode::INTERNAL_SERVER_ERROR)
45 .body(Body::from("{\"error\":\"Hook execution failed\"}"))
46 .unwrap()),
47 };
48 }
49 };
50
51 let req = match hooks.execute_pre_validation(req).await {
52 Ok(HookResult::Continue(r)) => r,
53 Ok(HookResult::ShortCircuit(response)) => return Ok(response),
54 Err(e) => {
55 let error_response = axum::http::Response::builder()
56 .status(StatusCode::INTERNAL_SERVER_ERROR)
57 .body(Body::from(format!(
58 "{{\"error\":\"preValidation hook failed: {}\"}}",
59 e
60 )))
61 .unwrap();
62
63 return match hooks.execute_on_error(error_response).await {
64 Ok(resp) => Ok(resp),
65 Err(_) => Ok(axum::http::Response::builder()
66 .status(StatusCode::INTERNAL_SERVER_ERROR)
67 .body(Body::from("{\"error\":\"Hook execution failed\"}"))
68 .unwrap()),
69 };
70 }
71 };
72
73 let req = match hooks.execute_pre_handler(req).await {
74 Ok(HookResult::Continue(r)) => r,
75 Ok(HookResult::ShortCircuit(response)) => return Ok(response),
76 Err(e) => {
77 let error_response = axum::http::Response::builder()
78 .status(StatusCode::INTERNAL_SERVER_ERROR)
79 .body(Body::from(format!("{{\"error\":\"preHandler hook failed: {}\"}}", e)))
80 .unwrap();
81
82 return match hooks.execute_on_error(error_response).await {
83 Ok(resp) => Ok(resp),
84 Err(_) => Ok(axum::http::Response::builder()
85 .status(StatusCode::INTERNAL_SERVER_ERROR)
86 .body(Body::from("{\"error\":\"Hook execution failed\"}"))
87 .unwrap()),
88 };
89 }
90 };
91
92 let response = match handler.call(req, request_data).await {
93 Ok(resp) => resp,
94 Err((status, message)) => {
95 let error_response = axum::http::Response::builder()
96 .status(status)
97 .body(Body::from(message))
98 .unwrap();
99
100 return match hooks.execute_on_error(error_response).await {
101 Ok(resp) => Ok(resp),
102 Err(e) => Ok(axum::http::Response::builder()
103 .status(StatusCode::INTERNAL_SERVER_ERROR)
104 .body(Body::from(format!("{{\"error\":\"onError hook failed: {}\"}}", e)))
105 .unwrap()),
106 };
107 }
108 };
109
110 match hooks.execute_on_response(response).await {
111 Ok(resp) => Ok(resp),
112 Err(e) => Ok(axum::http::Response::builder()
113 .status(StatusCode::INTERNAL_SERVER_ERROR)
114 .body(Body::from(format!("{{\"error\":\"onResponse hook failed: {}\"}}", e)))
115 .unwrap()),
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use crate::lifecycle::{HookResult, request_hook, response_hook};
123 use axum::http::{Request, Response, StatusCode};
124 use http_body_util::BodyExt;
125 use serde_json::json;
126 use std::collections::HashMap;
127
128 struct OkHandler;
129
130 impl Handler for OkHandler {
131 fn call(
132 &self,
133 _request: Request<Body>,
134 _request_data: crate::handler_trait::RequestData,
135 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::handler_trait::HandlerResult> + Send + '_>>
136 {
137 Box::pin(async move {
138 Ok(Response::builder()
139 .status(StatusCode::OK)
140 .body(Body::from("ok"))
141 .unwrap())
142 })
143 }
144 }
145
146 struct ErrHandler;
147
148 impl Handler for ErrHandler {
149 fn call(
150 &self,
151 _request: Request<Body>,
152 _request_data: crate::handler_trait::RequestData,
153 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::handler_trait::HandlerResult> + Send + '_>>
154 {
155 Box::pin(async move { Err((StatusCode::BAD_REQUEST, "bad".to_string())) })
156 }
157 }
158
159 fn empty_request_data() -> crate::handler_trait::RequestData {
160 crate::handler_trait::RequestData {
161 path_params: std::sync::Arc::new(HashMap::new()),
162 query_params: std::sync::Arc::new(json!({})),
163 validated_params: None,
164 raw_query_params: std::sync::Arc::new(HashMap::new()),
165 body: std::sync::Arc::new(json!(null)),
166 raw_body: None,
167 headers: std::sync::Arc::new(HashMap::new()),
168 cookies: std::sync::Arc::new(HashMap::new()),
169 method: "GET".to_string(),
170 path: "/".to_string(),
171 #[cfg(feature = "di")]
172 dependencies: None,
173 }
174 }
175
176 #[tokio::test]
177 async fn pre_validation_error_with_failing_on_error_hook_returns_fallback() {
178 let mut hooks = crate::LifecycleHooks::new();
179 hooks.add_pre_validation(request_hook("boom", |_req| async move { Err("boom".to_string()) }));
180 hooks.add_on_error(response_hook("fail-on-error", |_resp| async move {
181 Err("fail".to_string())
182 }));
183
184 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
185 let resp = execute_with_lifecycle_hooks(req, empty_request_data(), Arc::new(OkHandler), Some(Arc::new(hooks)))
186 .await
187 .unwrap();
188
189 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
190 let body = resp.into_body().collect().await.unwrap().to_bytes();
191 assert_eq!(body.as_ref(), b"{\"error\":\"Hook execution failed\"}");
192 }
193
194 #[tokio::test]
195 async fn handler_error_with_failing_on_error_hook_returns_on_error_hook_failed_response() {
196 let mut hooks = crate::LifecycleHooks::new();
197 hooks.add_on_error(response_hook("fail-on-error", |_resp| async move {
198 Err("boom".to_string())
199 }));
200
201 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
202 let resp = execute_with_lifecycle_hooks(req, empty_request_data(), Arc::new(ErrHandler), Some(Arc::new(hooks)))
203 .await
204 .unwrap();
205
206 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
207 let body = resp.into_body().collect().await.unwrap().to_bytes();
208 let body_str = std::str::from_utf8(body.as_ref()).unwrap();
209 assert!(body_str.contains("\"error\":\"onError hook failed:"));
210 assert!(body_str.contains("boom"));
211 }
212
213 #[tokio::test]
214 async fn on_response_hook_error_returns_on_response_hook_failed_response() {
215 let mut hooks = crate::LifecycleHooks::new();
216 hooks.add_on_response(response_hook("fail-on-response", |_resp| async move {
217 Err("boom".to_string())
218 }));
219
220 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
221 let resp = execute_with_lifecycle_hooks(req, empty_request_data(), Arc::new(OkHandler), Some(Arc::new(hooks)))
222 .await
223 .unwrap();
224
225 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
226 let body = resp.into_body().collect().await.unwrap().to_bytes();
227 let body_str = std::str::from_utf8(body.as_ref()).unwrap();
228 assert!(body_str.contains("\"error\":\"onResponse hook failed:"));
229 assert!(body_str.contains("boom"));
230 }
231
232 #[tokio::test]
233 async fn pre_validation_short_circuit_skips_handler_and_returns_response() {
234 let mut hooks = crate::LifecycleHooks::new();
235 hooks.add_pre_validation(request_hook("short-circuit", |_req| async move {
236 Ok(HookResult::ShortCircuit(
237 Response::builder()
238 .status(StatusCode::UNAUTHORIZED)
239 .body(Body::from("nope"))
240 .unwrap(),
241 ))
242 }));
243
244 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
245 let resp = execute_with_lifecycle_hooks(req, empty_request_data(), Arc::new(ErrHandler), Some(Arc::new(hooks)))
246 .await
247 .unwrap();
248
249 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
250 let body = resp.into_body().collect().await.unwrap().to_bytes();
251 assert_eq!(body.as_ref(), b"nope");
252 }
253}