1use axum::{
2 body::Body,
3 http::{Request, Response},
4};
5use std::sync::Arc;
6
7pub mod adapter;
8
9pub use spikard_core::lifecycle::{HookResult, LifecycleHook};
10
11pub type LifecycleHooks = spikard_core::lifecycle::LifecycleHooks<Request<Body>, Response<Body>>;
12pub type LifecycleHooksBuilder = spikard_core::lifecycle::LifecycleHooksBuilder<Request<Body>, Response<Body>>;
13
14#[cfg(not(target_arch = "wasm32"))]
16pub fn request_hook<F, Fut>(name: impl Into<String>, func: F) -> Arc<dyn LifecycleHook<Request<Body>, Response<Body>>>
17where
18 F: Fn(Request<Body>) -> Fut + Send + Sync + 'static,
19 Fut: std::future::Future<Output = Result<HookResult<Request<Body>, Response<Body>>, String>> + Send + 'static,
20{
21 spikard_core::lifecycle::request_hook::<Request<Body>, Response<Body>, _, _>(name, func)
22}
23
24#[cfg(target_arch = "wasm32")]
26pub fn request_hook<F, Fut>(name: impl Into<String>, func: F) -> Arc<dyn LifecycleHook<Request<Body>, Response<Body>>>
27where
28 F: Fn(Request<Body>) -> Fut + Send + Sync + 'static,
29 Fut: std::future::Future<Output = Result<HookResult<Request<Body>, Response<Body>>, String>> + 'static,
30{
31 spikard_core::lifecycle::request_hook::<Request<Body>, Response<Body>, _, _>(name, func)
32}
33
34#[cfg(not(target_arch = "wasm32"))]
36pub fn response_hook<F, Fut>(name: impl Into<String>, func: F) -> Arc<dyn LifecycleHook<Request<Body>, Response<Body>>>
37where
38 F: Fn(Response<Body>) -> Fut + Send + Sync + 'static,
39 Fut: std::future::Future<Output = Result<HookResult<Response<Body>, Response<Body>>, String>> + Send + 'static,
40{
41 spikard_core::lifecycle::response_hook::<Request<Body>, Response<Body>, _, _>(name, func)
42}
43
44#[cfg(target_arch = "wasm32")]
46pub fn response_hook<F, Fut>(name: impl Into<String>, func: F) -> Arc<dyn LifecycleHook<Request<Body>, Response<Body>>>
47where
48 F: Fn(Response<Body>) -> Fut + Send + Sync + 'static,
49 Fut: std::future::Future<Output = Result<HookResult<Response<Body>, Response<Body>>, String>> + 'static,
50{
51 spikard_core::lifecycle::response_hook::<Request<Body>, Response<Body>, _, _>(name, func)
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57 use axum::body::Body;
58 use axum::http::{Request, Response, StatusCode};
59 use std::future::Future;
60 use std::pin::Pin;
61
62 struct ContinueHook {
64 name: String,
65 }
66
67 impl LifecycleHook<Request<Body>, Response<Body>> for ContinueHook {
68 fn name(&self) -> &str {
69 &self.name
70 }
71
72 fn execute_request<'a>(
73 &self,
74 req: Request<Body>,
75 ) -> Pin<Box<dyn Future<Output = Result<HookResult<Request<Body>, Response<Body>>, String>> + Send + 'a>>
76 {
77 Box::pin(async move { Ok(HookResult::Continue(req)) })
78 }
79
80 fn execute_response<'a>(
81 &self,
82 resp: Response<Body>,
83 ) -> Pin<Box<dyn Future<Output = Result<HookResult<Response<Body>, Response<Body>>, String>> + Send + 'a>>
84 {
85 Box::pin(async move { Ok(HookResult::Continue(resp)) })
86 }
87 }
88
89 struct ShortCircuitHook {
91 name: String,
92 }
93
94 impl LifecycleHook<Request<Body>, Response<Body>> for ShortCircuitHook {
95 fn name(&self) -> &str {
96 &self.name
97 }
98
99 fn execute_request<'a>(
100 &self,
101 _req: Request<Body>,
102 ) -> Pin<Box<dyn Future<Output = Result<HookResult<Request<Body>, Response<Body>>, String>> + Send + 'a>>
103 {
104 Box::pin(async move {
105 let response = Response::builder()
106 .status(StatusCode::UNAUTHORIZED)
107 .body(Body::from("Unauthorized"))
108 .unwrap();
109 Ok(HookResult::ShortCircuit(response))
110 })
111 }
112
113 fn execute_response<'a>(
114 &self,
115 _resp: Response<Body>,
116 ) -> Pin<Box<dyn Future<Output = Result<HookResult<Response<Body>, Response<Body>>, String>> + Send + 'a>>
117 {
118 Box::pin(async move {
119 let response = Response::builder()
120 .status(StatusCode::UNAUTHORIZED)
121 .body(Body::from("Unauthorized"))
122 .unwrap();
123 Ok(HookResult::ShortCircuit(response))
124 })
125 }
126 }
127
128 #[tokio::test]
129 async fn test_empty_hooks_fast_path() {
130 let hooks = LifecycleHooks::new();
131 assert!(hooks.is_empty());
132
133 let req = Request::builder().body(Body::empty()).unwrap();
134 let result = hooks.execute_on_request(req).await.unwrap();
135 assert!(matches!(result, HookResult::Continue(_)));
136 }
137
138 #[tokio::test]
139 async fn test_on_request_continue() {
140 let mut hooks = LifecycleHooks::new();
141 hooks.add_on_request(Arc::new(ContinueHook {
142 name: "test".to_string(),
143 }));
144
145 let req = Request::builder().body(Body::empty()).unwrap();
146 let result = hooks.execute_on_request(req).await.unwrap();
147 assert!(matches!(result, HookResult::Continue(_)));
148 }
149
150 #[tokio::test]
151 async fn test_on_request_short_circuit() {
152 let mut hooks = LifecycleHooks::new();
153 hooks.add_on_request(Arc::new(ShortCircuitHook {
154 name: "auth_check".to_string(),
155 }));
156
157 let req = Request::builder().body(Body::empty()).unwrap();
158 let result = hooks.execute_on_request(req).await.unwrap();
159
160 match result {
161 HookResult::ShortCircuit(resp) => {
162 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
163 }
164 HookResult::Continue(_) => panic!("Expected ShortCircuit, got Continue"),
165 }
166 }
167
168 #[tokio::test]
169 async fn test_multiple_hooks_in_order() {
170 let mut hooks = LifecycleHooks::new();
171
172 hooks.add_on_request(Arc::new(ContinueHook {
173 name: "first".to_string(),
174 }));
175 hooks.add_on_request(Arc::new(ContinueHook {
176 name: "second".to_string(),
177 }));
178
179 let req = Request::builder().body(Body::empty()).unwrap();
180 let result = hooks.execute_on_request(req).await.unwrap();
181 assert!(matches!(result, HookResult::Continue(_)));
182 }
183
184 #[tokio::test]
185 async fn test_short_circuit_stops_execution() {
186 let mut hooks = LifecycleHooks::new();
187
188 hooks.add_on_request(Arc::new(ShortCircuitHook {
189 name: "short_circuit".to_string(),
190 }));
191 hooks.add_on_request(Arc::new(ContinueHook {
192 name: "never_executed".to_string(),
193 }));
194
195 let req = Request::builder().body(Body::empty()).unwrap();
196 let result = hooks.execute_on_request(req).await.unwrap();
197
198 match result {
199 HookResult::ShortCircuit(_) => {}
200 HookResult::Continue(_) => panic!("Expected ShortCircuit, got Continue"),
201 }
202 }
203
204 #[tokio::test]
205 async fn test_on_response_hooks() {
206 let mut hooks = LifecycleHooks::new();
207 hooks.add_on_response(Arc::new(ContinueHook {
208 name: "response_hook".to_string(),
209 }));
210
211 let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
212
213 let result = hooks.execute_on_response(resp).await.unwrap();
214 assert_eq!(result.status(), StatusCode::OK);
215 }
216
217 #[tokio::test]
218 async fn test_request_hook_builder() {
219 let hook = request_hook("test", |req| async move { Ok(HookResult::Continue(req)) });
220
221 let req = Request::builder().body(Body::empty()).unwrap();
222 let result = hook.execute_request(req).await.unwrap();
223
224 assert!(matches!(result, HookResult::Continue(_)));
225 }
226
227 #[tokio::test]
228 async fn test_request_hook_with_modification() {
229 let hook = request_hook("add_header", |mut req| async move {
230 req.headers_mut()
231 .insert("X-Custom-Header", axum::http::HeaderValue::from_static("test-value"));
232 Ok(HookResult::Continue(req))
233 });
234
235 let req = Request::builder().body(Body::empty()).unwrap();
236 let result = hook.execute_request(req).await.unwrap();
237
238 match result {
239 HookResult::Continue(req) => {
240 assert_eq!(req.headers().get("X-Custom-Header").unwrap(), "test-value");
241 }
242 HookResult::ShortCircuit(_) => panic!("Expected Continue"),
243 }
244 }
245
246 #[tokio::test]
247 async fn test_request_hook_short_circuit() {
248 let hook = request_hook("auth", |_req| async move {
249 let response = Response::builder()
250 .status(StatusCode::UNAUTHORIZED)
251 .body(Body::from("Unauthorized"))
252 .unwrap();
253 Ok(HookResult::ShortCircuit(response))
254 });
255
256 let req = Request::builder().body(Body::empty()).unwrap();
257 let result = hook.execute_request(req).await.unwrap();
258
259 match result {
260 HookResult::ShortCircuit(resp) => {
261 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
262 }
263 HookResult::Continue(_) => panic!("Expected ShortCircuit"),
264 }
265 }
266
267 #[tokio::test]
268 async fn test_response_hook_builder() {
269 let hook = response_hook("security", |mut resp| async move {
270 resp.headers_mut()
271 .insert("X-Frame-Options", axum::http::HeaderValue::from_static("DENY"));
272 Ok(HookResult::Continue(resp))
273 });
274
275 let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
276
277 let result = hook.execute_response(resp).await.unwrap();
278
279 match result {
280 HookResult::Continue(resp) => {
281 assert_eq!(resp.headers().get("X-Frame-Options").unwrap(), "DENY");
282 assert_eq!(resp.status(), StatusCode::OK);
283 }
284 HookResult::ShortCircuit(_) => panic!("Expected Continue"),
285 }
286 }
287
288 #[tokio::test]
289 async fn test_builder_pattern() {
290 let hooks = LifecycleHooks::builder()
291 .on_request(request_hook(
292 "logger",
293 |req| async move { Ok(HookResult::Continue(req)) },
294 ))
295 .pre_handler(request_hook("auth", |req| async move { Ok(HookResult::Continue(req)) }))
296 .on_response(response_hook("security", |resp| async move {
297 Ok(HookResult::Continue(resp))
298 }))
299 .build();
300
301 assert!(!hooks.is_empty());
302
303 let req = Request::builder().body(Body::empty()).unwrap();
304 let result = hooks.execute_on_request(req).await.unwrap();
305 assert!(matches!(result, HookResult::Continue(_)));
306 }
307
308 #[tokio::test]
309 async fn test_builder_with_multiple_hooks() {
310 let hooks = LifecycleHooks::builder()
311 .on_request(request_hook("first", |mut req| async move {
312 req.headers_mut()
313 .insert("X-First", axum::http::HeaderValue::from_static("1"));
314 Ok(HookResult::Continue(req))
315 }))
316 .on_request(request_hook("second", |mut req| async move {
317 req.headers_mut()
318 .insert("X-Second", axum::http::HeaderValue::from_static("2"));
319 Ok(HookResult::Continue(req))
320 }))
321 .build();
322
323 let req = Request::builder().body(Body::empty()).unwrap();
324 let result = hooks.execute_on_request(req).await.unwrap();
325
326 match result {
327 HookResult::Continue(req) => {
328 assert_eq!(req.headers().get("X-First").unwrap(), "1");
329 assert_eq!(req.headers().get("X-Second").unwrap(), "2");
330 }
331 HookResult::ShortCircuit(_) => panic!("Expected Continue"),
332 }
333 }
334
335 #[tokio::test]
336 async fn test_builder_short_circuit_stops_chain() {
337 let hooks = LifecycleHooks::builder()
338 .on_request(request_hook(
339 "first",
340 |req| async move { Ok(HookResult::Continue(req)) },
341 ))
342 .on_request(request_hook("short_circuit", |_req| async move {
343 let response = Response::builder()
344 .status(StatusCode::FORBIDDEN)
345 .body(Body::from("Blocked"))
346 .unwrap();
347 Ok(HookResult::ShortCircuit(response))
348 }))
349 .on_request(request_hook("never_called", |mut req| async move {
350 req.headers_mut()
351 .insert("X-Should-Not-Exist", axum::http::HeaderValue::from_static("value"));
352 Ok(HookResult::Continue(req))
353 }))
354 .build();
355
356 let req = Request::builder().body(Body::empty()).unwrap();
357 let result = hooks.execute_on_request(req).await.unwrap();
358
359 match result {
360 HookResult::ShortCircuit(resp) => {
361 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
362 }
363 HookResult::Continue(_) => panic!("Expected ShortCircuit"),
364 }
365 }
366
367 #[tokio::test]
368 async fn test_all_hook_types() {
369 let hooks = LifecycleHooks::builder()
370 .on_request(request_hook("on_request", |req| async move {
371 Ok(HookResult::Continue(req))
372 }))
373 .pre_validation(request_hook("pre_validation", |req| async move {
374 Ok(HookResult::Continue(req))
375 }))
376 .pre_handler(request_hook("pre_handler", |req| async move {
377 Ok(HookResult::Continue(req))
378 }))
379 .on_response(response_hook("on_response", |resp| async move {
380 Ok(HookResult::Continue(resp))
381 }))
382 .on_error(response_hook("on_error", |resp| async move {
383 Ok(HookResult::Continue(resp))
384 }))
385 .build();
386
387 assert!(!hooks.is_empty());
388
389 let req = Request::builder().body(Body::empty()).unwrap();
390 assert!(matches!(
391 hooks.execute_on_request(req).await.unwrap(),
392 HookResult::Continue(_)
393 ));
394
395 let req = Request::builder().body(Body::empty()).unwrap();
396 assert!(matches!(
397 hooks.execute_pre_validation(req).await.unwrap(),
398 HookResult::Continue(_)
399 ));
400
401 let req = Request::builder().body(Body::empty()).unwrap();
402 assert!(matches!(
403 hooks.execute_pre_handler(req).await.unwrap(),
404 HookResult::Continue(_)
405 ));
406
407 let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
408 let result = hooks.execute_on_response(resp).await.unwrap();
409 assert_eq!(result.status(), StatusCode::OK);
410
411 let resp = Response::builder()
412 .status(StatusCode::INTERNAL_SERVER_ERROR)
413 .body(Body::empty())
414 .unwrap();
415 let result = hooks.execute_on_error(resp).await.unwrap();
416 assert_eq!(result.status(), StatusCode::INTERNAL_SERVER_ERROR);
417 }
418
419 #[tokio::test]
420 async fn test_empty_builder() {
421 let hooks = LifecycleHooks::builder().build();
422 assert!(hooks.is_empty());
423
424 let req = Request::builder().body(Body::empty()).unwrap();
425 let result = hooks.execute_on_request(req).await.unwrap();
426 assert!(matches!(result, HookResult::Continue(_)));
427 }
428
429 #[tokio::test]
430 async fn test_hook_chaining_modifies_request_sequentially() {
431 let hooks = LifecycleHooks::builder()
432 .on_request(request_hook("add_header_1", |mut req| async move {
433 req.headers_mut()
434 .insert("X-Chain-1", axum::http::HeaderValue::from_static("first"));
435 Ok(HookResult::Continue(req))
436 }))
437 .on_request(request_hook("add_header_2", |mut req| async move {
438 req.headers_mut()
439 .insert("X-Chain-2", axum::http::HeaderValue::from_static("second"));
440 Ok(HookResult::Continue(req))
441 }))
442 .on_request(request_hook("add_header_3", |mut req| async move {
443 req.headers_mut()
444 .insert("X-Chain-3", axum::http::HeaderValue::from_static("third"));
445 Ok(HookResult::Continue(req))
446 }))
447 .build();
448
449 let req = Request::builder().body(Body::empty()).unwrap();
450 let result = hooks.execute_on_request(req).await.unwrap();
451
452 match result {
453 HookResult::Continue(req) => {
454 assert_eq!(req.headers().get("X-Chain-1").unwrap(), "first");
455 assert_eq!(req.headers().get("X-Chain-2").unwrap(), "second");
456 assert_eq!(req.headers().get("X-Chain-3").unwrap(), "third");
457 }
458 HookResult::ShortCircuit(_) => panic!("Expected Continue"),
459 }
460 }
461
462 #[tokio::test]
463 async fn test_response_hook_chaining_modifies_status_and_headers() {
464 let hooks = LifecycleHooks::builder()
465 .on_response(response_hook("add_security_header", |mut resp| async move {
466 resp.headers_mut().insert(
467 "X-Content-Type-Options",
468 axum::http::HeaderValue::from_static("nosniff"),
469 );
470 Ok(HookResult::Continue(resp))
471 }))
472 .on_response(response_hook("add_cache_header", |mut resp| async move {
473 resp.headers_mut()
474 .insert("Cache-Control", axum::http::HeaderValue::from_static("no-cache"));
475 Ok(HookResult::Continue(resp))
476 }))
477 .on_response(response_hook("add_custom_header", |mut resp| async move {
478 resp.headers_mut()
479 .insert("X-Custom", axum::http::HeaderValue::from_static("value"));
480 Ok(HookResult::Continue(resp))
481 }))
482 .build();
483
484 let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
485
486 let result = hooks.execute_on_response(resp).await.unwrap();
487
488 assert_eq!(result.status(), StatusCode::OK);
489 assert_eq!(result.headers().get("X-Content-Type-Options").unwrap(), "nosniff");
490 assert_eq!(result.headers().get("Cache-Control").unwrap(), "no-cache");
491 assert_eq!(result.headers().get("X-Custom").unwrap(), "value");
492 }
493
494 #[tokio::test]
495 async fn test_pre_validation_and_pre_handler_chaining() {
496 let hooks = LifecycleHooks::builder()
497 .pre_validation(request_hook("validate_auth", |mut req| async move {
498 req.headers_mut()
499 .insert("X-Validated", axum::http::HeaderValue::from_static("true"));
500 Ok(HookResult::Continue(req))
501 }))
502 .pre_handler(request_hook("prepare_handler", |mut req| async move {
503 req.headers_mut()
504 .insert("X-Prepared", axum::http::HeaderValue::from_static("true"));
505 Ok(HookResult::Continue(req))
506 }))
507 .build();
508
509 let req = Request::builder().body(Body::empty()).unwrap();
510 let result = hooks.execute_pre_validation(req).await.unwrap();
511
512 match result {
513 HookResult::Continue(req) => {
514 assert_eq!(req.headers().get("X-Validated").unwrap(), "true");
515 assert!(!req.headers().contains_key("X-Prepared"));
516 }
517 HookResult::ShortCircuit(_) => panic!("Expected Continue"),
518 }
519
520 let req = Request::builder()
521 .header("X-Validated", "true")
522 .body(Body::empty())
523 .unwrap();
524 let result = hooks.execute_pre_handler(req).await.unwrap();
525
526 match result {
527 HookResult::Continue(req) => {
528 assert_eq!(req.headers().get("X-Prepared").unwrap(), "true");
529 }
530 HookResult::ShortCircuit(_) => panic!("Expected Continue"),
531 }
532 }
533
534 #[tokio::test]
535 async fn test_hook_chain_with_state_passing() {
536 let hooks = LifecycleHooks::builder()
537 .on_request(request_hook("add_user_id", |mut req| async move {
538 req.headers_mut()
539 .insert("X-User-ID", axum::http::HeaderValue::from_static("123"));
540 Ok(HookResult::Continue(req))
541 }))
542 .on_request(request_hook("add_session_id", |mut req| async move {
543 if let Some(user_id) = req.headers().get("X-User-ID") {
544 if user_id == "123" {
545 req.headers_mut()
546 .insert("X-Session-ID", axum::http::HeaderValue::from_static("session_abc"));
547 }
548 }
549 Ok(HookResult::Continue(req))
550 }))
551 .build();
552
553 let req = Request::builder().body(Body::empty()).unwrap();
554 let result = hooks.execute_on_request(req).await.unwrap();
555
556 match result {
557 HookResult::Continue(req) => {
558 assert_eq!(req.headers().get("X-User-ID").unwrap(), "123");
559 assert_eq!(req.headers().get("X-Session-ID").unwrap(), "session_abc");
560 }
561 HookResult::ShortCircuit(_) => panic!("Expected Continue"),
562 }
563 }
564
565 #[tokio::test]
566 async fn test_pre_validation_short_circuit_stops_subsequent_hooks() {
567 let hooks = LifecycleHooks::builder()
568 .on_request(request_hook("on_request", |req| async move {
569 println!("on_request executed");
570 Ok(HookResult::Continue(req))
571 }))
572 .pre_validation(request_hook("pre_validation_abort", |_req| async move {
573 println!("pre_validation executed - short circuiting");
574 let response = Response::builder()
575 .status(StatusCode::BAD_REQUEST)
576 .body(Body::from("Validation failed"))
577 .unwrap();
578 Ok(HookResult::ShortCircuit(response))
579 }))
580 .pre_handler(request_hook("pre_handler", |req| async move {
581 println!("pre_handler executed - should NOT happen");
582 Ok(HookResult::Continue(req))
583 }))
584 .build();
585
586 let req = Request::builder().body(Body::empty()).unwrap();
587 let result = hooks.execute_pre_validation(req).await.unwrap();
588
589 match result {
590 HookResult::ShortCircuit(resp) => {
591 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
592 }
593 HookResult::Continue(_) => panic!("Expected ShortCircuit"),
594 }
595 }
596
597 #[tokio::test]
598 async fn test_pre_handler_short_circuit_returns_early_response() {
599 let hooks = LifecycleHooks::builder()
600 .pre_validation(request_hook("pre_validation", |req| async move {
601 Ok(HookResult::Continue(req))
602 }))
603 .pre_handler(request_hook("rate_limit_check", |_req| async move {
604 let response = Response::builder()
605 .status(StatusCode::TOO_MANY_REQUESTS)
606 .body(Body::from("Rate limit exceeded"))
607 .unwrap();
608 Ok(HookResult::ShortCircuit(response))
609 }))
610 .build();
611
612 let req = Request::builder().body(Body::empty()).unwrap();
613 let result = hooks.execute_pre_handler(req).await.unwrap();
614
615 match result {
616 HookResult::ShortCircuit(resp) => {
617 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
618 }
619 HookResult::Continue(_) => panic!("Expected ShortCircuit"),
620 }
621 }
622
623 #[tokio::test]
624 async fn test_short_circuit_in_middle_of_chain() {
625 let hooks = LifecycleHooks::builder()
626 .on_request(request_hook("hook_1", |mut req| async move {
627 req.headers_mut()
628 .insert("X-Executed-1", axum::http::HeaderValue::from_static("yes"));
629 Ok(HookResult::Continue(req))
630 }))
631 .on_request(request_hook("hook_2_abort", |_req| async move {
632 let response = Response::builder()
633 .status(StatusCode::FORBIDDEN)
634 .body(Body::from("Access denied"))
635 .unwrap();
636 Ok(HookResult::ShortCircuit(response))
637 }))
638 .on_request(request_hook("hook_3", |mut req| async move {
639 req.headers_mut()
640 .insert("X-Executed-3", axum::http::HeaderValue::from_static("yes"));
641 Ok(HookResult::Continue(req))
642 }))
643 .build();
644
645 let req = Request::builder().body(Body::empty()).unwrap();
646 let result = hooks.execute_on_request(req).await.unwrap();
647
648 match result {
649 HookResult::ShortCircuit(resp) => {
650 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
651 }
652 HookResult::Continue(_) => panic!("Expected ShortCircuit"),
653 }
654 }
655
656 #[tokio::test]
657 async fn test_short_circuit_with_custom_response_headers() {
658 let hooks = LifecycleHooks::builder()
659 .pre_validation(request_hook("auth_check", |_req| async move {
660 let response = Response::builder()
661 .status(StatusCode::UNAUTHORIZED)
662 .header("WWW-Authenticate", "Bearer realm=\"api\"")
663 .body(Body::from("Authorization required"))
664 .unwrap();
665 Ok(HookResult::ShortCircuit(response))
666 }))
667 .build();
668
669 let req = Request::builder().body(Body::empty()).unwrap();
670 let result = hooks.execute_pre_validation(req).await.unwrap();
671
672 match result {
673 HookResult::ShortCircuit(resp) => {
674 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
675 assert_eq!(resp.headers().get("WWW-Authenticate").unwrap(), "Bearer realm=\"api\"");
676 }
677 HookResult::Continue(_) => panic!("Expected ShortCircuit"),
678 }
679 }
680
681 #[tokio::test]
682 async fn test_hook_error_propagates_through_chain() {
683 let hooks = LifecycleHooks::builder()
684 .on_request(request_hook("good_hook", |mut req| async move {
685 req.headers_mut()
686 .insert("X-Good", axum::http::HeaderValue::from_static("yes"));
687 Ok(HookResult::Continue(req))
688 }))
689 .on_request(request_hook("bad_hook", |_req| async move {
690 Err("Something went wrong in hook".to_string())
691 }))
692 .build();
693
694 let req = Request::builder().body(Body::empty()).unwrap();
695 let result = hooks.execute_on_request(req).await;
696
697 assert!(result.is_err());
698 assert_eq!(result.unwrap_err(), "Something went wrong in hook");
699 }
700
701 #[tokio::test]
702 async fn test_error_in_pre_validation_stops_chain() {
703 let hooks = LifecycleHooks::builder()
704 .pre_validation(request_hook("validation_hook", |_req| async move {
705 Err("Validation error: invalid input".to_string())
706 }))
707 .pre_handler(request_hook("handler_prep", |req| async move {
708 Ok(HookResult::Continue(req))
709 }))
710 .build();
711
712 let req = Request::builder().body(Body::empty()).unwrap();
713 let result = hooks.execute_pre_validation(req).await;
714
715 assert!(result.is_err());
716 assert!(result.unwrap_err().contains("Validation error"));
717 }
718
719 #[tokio::test]
720 async fn test_on_error_hook_transforms_response() {
721 let hooks = LifecycleHooks::builder()
722 .on_error(response_hook("transform_error", |mut resp| async move {
723 resp.headers_mut()
724 .insert("X-Error-Handled", axum::http::HeaderValue::from_static("true"));
725
726 let _status = resp.status();
727 Ok(HookResult::Continue(resp))
728 }))
729 .build();
730
731 let resp = Response::builder()
732 .status(StatusCode::INTERNAL_SERVER_ERROR)
733 .body(Body::empty())
734 .unwrap();
735
736 let result = hooks.execute_on_error(resp).await.unwrap();
737
738 assert_eq!(result.headers().get("X-Error-Handled").unwrap(), "true");
739 }
740
741 #[tokio::test]
742 async fn test_response_hook_error_propagates() {
743 let hooks = LifecycleHooks::builder()
744 .on_response(response_hook("good_response_hook", |mut resp| async move {
745 resp.headers_mut()
746 .insert("X-Processed", axum::http::HeaderValue::from_static("yes"));
747 Ok(HookResult::Continue(resp))
748 }))
749 .on_response(response_hook("bad_response_hook", |_resp| async move {
750 Err("Error processing response".to_string())
751 }))
752 .build();
753
754 let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
755
756 let result = hooks.execute_on_response(resp).await;
757
758 assert!(result.is_err());
759 assert_eq!(result.unwrap_err(), "Error processing response");
760 }
761
762 #[tokio::test]
763 async fn test_error_hook_error_propagates() {
764 let hooks = LifecycleHooks::builder()
765 .on_error(response_hook("error_hook_1", |mut resp| async move {
766 resp.headers_mut()
767 .insert("X-Error-Processed", axum::http::HeaderValue::from_static("1"));
768 Ok(HookResult::Continue(resp))
769 }))
770 .on_error(response_hook("error_hook_2_fails", |_resp| async move {
771 Err("Error in error hook".to_string())
772 }))
773 .build();
774
775 let resp = Response::builder()
776 .status(StatusCode::INTERNAL_SERVER_ERROR)
777 .body(Body::empty())
778 .unwrap();
779
780 let result = hooks.execute_on_error(resp).await;
781
782 assert!(result.is_err());
783 assert_eq!(result.unwrap_err(), "Error in error hook");
784 }
785
786 #[tokio::test]
787 async fn test_on_request_adds_multiple_headers() {
788 let hooks = LifecycleHooks::builder()
789 .on_request(request_hook("add_request_headers", |mut req| async move {
790 req.headers_mut()
791 .insert("X-Request-ID", axum::http::HeaderValue::from_static("req_123"));
792 req.headers_mut()
793 .insert("X-Timestamp", axum::http::HeaderValue::from_static("2025-01-01"));
794 req.headers_mut()
795 .insert("X-Processed", axum::http::HeaderValue::from_static("true"));
796 Ok(HookResult::Continue(req))
797 }))
798 .build();
799
800 let req = Request::builder().body(Body::empty()).unwrap();
801 let result = hooks.execute_on_request(req).await.unwrap();
802
803 match result {
804 HookResult::Continue(req) => {
805 assert_eq!(req.headers().get("X-Request-ID").unwrap(), "req_123");
806 assert_eq!(req.headers().get("X-Timestamp").unwrap(), "2025-01-01");
807 assert_eq!(req.headers().get("X-Processed").unwrap(), "true");
808 }
809 HookResult::ShortCircuit(_) => panic!("Expected Continue"),
810 }
811 }
812
813 #[tokio::test]
814 async fn test_on_response_adds_security_headers() {
815 let hooks = LifecycleHooks::builder()
816 .on_response(response_hook("add_security_headers", |mut resp| async move {
817 resp.headers_mut()
818 .insert("X-Frame-Options", axum::http::HeaderValue::from_static("DENY"));
819 resp.headers_mut().insert(
820 "X-Content-Type-Options",
821 axum::http::HeaderValue::from_static("nosniff"),
822 );
823 resp.headers_mut().insert(
824 "Strict-Transport-Security",
825 axum::http::HeaderValue::from_static("max-age=31536000"),
826 );
827 Ok(HookResult::Continue(resp))
828 }))
829 .build();
830
831 let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
832
833 let result = hooks.execute_on_response(resp).await.unwrap();
834
835 assert_eq!(result.headers().get("X-Frame-Options").unwrap(), "DENY");
836 assert_eq!(result.headers().get("X-Content-Type-Options").unwrap(), "nosniff");
837 assert_eq!(
838 result.headers().get("Strict-Transport-Security").unwrap(),
839 "max-age=31536000"
840 );
841 }
842
843 #[tokio::test]
844 async fn test_pre_handler_modifies_request_before_execution() {
845 let hooks = LifecycleHooks::builder()
846 .pre_handler(request_hook("inject_context", |mut req| async move {
847 req.headers_mut().insert(
848 "X-Handler-Context",
849 axum::http::HeaderValue::from_static("context_data"),
850 );
851 req.headers_mut()
852 .insert("X-Injected", axum::http::HeaderValue::from_static("true"));
853 Ok(HookResult::Continue(req))
854 }))
855 .build();
856
857 let req = Request::builder().body(Body::empty()).unwrap();
858 let result = hooks.execute_pre_handler(req).await.unwrap();
859
860 match result {
861 HookResult::Continue(req) => {
862 assert_eq!(req.headers().get("X-Handler-Context").unwrap(), "context_data");
863 assert_eq!(req.headers().get("X-Injected").unwrap(), "true");
864 }
865 HookResult::ShortCircuit(_) => panic!("Expected Continue"),
866 }
867 }
868
869 #[tokio::test]
870 async fn test_register_multiple_hooks_different_types() {
871 let mut hooks = LifecycleHooks::new();
872
873 hooks.add_on_request(request_hook("on_request_1", |req| async move {
874 Ok(HookResult::Continue(req))
875 }));
876
877 hooks.add_pre_validation(request_hook("pre_validation_1", |req| async move {
878 Ok(HookResult::Continue(req))
879 }));
880
881 hooks.add_pre_handler(request_hook("pre_handler_1", |req| async move {
882 Ok(HookResult::Continue(req))
883 }));
884
885 hooks.add_on_response(response_hook("on_response_1", |resp| async move {
886 Ok(HookResult::Continue(resp))
887 }));
888
889 hooks.add_on_error(response_hook("on_error_1", |resp| async move {
890 Ok(HookResult::Continue(resp))
891 }));
892
893 assert!(!hooks.is_empty());
894 }
895
896 #[tokio::test]
897 async fn test_builder_composition_with_request_and_response_hooks() {
898 let hooks = LifecycleHooks::builder()
899 .on_request(request_hook("req_1", |mut req| async move {
900 req.headers_mut()
901 .insert("X-R1", axum::http::HeaderValue::from_static("1"));
902 Ok(HookResult::Continue(req))
903 }))
904 .on_request(request_hook("req_2", |mut req| async move {
905 req.headers_mut()
906 .insert("X-R2", axum::http::HeaderValue::from_static("2"));
907 Ok(HookResult::Continue(req))
908 }))
909 .on_response(response_hook("resp_1", |mut resp| async move {
910 resp.headers_mut()
911 .insert("X-Resp1", axum::http::HeaderValue::from_static("resp1"));
912 Ok(HookResult::Continue(resp))
913 }))
914 .on_response(response_hook("resp_2", |mut resp| async move {
915 resp.headers_mut()
916 .insert("X-Resp2", axum::http::HeaderValue::from_static("resp2"));
917 Ok(HookResult::Continue(resp))
918 }))
919 .build();
920
921 let req = Request::builder().body(Body::empty()).unwrap();
922 let req_result = hooks.execute_on_request(req).await.unwrap();
923
924 match req_result {
925 HookResult::Continue(req) => {
926 assert_eq!(req.headers().get("X-R1").unwrap(), "1");
927 assert_eq!(req.headers().get("X-R2").unwrap(), "2");
928 }
929 HookResult::ShortCircuit(_) => panic!("Expected Continue"),
930 }
931
932 let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
933 let resp_result = hooks.execute_on_response(resp).await.unwrap();
934
935 assert_eq!(resp_result.headers().get("X-Resp1").unwrap(), "resp1");
936 assert_eq!(resp_result.headers().get("X-Resp2").unwrap(), "resp2");
937 }
938
939 #[tokio::test]
940 async fn test_multiple_hooks_accumulate_state() {
941 let hooks = LifecycleHooks::builder()
942 .on_request(request_hook("init_counter", |mut req| async move {
943 req.headers_mut()
944 .insert("X-Count", axum::http::HeaderValue::from_static("0"));
945 Ok(HookResult::Continue(req))
946 }))
947 .on_request(request_hook("increment_1", |mut req| async move {
948 if let Some(count_header) = req.headers().get("X-Count") {
949 if count_header == "0" {
950 req.headers_mut()
951 .insert("X-Count", axum::http::HeaderValue::from_static("1"));
952 }
953 }
954 Ok(HookResult::Continue(req))
955 }))
956 .on_request(request_hook("increment_2", |mut req| async move {
957 if let Some(count_header) = req.headers().get("X-Count") {
958 if count_header == "1" {
959 req.headers_mut()
960 .insert("X-Count", axum::http::HeaderValue::from_static("2"));
961 }
962 }
963 Ok(HookResult::Continue(req))
964 }))
965 .build();
966
967 let req = Request::builder().body(Body::empty()).unwrap();
968 let result = hooks.execute_on_request(req).await.unwrap();
969
970 match result {
971 HookResult::Continue(req) => {
972 assert_eq!(req.headers().get("X-Count").unwrap(), "2");
973 }
974 HookResult::ShortCircuit(_) => panic!("Expected Continue"),
975 }
976 }
977
978 #[tokio::test]
979 async fn test_first_hook_short_circuits_second_continues() {
980 let hooks = LifecycleHooks::builder()
981 .on_request(request_hook("early_exit", |_req| async move {
982 let response = Response::builder()
983 .status(StatusCode::FORBIDDEN)
984 .body(Body::from("Early exit"))
985 .unwrap();
986 Ok(HookResult::ShortCircuit(response))
987 }))
988 .on_request(request_hook("never_runs", |req| async move {
989 Ok(HookResult::Continue(req))
990 }))
991 .build();
992
993 let req = Request::builder().body(Body::empty()).unwrap();
994 let result = hooks.execute_on_request(req).await.unwrap();
995
996 match result {
997 HookResult::ShortCircuit(resp) => {
998 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
999 }
1000 HookResult::Continue(_) => panic!("Expected ShortCircuit"),
1001 }
1002 }
1003
1004 #[tokio::test]
1005 async fn test_all_hook_phases_in_sequence() {
1006 let hooks = LifecycleHooks::builder()
1007 .on_request(request_hook("on_request", |req| async move {
1008 Ok(HookResult::Continue(req))
1009 }))
1010 .pre_validation(request_hook("pre_validation", |req| async move {
1011 Ok(HookResult::Continue(req))
1012 }))
1013 .pre_handler(request_hook("pre_handler", |req| async move {
1014 Ok(HookResult::Continue(req))
1015 }))
1016 .on_response(response_hook("on_response", |resp| async move {
1017 Ok(HookResult::Continue(resp))
1018 }))
1019 .on_error(response_hook("on_error", |resp| async move {
1020 Ok(HookResult::Continue(resp))
1021 }))
1022 .build();
1023
1024 let req = Request::builder().body(Body::empty()).unwrap();
1025 let _ = hooks.execute_on_request(req).await;
1026
1027 let req = Request::builder().body(Body::empty()).unwrap();
1028 let _ = hooks.execute_pre_validation(req).await;
1029
1030 let req = Request::builder().body(Body::empty()).unwrap();
1031 let _ = hooks.execute_pre_handler(req).await;
1032
1033 let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
1034 let _ = hooks.execute_on_response(resp).await;
1035
1036 let resp = Response::builder()
1037 .status(StatusCode::INTERNAL_SERVER_ERROR)
1038 .body(Body::empty())
1039 .unwrap();
1040 let _ = hooks.execute_on_error(resp).await;
1041 }
1042
1043 #[tokio::test]
1044 async fn test_hook_with_complex_header_manipulation() {
1045 let hooks = LifecycleHooks::builder()
1046 .on_request(request_hook("parse_auth", |mut req| async move {
1047 let has_auth = req.headers().contains_key("Authorization");
1048 let auth_status = if has_auth { "authenticated" } else { "anonymous" };
1049 req.headers_mut()
1050 .insert("X-Auth-Status", axum::http::HeaderValue::from_static(auth_status));
1051 Ok(HookResult::Continue(req))
1052 }))
1053 .pre_validation(request_hook("validate_auth", |req| async move {
1054 if let Some(auth_header) = req.headers().get("X-Auth-Status") {
1055 if auth_header == "anonymous" {
1056 let response = Response::builder()
1057 .status(StatusCode::UNAUTHORIZED)
1058 .body(Body::from("Authentication required"))
1059 .unwrap();
1060 return Ok(HookResult::ShortCircuit(response));
1061 }
1062 }
1063 Ok(HookResult::Continue(req))
1064 }))
1065 .build();
1066
1067 let auth_req = Request::builder()
1068 .header("Authorization", "Bearer token123")
1069 .body(Body::empty())
1070 .unwrap();
1071
1072 let result = hooks.execute_on_request(auth_req).await.unwrap();
1073 assert!(matches!(result, HookResult::Continue(_)));
1074
1075 let anon_req = Request::builder().body(Body::empty()).unwrap();
1076 let on_req_result = hooks.execute_on_request(anon_req).await.unwrap();
1077
1078 match on_req_result {
1079 HookResult::Continue(req) => {
1080 assert_eq!(req.headers().get("X-Auth-Status").unwrap(), "anonymous");
1081
1082 let val_result = hooks.execute_pre_validation(req).await.unwrap();
1083 assert!(matches!(val_result, HookResult::ShortCircuit(_)));
1084 }
1085 HookResult::ShortCircuit(_) => panic!("Expected Continue from on_request"),
1086 }
1087 }
1088
1089 #[tokio::test]
1090 async fn test_empty_hooks_no_overhead() {
1091 let hooks = LifecycleHooks::new();
1092 assert!(hooks.is_empty());
1093
1094 let req = Request::builder().body(Body::empty()).unwrap();
1095 let result = hooks.execute_on_request(req).await.unwrap();
1096 assert!(matches!(result, HookResult::Continue(_)));
1097
1098 let req = Request::builder().body(Body::empty()).unwrap();
1099 let result = hooks.execute_pre_validation(req).await.unwrap();
1100 assert!(matches!(result, HookResult::Continue(_)));
1101
1102 let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
1103 let result = hooks.execute_on_response(resp).await.unwrap();
1104 assert_eq!(result.status(), StatusCode::OK);
1105 }
1106
1107 #[tokio::test]
1108 async fn test_response_hook_short_circuit_treated_as_continue() {
1109 let hooks = LifecycleHooks::builder()
1110 .on_response(response_hook("hook_with_short_circuit", |mut resp| async move {
1111 resp.headers_mut()
1112 .insert("X-Processed", axum::http::HeaderValue::from_static("yes"));
1113 Ok(HookResult::ShortCircuit(resp))
1114 }))
1115 .on_response(response_hook("second_hook", |mut resp| async move {
1116 resp.headers_mut()
1117 .insert("X-Second", axum::http::HeaderValue::from_static("yes"));
1118 Ok(HookResult::Continue(resp))
1119 }))
1120 .build();
1121
1122 let resp = Response::builder().status(StatusCode::OK).body(Body::empty()).unwrap();
1123
1124 let result = hooks.execute_on_response(resp).await.unwrap();
1125
1126 assert_eq!(result.headers().get("X-Processed").unwrap(), "yes");
1127 assert_eq!(result.headers().get("X-Second").unwrap(), "yes");
1128 }
1129
1130 #[tokio::test]
1131 async fn test_complex_pre_validation_flow_with_auth_and_content_check() {
1132 let hooks = LifecycleHooks::builder()
1133 .pre_validation(request_hook("check_auth", |req| async move {
1134 if !req.headers().contains_key("Authorization") {
1135 return Ok(HookResult::ShortCircuit(
1136 Response::builder()
1137 .status(StatusCode::UNAUTHORIZED)
1138 .body(Body::from("Missing auth"))
1139 .unwrap(),
1140 ));
1141 }
1142 Ok(HookResult::Continue(req))
1143 }))
1144 .pre_validation(request_hook("check_content_type", |req| async move {
1145 if req.method() == axum::http::Method::POST {
1146 if !req.headers().contains_key("Content-Type") {
1147 return Ok(HookResult::ShortCircuit(
1148 Response::builder()
1149 .status(StatusCode::BAD_REQUEST)
1150 .body(Body::from("Missing Content-Type"))
1151 .unwrap(),
1152 ));
1153 }
1154 }
1155 Ok(HookResult::Continue(req))
1156 }))
1157 .build();
1158
1159 let req = Request::builder().body(Body::empty()).unwrap();
1160 let result = hooks.execute_pre_validation(req).await.unwrap();
1161
1162 match result {
1163 HookResult::ShortCircuit(resp) => {
1164 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1165 }
1166 HookResult::Continue(_) => panic!("Expected ShortCircuit for missing auth"),
1167 }
1168
1169 let req = Request::builder()
1170 .method(axum::http::Method::POST)
1171 .header("Authorization", "Bearer token")
1172 .body(Body::empty())
1173 .unwrap();
1174 let result = hooks.execute_pre_validation(req).await.unwrap();
1175
1176 match result {
1177 HookResult::ShortCircuit(resp) => {
1178 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
1179 }
1180 HookResult::Continue(_) => panic!("Expected ShortCircuit for missing content type"),
1181 }
1182
1183 let req = Request::builder()
1184 .method(axum::http::Method::POST)
1185 .header("Authorization", "Bearer token")
1186 .header("Content-Type", "application/json")
1187 .body(Body::empty())
1188 .unwrap();
1189 let result = hooks.execute_pre_validation(req).await.unwrap();
1190
1191 assert!(matches!(result, HookResult::Continue(_)));
1192 }
1193}