1use crate::request::Request;
7use crate::response::Response;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tower::Service;
13
14#[allow(dead_code)]
16pub type BoxedMiddleware = Arc<
17 dyn Fn(Request, BoxedNext) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
18 + Send
19 + Sync,
20>;
21
22pub type BoxedNext =
24 Arc<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>;
25
26pub trait MiddlewareLayer: Send + Sync + 'static {
31 fn call(
33 &self,
34 req: Request,
35 next: BoxedNext,
36 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
37
38 fn clone_box(&self) -> Box<dyn MiddlewareLayer>;
40}
41
42impl Clone for Box<dyn MiddlewareLayer> {
43 fn clone(&self) -> Self {
44 self.clone_box()
45 }
46}
47
48#[derive(Clone, Default)]
50pub struct LayerStack {
51 layers: Vec<Box<dyn MiddlewareLayer>>,
52}
53
54impl LayerStack {
55 pub fn new() -> Self {
57 Self { layers: Vec::new() }
58 }
59
60 pub fn push(&mut self, layer: Box<dyn MiddlewareLayer>) {
64 self.layers.push(layer);
65 }
66
67 pub fn prepend(&mut self, layer: Box<dyn MiddlewareLayer>) {
71 self.layers.insert(0, layer);
72 }
73
74 pub fn is_empty(&self) -> bool {
76 self.layers.is_empty()
77 }
78
79 pub fn len(&self) -> usize {
81 self.layers.len()
82 }
83
84 pub fn execute(
86 &self,
87 req: Request,
88 handler: BoxedNext,
89 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
90 if self.layers.is_empty() {
91 return handler(req);
92 }
93
94 let mut next = handler;
97
98 for layer in self.layers.iter().rev() {
99 let layer = layer.clone_box();
100 let current_next = next;
101 next = Arc::new(move |req: Request| {
102 let layer = layer.clone_box();
103 let next = current_next.clone();
104 Box::pin(async move { layer.call(req, next).await })
105 as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
106 });
107 }
108
109 next(req)
110 }
111}
112
113#[allow(dead_code)]
115pub struct TowerLayerAdapter<L> {
116 layer: L,
117}
118
119impl<L> TowerLayerAdapter<L>
120where
121 L: Clone + Send + Sync + 'static,
122{
123 #[allow(dead_code)]
125 pub fn new(layer: L) -> Self {
126 Self { layer }
127 }
128}
129
130impl<L> Clone for TowerLayerAdapter<L>
131where
132 L: Clone,
133{
134 fn clone(&self) -> Self {
135 Self {
136 layer: self.layer.clone(),
137 }
138 }
139}
140
141#[allow(dead_code)]
143pub struct NextService {
144 next: BoxedNext,
145}
146
147impl NextService {
148 #[allow(dead_code)]
149 pub fn new(next: BoxedNext) -> Self {
150 Self { next }
151 }
152}
153
154impl Clone for NextService {
155 fn clone(&self) -> Self {
156 Self {
157 next: self.next.clone(),
158 }
159 }
160}
161
162impl Service<Request> for NextService {
163 type Response = Response;
164 type Error = std::convert::Infallible;
165 type Future = Pin<Box<dyn Future<Output = Result<Response, Self::Error>> + Send>>;
166
167 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168 Poll::Ready(Ok(()))
169 }
170
171 fn call(&mut self, req: Request) -> Self::Future {
172 let next = self.next.clone();
173 Box::pin(async move { Ok(next(req).await) })
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use crate::request::Request;
181 use crate::response::Response;
182 use bytes::Bytes;
183 use http::{Extensions, Method, StatusCode};
184 use proptest::prelude::*;
185 use proptest::test_runner::TestCaseError;
186 use std::collections::HashMap;
187
188 fn create_test_request(method: Method, path: &str) -> Request {
190 let uri: http::Uri = path.parse().unwrap();
191 let builder = http::Request::builder().method(method).uri(uri);
192
193 let req = builder.body(()).unwrap();
194 let (parts, _) = req.into_parts();
195
196 Request::new(
197 parts,
198 Bytes::new(),
199 Arc::new(Extensions::new()),
200 HashMap::new(),
201 )
202 }
203
204 #[derive(Clone)]
206 struct OrderTrackingMiddleware {
207 id: usize,
208 order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>,
209 }
210
211 impl OrderTrackingMiddleware {
212 fn new(id: usize, order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>) -> Self {
213 Self { id, order }
214 }
215 }
216
217 impl MiddlewareLayer for OrderTrackingMiddleware {
218 fn call(
219 &self,
220 req: Request,
221 next: BoxedNext,
222 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
223 let id = self.id;
224 let order = self.order.clone();
225
226 Box::pin(async move {
227 order.lock().unwrap().push((id, "pre"));
229
230 let response = next(req).await;
232
233 order.lock().unwrap().push((id, "post"));
235
236 response
237 })
238 }
239
240 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
241 Box::new(self.clone())
242 }
243 }
244
245 #[derive(Clone)]
247 struct StatusModifyingMiddleware {
248 status: StatusCode,
249 }
250
251 impl StatusModifyingMiddleware {
252 fn new(status: StatusCode) -> Self {
253 Self { status }
254 }
255 }
256
257 impl MiddlewareLayer for StatusModifyingMiddleware {
258 fn call(
259 &self,
260 req: Request,
261 next: BoxedNext,
262 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
263 let status = self.status;
264
265 Box::pin(async move {
266 let mut response = next(req).await;
267 *response.status_mut() = status;
268 response
269 })
270 }
271
272 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
273 Box::new(self.clone())
274 }
275 }
276
277 proptest! {
285 #![proptest_config(ProptestConfig::with_cases(100))]
286
287 #[test]
288 fn prop_layer_application_preserves_handler_behavior(
289 handler_status in 200u16..600u16,
290 ) {
291 let rt = tokio::runtime::Runtime::new().unwrap();
292 let result: Result<(), TestCaseError> = rt.block_on(async {
293 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
294
295 let mut stack = LayerStack::new();
297 stack.push(Box::new(OrderTrackingMiddleware::new(1, order.clone())));
298
299 let handler_status = StatusCode::from_u16(handler_status).unwrap_or(StatusCode::OK);
301 let handler: BoxedNext = Arc::new(move |_req: Request| {
302 let status = handler_status;
303 Box::pin(async move {
304 http::Response::builder()
305 .status(status)
306 .body(http_body_util::Full::new(Bytes::from("test")))
307 .unwrap()
308 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
309 });
310
311 let request = create_test_request(Method::GET, "/test");
313 let response = stack.execute(request, handler).await;
314
315 prop_assert_eq!(response.status(), handler_status);
317
318 let execution_order = order.lock().unwrap();
320 prop_assert_eq!(execution_order.len(), 2);
321 prop_assert_eq!(execution_order[0], (1, "pre"));
322 prop_assert_eq!(execution_order[1], (1, "post"));
323
324 Ok(())
325 });
326 result?;
327 }
328 }
329
330 #[test]
331 fn test_empty_layer_stack_calls_handler_directly() {
332 let rt = tokio::runtime::Runtime::new().unwrap();
333 rt.block_on(async {
334 let stack = LayerStack::new();
335
336 let handler: BoxedNext = Arc::new(|_req: Request| {
337 Box::pin(async {
338 http::Response::builder()
339 .status(StatusCode::OK)
340 .body(http_body_util::Full::new(Bytes::from("direct")))
341 .unwrap()
342 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
343 });
344
345 let request = create_test_request(Method::GET, "/test");
346 let response = stack.execute(request, handler).await;
347
348 assert_eq!(response.status(), StatusCode::OK);
349 });
350 }
351
352 proptest! {
360 #![proptest_config(ProptestConfig::with_cases(100))]
361
362 #[test]
363 fn prop_middleware_execution_order(
364 num_layers in 1usize..10usize,
365 ) {
366 let rt = tokio::runtime::Runtime::new().unwrap();
367 let result: Result<(), TestCaseError> = rt.block_on(async {
368 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
369
370 let mut stack = LayerStack::new();
372 for i in 0..num_layers {
373 stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
374 }
375
376 let handler: BoxedNext = Arc::new(|_req: Request| {
378 Box::pin(async {
379 http::Response::builder()
380 .status(StatusCode::OK)
381 .body(http_body_util::Full::new(Bytes::from("test")))
382 .unwrap()
383 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
384 });
385
386 let request = create_test_request(Method::GET, "/test");
388 let _response = stack.execute(request, handler).await;
389
390 let execution_order = order.lock().unwrap();
392
393 prop_assert_eq!(execution_order.len(), num_layers * 2);
395
396 for i in 0..num_layers {
398 prop_assert_eq!(execution_order[i], (i, "pre"),
399 "Pre-handler order mismatch at index {}", i);
400 }
401
402 for i in 0..num_layers {
404 let expected_id = num_layers - 1 - i;
405 prop_assert_eq!(execution_order[num_layers + i], (expected_id, "post"),
406 "Post-handler order mismatch at index {}", i);
407 }
408
409 Ok(())
410 });
411 result?;
412 }
413 }
414
415 #[derive(Clone)]
417 struct ShortCircuitMiddleware {
418 error_status: StatusCode,
419 should_short_circuit: bool,
420 }
421
422 impl ShortCircuitMiddleware {
423 fn new(error_status: StatusCode, should_short_circuit: bool) -> Self {
424 Self {
425 error_status,
426 should_short_circuit,
427 }
428 }
429 }
430
431 impl MiddlewareLayer for ShortCircuitMiddleware {
432 fn call(
433 &self,
434 req: Request,
435 next: BoxedNext,
436 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
437 let error_status = self.error_status;
438 let should_short_circuit = self.should_short_circuit;
439
440 Box::pin(async move {
441 if should_short_circuit {
442 http::Response::builder()
444 .status(error_status)
445 .body(http_body_util::Full::new(Bytes::from("error")))
446 .unwrap()
447 } else {
448 next(req).await
450 }
451 })
452 }
453
454 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
455 Box::new(self.clone())
456 }
457 }
458
459 proptest! {
466 #![proptest_config(ProptestConfig::with_cases(100))]
467
468 #[test]
469 fn prop_middleware_short_circuit_on_error(
470 error_status in 400u16..600u16,
471 num_middleware_before in 0usize..5usize,
472 num_middleware_after in 0usize..5usize,
473 ) {
474 let rt = tokio::runtime::Runtime::new().unwrap();
475 let result: Result<(), TestCaseError> = rt.block_on(async {
476 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
477 let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
478
479 let mut stack = LayerStack::new();
481
482 for i in 0..num_middleware_before {
484 stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
485 }
486
487 let error_status = StatusCode::from_u16(error_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
489 stack.push(Box::new(ShortCircuitMiddleware::new(error_status, true)));
490
491 for i in 0..num_middleware_after {
493 stack.push(Box::new(OrderTrackingMiddleware::new(100 + i, order.clone())));
494 }
495
496 let handler_called_clone = handler_called.clone();
498 let handler: BoxedNext = Arc::new(move |_req: Request| {
499 let handler_called = handler_called_clone.clone();
500 Box::pin(async move {
501 handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
502 http::Response::builder()
503 .status(StatusCode::OK)
504 .body(http_body_util::Full::new(Bytes::from("handler")))
505 .unwrap()
506 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
507 });
508
509 let request = create_test_request(Method::GET, "/test");
511 let response = stack.execute(request, handler).await;
512
513 prop_assert_eq!(response.status(), error_status,
515 "Response should have the error status from short-circuit middleware");
516
517 prop_assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst),
519 "Handler should NOT be called when middleware short-circuits");
520
521 let execution_order = order.lock().unwrap();
526
527 let pre_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "pre").count();
529 let post_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "post").count();
530
531 prop_assert_eq!(pre_count, num_middleware_before,
532 "All middleware before short-circuit should have pre recorded");
533 prop_assert_eq!(post_count, num_middleware_before,
534 "All middleware before short-circuit should have post recorded (unwinding)");
535
536 let after_entries = execution_order.iter().filter(|(id, _)| *id >= 100).count();
538 prop_assert_eq!(after_entries, 0,
539 "Middleware after short-circuit should NOT be executed");
540
541 Ok(())
542 });
543 result?;
544 }
545 }
546
547 #[test]
548 fn test_short_circuit_returns_error_response() {
549 let rt = tokio::runtime::Runtime::new().unwrap();
550 rt.block_on(async {
551 let mut stack = LayerStack::new();
552 stack.push(Box::new(ShortCircuitMiddleware::new(
553 StatusCode::UNAUTHORIZED,
554 true,
555 )));
556
557 let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
558 let handler_called_clone = handler_called.clone();
559
560 let handler: BoxedNext = Arc::new(move |_req: Request| {
561 let handler_called = handler_called_clone.clone();
562 Box::pin(async move {
563 handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
564 http::Response::builder()
565 .status(StatusCode::OK)
566 .body(http_body_util::Full::new(Bytes::from("handler")))
567 .unwrap()
568 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
569 });
570
571 let request = create_test_request(Method::GET, "/test");
572 let response = stack.execute(request, handler).await;
573
574 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
575 assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst));
576 });
577 }
578}