1use crate::request::Request;
7use crate::response::Response;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tower::Service;
13
14#[allow(dead_code)]
16pub type BoxedMiddleware = Arc<
17 dyn Fn(Request, BoxedNext) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
18 + Send
19 + Sync,
20>;
21
22pub type BoxedNext =
24 Arc<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>;
25
26pub trait MiddlewareLayer: Send + Sync + 'static {
62 fn call(
64 &self,
65 req: Request,
66 next: BoxedNext,
67 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
68
69 fn clone_box(&self) -> Box<dyn MiddlewareLayer>;
71}
72
73impl Clone for Box<dyn MiddlewareLayer> {
74 fn clone(&self) -> Self {
75 self.clone_box()
76 }
77}
78
79#[derive(Clone, Default)]
81pub struct LayerStack {
82 layers: Vec<Box<dyn MiddlewareLayer>>,
83}
84
85impl LayerStack {
86 pub fn new() -> Self {
88 Self { layers: Vec::new() }
89 }
90
91 pub fn push(&mut self, layer: Box<dyn MiddlewareLayer>) {
95 self.layers.push(layer);
96 }
97
98 pub fn prepend(&mut self, layer: Box<dyn MiddlewareLayer>) {
102 self.layers.insert(0, layer);
103 }
104
105 pub fn is_empty(&self) -> bool {
107 self.layers.is_empty()
108 }
109
110 pub fn len(&self) -> usize {
112 self.layers.len()
113 }
114
115 pub fn execute(
117 &self,
118 req: Request,
119 handler: BoxedNext,
120 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
121 if self.layers.is_empty() {
122 return handler(req);
123 }
124
125 let mut next = handler;
128
129 for layer in self.layers.iter().rev() {
130 let layer = layer.clone_box();
131 let current_next = next;
132 next = Arc::new(move |req: Request| {
133 let layer = layer.clone_box();
134 let next = current_next.clone();
135 Box::pin(async move { layer.call(req, next).await })
136 as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
137 });
138 }
139
140 next(req)
141 }
142}
143
144impl IntoIterator for LayerStack {
145 type Item = Box<dyn MiddlewareLayer>;
146 type IntoIter = std::vec::IntoIter<Self::Item>;
147
148 fn into_iter(self) -> Self::IntoIter {
149 self.layers.into_iter()
150 }
151}
152
153impl Extend<Box<dyn MiddlewareLayer>> for LayerStack {
154 fn extend<T: IntoIterator<Item = Box<dyn MiddlewareLayer>>>(&mut self, iter: T) {
155 self.layers.extend(iter);
156 }
157}
158
159#[allow(dead_code)]
161pub struct TowerLayerAdapter<L> {
162 layer: L,
163}
164
165impl<L> TowerLayerAdapter<L>
166where
167 L: Clone + Send + Sync + 'static,
168{
169 #[allow(dead_code)]
171 pub fn new(layer: L) -> Self {
172 Self { layer }
173 }
174}
175
176impl<L> Clone for TowerLayerAdapter<L>
177where
178 L: Clone,
179{
180 fn clone(&self) -> Self {
181 Self {
182 layer: self.layer.clone(),
183 }
184 }
185}
186
187#[allow(dead_code)]
189pub struct NextService {
190 next: BoxedNext,
191}
192
193impl NextService {
194 #[allow(dead_code)]
195 pub fn new(next: BoxedNext) -> Self {
196 Self { next }
197 }
198}
199
200impl Clone for NextService {
201 fn clone(&self) -> Self {
202 Self {
203 next: self.next.clone(),
204 }
205 }
206}
207
208impl Service<Request> for NextService {
209 type Response = Response;
210 type Error = std::convert::Infallible;
211 type Future = Pin<Box<dyn Future<Output = Result<Response, Self::Error>> + Send>>;
212
213 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
214 Poll::Ready(Ok(()))
215 }
216
217 fn call(&mut self, req: Request) -> Self::Future {
218 let next = self.next.clone();
219 Box::pin(async move { Ok(next(req).await) })
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use crate::path_params::PathParams;
227 use crate::request::Request;
228 use crate::response::Response;
229 use bytes::Bytes;
230 use http::{Extensions, Method, StatusCode};
231 use proptest::prelude::*;
232 use proptest::test_runner::TestCaseError;
233
234 fn create_test_request(method: Method, path: &str) -> Request {
236 let uri: http::Uri = path.parse().unwrap();
237 let builder = http::Request::builder().method(method).uri(uri);
238
239 let req = builder.body(()).unwrap();
240 let (parts, _) = req.into_parts();
241
242 Request::new(
243 parts,
244 crate::request::BodyVariant::Buffered(Bytes::new()),
245 Arc::new(Extensions::new()),
246 PathParams::new(),
247 )
248 }
249
250 #[derive(Clone)]
252 struct OrderTrackingMiddleware {
253 id: usize,
254 order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>,
255 }
256
257 impl OrderTrackingMiddleware {
258 fn new(id: usize, order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>) -> Self {
259 Self { id, order }
260 }
261 }
262
263 impl MiddlewareLayer for OrderTrackingMiddleware {
264 fn call(
265 &self,
266 req: Request,
267 next: BoxedNext,
268 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
269 let id = self.id;
270 let order = self.order.clone();
271
272 Box::pin(async move {
273 order.lock().unwrap().push((id, "pre"));
275
276 let response = next(req).await;
278
279 order.lock().unwrap().push((id, "post"));
281
282 response
283 })
284 }
285
286 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
287 Box::new(self.clone())
288 }
289 }
290
291 #[derive(Clone)]
293 #[allow(dead_code)]
294 struct StatusModifyingMiddleware {
295 status: StatusCode,
296 }
297
298 #[allow(dead_code)]
299 impl StatusModifyingMiddleware {
300 fn new(status: StatusCode) -> Self {
301 Self { status }
302 }
303 }
304
305 impl MiddlewareLayer for StatusModifyingMiddleware {
306 fn call(
307 &self,
308 req: Request,
309 next: BoxedNext,
310 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
311 let status = self.status;
312
313 Box::pin(async move {
314 let mut response = next(req).await;
315 *response.status_mut() = status;
316 response
317 })
318 }
319
320 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
321 Box::new(self.clone())
322 }
323 }
324
325 proptest! {
333 #![proptest_config(ProptestConfig::with_cases(100))]
334
335 #[test]
336 fn prop_layer_application_preserves_handler_behavior(
337 handler_status in 200u16..600u16,
338 ) {
339 let rt = tokio::runtime::Runtime::new().unwrap();
340 let result: Result<(), TestCaseError> = rt.block_on(async {
341 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
342
343 let mut stack = LayerStack::new();
345 stack.push(Box::new(OrderTrackingMiddleware::new(1, order.clone())));
346
347 let handler_status = StatusCode::from_u16(handler_status).unwrap_or(StatusCode::OK);
349 let handler: BoxedNext = Arc::new(move |_req: Request| {
350 let status = handler_status;
351 Box::pin(async move {
352 http::Response::builder()
353 .status(status)
354 .body(crate::response::Body::from("test"))
355 .unwrap()
356 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
357 });
358
359 let request = create_test_request(Method::GET, "/test");
361 let response = stack.execute(request, handler).await;
362
363 prop_assert_eq!(response.status(), handler_status);
365
366 let execution_order = order.lock().unwrap();
368 prop_assert_eq!(execution_order.len(), 2);
369 prop_assert_eq!(execution_order[0], (1, "pre"));
370 prop_assert_eq!(execution_order[1], (1, "post"));
371
372 Ok(())
373 });
374 result?;
375 }
376 }
377
378 #[test]
379 fn test_empty_layer_stack_calls_handler_directly() {
380 let rt = tokio::runtime::Runtime::new().unwrap();
381 rt.block_on(async {
382 let stack = LayerStack::new();
383
384 let handler: BoxedNext = Arc::new(|_req: Request| {
385 Box::pin(async {
386 http::Response::builder()
387 .status(StatusCode::OK)
388 .body(crate::response::Body::from("direct"))
389 .unwrap()
390 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
391 });
392
393 let request = create_test_request(Method::GET, "/test");
394 let response = stack.execute(request, handler).await;
395
396 assert_eq!(response.status(), StatusCode::OK);
397 });
398 }
399
400 proptest! {
408 #![proptest_config(ProptestConfig::with_cases(100))]
409
410 #[test]
411 fn prop_middleware_execution_order(
412 num_layers in 1usize..10usize,
413 ) {
414 let rt = tokio::runtime::Runtime::new().unwrap();
415 let result: Result<(), TestCaseError> = rt.block_on(async {
416 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
417
418 let mut stack = LayerStack::new();
420 for i in 0..num_layers {
421 stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
422 }
423
424 let handler: BoxedNext = Arc::new(|_req: Request| {
426 Box::pin(async {
427 http::Response::builder()
428 .status(StatusCode::OK)
429 .body(crate::response::Body::from("test"))
430 .unwrap()
431 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
432 });
433
434 let request = create_test_request(Method::GET, "/test");
436 let _response = stack.execute(request, handler).await;
437
438 let execution_order = order.lock().unwrap();
440
441 prop_assert_eq!(execution_order.len(), num_layers * 2);
443
444 for i in 0..num_layers {
446 prop_assert_eq!(execution_order[i], (i, "pre"),
447 "Pre-handler order mismatch at index {}", i);
448 }
449
450 for i in 0..num_layers {
452 let expected_id = num_layers - 1 - i;
453 prop_assert_eq!(execution_order[num_layers + i], (expected_id, "post"),
454 "Post-handler order mismatch at index {}", i);
455 }
456
457 Ok(())
458 });
459 result?;
460 }
461 }
462
463 #[derive(Clone)]
465 struct ShortCircuitMiddleware {
466 error_status: StatusCode,
467 should_short_circuit: bool,
468 }
469
470 impl ShortCircuitMiddleware {
471 fn new(error_status: StatusCode, should_short_circuit: bool) -> Self {
472 Self {
473 error_status,
474 should_short_circuit,
475 }
476 }
477 }
478
479 impl MiddlewareLayer for ShortCircuitMiddleware {
480 fn call(
481 &self,
482 req: Request,
483 next: BoxedNext,
484 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
485 let error_status = self.error_status;
486 let should_short_circuit = self.should_short_circuit;
487
488 Box::pin(async move {
489 if should_short_circuit {
490 http::Response::builder()
492 .status(error_status)
493 .body(crate::response::Body::from("error"))
494 .unwrap()
495 } else {
496 next(req).await
498 }
499 })
500 }
501
502 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
503 Box::new(self.clone())
504 }
505 }
506
507 proptest! {
514 #![proptest_config(ProptestConfig::with_cases(100))]
515
516 #[test]
517 fn prop_middleware_short_circuit_on_error(
518 error_status in 400u16..600u16,
519 num_middleware_before in 0usize..5usize,
520 num_middleware_after in 0usize..5usize,
521 ) {
522 let rt = tokio::runtime::Runtime::new().unwrap();
523 let result: Result<(), TestCaseError> = rt.block_on(async {
524 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
525 let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
526
527 let mut stack = LayerStack::new();
529
530 for i in 0..num_middleware_before {
532 stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
533 }
534
535 let error_status = StatusCode::from_u16(error_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
537 stack.push(Box::new(ShortCircuitMiddleware::new(error_status, true)));
538
539 for i in 0..num_middleware_after {
541 stack.push(Box::new(OrderTrackingMiddleware::new(100 + i, order.clone())));
542 }
543
544 let handler_called_clone = handler_called.clone();
546 let handler: BoxedNext = Arc::new(move |_req: Request| {
547 let handler_called = handler_called_clone.clone();
548 Box::pin(async move {
549 handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
550 http::Response::builder()
551 .status(StatusCode::OK)
552 .body(crate::response::Body::from("handler"))
553 .unwrap()
554 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
555 });
556
557 let request = create_test_request(Method::GET, "/test");
559 let response = stack.execute(request, handler).await;
560
561 prop_assert_eq!(response.status(), error_status,
563 "Response should have the error status from short-circuit middleware");
564
565 prop_assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst),
567 "Handler should NOT be called when middleware short-circuits");
568
569 let execution_order = order.lock().unwrap();
574
575 let pre_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "pre").count();
577 let post_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "post").count();
578
579 prop_assert_eq!(pre_count, num_middleware_before,
580 "All middleware before short-circuit should have pre recorded");
581 prop_assert_eq!(post_count, num_middleware_before,
582 "All middleware before short-circuit should have post recorded (unwinding)");
583
584 let after_entries = execution_order.iter().filter(|(id, _)| *id >= 100).count();
586 prop_assert_eq!(after_entries, 0,
587 "Middleware after short-circuit should NOT be executed");
588
589 Ok(())
590 });
591 result?;
592 }
593 }
594
595 #[test]
596 fn test_short_circuit_returns_error_response() {
597 let rt = tokio::runtime::Runtime::new().unwrap();
598 rt.block_on(async {
599 let mut stack = LayerStack::new();
600 stack.push(Box::new(ShortCircuitMiddleware::new(
601 StatusCode::UNAUTHORIZED,
602 true,
603 )));
604
605 let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
606 let handler_called_clone = handler_called.clone();
607
608 let handler: BoxedNext = Arc::new(move |_req: Request| {
609 let handler_called = handler_called_clone.clone();
610 Box::pin(async move {
611 handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
612 http::Response::builder()
613 .status(StatusCode::OK)
614 .body(crate::response::Body::from("handler"))
615 .unwrap()
616 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
617 });
618
619 let request = create_test_request(Method::GET, "/test");
620 let response = stack.execute(request, handler).await;
621
622 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
623 assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst));
624 });
625 }
626}