1use crate::request::Request;
7use crate::response::Response;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tower::Service;
13
14#[allow(dead_code)]
16pub type BoxedMiddleware = Arc<
17 dyn Fn(Request, BoxedNext) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
18 + Send
19 + Sync,
20>;
21
22pub type BoxedNext =
24 Arc<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>;
25
26pub trait MiddlewareLayer: Send + Sync + 'static {
31 fn call(
33 &self,
34 req: Request,
35 next: BoxedNext,
36 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
37
38 fn clone_box(&self) -> Box<dyn MiddlewareLayer>;
40}
41
42impl Clone for Box<dyn MiddlewareLayer> {
43 fn clone(&self) -> Self {
44 self.clone_box()
45 }
46}
47
48#[derive(Clone, Default)]
50pub struct LayerStack {
51 layers: Vec<Box<dyn MiddlewareLayer>>,
52}
53
54impl LayerStack {
55 pub fn new() -> Self {
57 Self { layers: Vec::new() }
58 }
59
60 pub fn push(&mut self, layer: Box<dyn MiddlewareLayer>) {
64 self.layers.push(layer);
65 }
66
67 pub fn prepend(&mut self, layer: Box<dyn MiddlewareLayer>) {
71 self.layers.insert(0, layer);
72 }
73
74 pub fn is_empty(&self) -> bool {
76 self.layers.is_empty()
77 }
78
79 pub fn len(&self) -> usize {
81 self.layers.len()
82 }
83
84 pub fn execute(
86 &self,
87 req: Request,
88 handler: BoxedNext,
89 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
90 if self.layers.is_empty() {
91 return handler(req);
92 }
93
94 let mut next = handler;
97
98 for layer in self.layers.iter().rev() {
99 let layer = layer.clone_box();
100 let current_next = next;
101 next = Arc::new(move |req: Request| {
102 let layer = layer.clone_box();
103 let next = current_next.clone();
104 Box::pin(async move { layer.call(req, next).await })
105 as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
106 });
107 }
108
109 next(req)
110 }
111}
112
113impl IntoIterator for LayerStack {
114 type Item = Box<dyn MiddlewareLayer>;
115 type IntoIter = std::vec::IntoIter<Self::Item>;
116
117 fn into_iter(self) -> Self::IntoIter {
118 self.layers.into_iter()
119 }
120}
121
122impl Extend<Box<dyn MiddlewareLayer>> for LayerStack {
123 fn extend<T: IntoIterator<Item = Box<dyn MiddlewareLayer>>>(&mut self, iter: T) {
124 self.layers.extend(iter);
125 }
126}
127
128#[allow(dead_code)]
130pub struct TowerLayerAdapter<L> {
131 layer: L,
132}
133
134impl<L> TowerLayerAdapter<L>
135where
136 L: Clone + Send + Sync + 'static,
137{
138 #[allow(dead_code)]
140 pub fn new(layer: L) -> Self {
141 Self { layer }
142 }
143}
144
145impl<L> Clone for TowerLayerAdapter<L>
146where
147 L: Clone,
148{
149 fn clone(&self) -> Self {
150 Self {
151 layer: self.layer.clone(),
152 }
153 }
154}
155
156#[allow(dead_code)]
158pub struct NextService {
159 next: BoxedNext,
160}
161
162impl NextService {
163 #[allow(dead_code)]
164 pub fn new(next: BoxedNext) -> Self {
165 Self { next }
166 }
167}
168
169impl Clone for NextService {
170 fn clone(&self) -> Self {
171 Self {
172 next: self.next.clone(),
173 }
174 }
175}
176
177impl Service<Request> for NextService {
178 type Response = Response;
179 type Error = std::convert::Infallible;
180 type Future = Pin<Box<dyn Future<Output = Result<Response, Self::Error>> + Send>>;
181
182 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
183 Poll::Ready(Ok(()))
184 }
185
186 fn call(&mut self, req: Request) -> Self::Future {
187 let next = self.next.clone();
188 Box::pin(async move { Ok(next(req).await) })
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use crate::request::Request;
196 use crate::response::Response;
197 use bytes::Bytes;
198 use http::{Extensions, Method, StatusCode};
199 use proptest::prelude::*;
200 use proptest::test_runner::TestCaseError;
201 use std::collections::HashMap;
202
203 fn create_test_request(method: Method, path: &str) -> Request {
205 let uri: http::Uri = path.parse().unwrap();
206 let builder = http::Request::builder().method(method).uri(uri);
207
208 let req = builder.body(()).unwrap();
209 let (parts, _) = req.into_parts();
210
211 Request::new(
212 parts,
213 Bytes::new(),
214 Arc::new(Extensions::new()),
215 HashMap::new(),
216 )
217 }
218
219 #[derive(Clone)]
221 struct OrderTrackingMiddleware {
222 id: usize,
223 order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>,
224 }
225
226 impl OrderTrackingMiddleware {
227 fn new(id: usize, order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>) -> Self {
228 Self { id, order }
229 }
230 }
231
232 impl MiddlewareLayer for OrderTrackingMiddleware {
233 fn call(
234 &self,
235 req: Request,
236 next: BoxedNext,
237 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
238 let id = self.id;
239 let order = self.order.clone();
240
241 Box::pin(async move {
242 order.lock().unwrap().push((id, "pre"));
244
245 let response = next(req).await;
247
248 order.lock().unwrap().push((id, "post"));
250
251 response
252 })
253 }
254
255 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
256 Box::new(self.clone())
257 }
258 }
259
260 #[derive(Clone)]
262 #[allow(dead_code)]
263 struct StatusModifyingMiddleware {
264 status: StatusCode,
265 }
266
267 #[allow(dead_code)]
268 impl StatusModifyingMiddleware {
269 fn new(status: StatusCode) -> Self {
270 Self { status }
271 }
272 }
273
274 impl MiddlewareLayer for StatusModifyingMiddleware {
275 fn call(
276 &self,
277 req: Request,
278 next: BoxedNext,
279 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
280 let status = self.status;
281
282 Box::pin(async move {
283 let mut response = next(req).await;
284 *response.status_mut() = status;
285 response
286 })
287 }
288
289 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
290 Box::new(self.clone())
291 }
292 }
293
294 proptest! {
302 #![proptest_config(ProptestConfig::with_cases(100))]
303
304 #[test]
305 fn prop_layer_application_preserves_handler_behavior(
306 handler_status in 200u16..600u16,
307 ) {
308 let rt = tokio::runtime::Runtime::new().unwrap();
309 let result: Result<(), TestCaseError> = rt.block_on(async {
310 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
311
312 let mut stack = LayerStack::new();
314 stack.push(Box::new(OrderTrackingMiddleware::new(1, order.clone())));
315
316 let handler_status = StatusCode::from_u16(handler_status).unwrap_or(StatusCode::OK);
318 let handler: BoxedNext = Arc::new(move |_req: Request| {
319 let status = handler_status;
320 Box::pin(async move {
321 http::Response::builder()
322 .status(status)
323 .body(http_body_util::Full::new(Bytes::from("test")))
324 .unwrap()
325 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
326 });
327
328 let request = create_test_request(Method::GET, "/test");
330 let response = stack.execute(request, handler).await;
331
332 prop_assert_eq!(response.status(), handler_status);
334
335 let execution_order = order.lock().unwrap();
337 prop_assert_eq!(execution_order.len(), 2);
338 prop_assert_eq!(execution_order[0], (1, "pre"));
339 prop_assert_eq!(execution_order[1], (1, "post"));
340
341 Ok(())
342 });
343 result?;
344 }
345 }
346
347 #[test]
348 fn test_empty_layer_stack_calls_handler_directly() {
349 let rt = tokio::runtime::Runtime::new().unwrap();
350 rt.block_on(async {
351 let stack = LayerStack::new();
352
353 let handler: BoxedNext = Arc::new(|_req: Request| {
354 Box::pin(async {
355 http::Response::builder()
356 .status(StatusCode::OK)
357 .body(http_body_util::Full::new(Bytes::from("direct")))
358 .unwrap()
359 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
360 });
361
362 let request = create_test_request(Method::GET, "/test");
363 let response = stack.execute(request, handler).await;
364
365 assert_eq!(response.status(), StatusCode::OK);
366 });
367 }
368
369 proptest! {
377 #![proptest_config(ProptestConfig::with_cases(100))]
378
379 #[test]
380 fn prop_middleware_execution_order(
381 num_layers in 1usize..10usize,
382 ) {
383 let rt = tokio::runtime::Runtime::new().unwrap();
384 let result: Result<(), TestCaseError> = rt.block_on(async {
385 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
386
387 let mut stack = LayerStack::new();
389 for i in 0..num_layers {
390 stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
391 }
392
393 let handler: BoxedNext = Arc::new(|_req: Request| {
395 Box::pin(async {
396 http::Response::builder()
397 .status(StatusCode::OK)
398 .body(http_body_util::Full::new(Bytes::from("test")))
399 .unwrap()
400 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
401 });
402
403 let request = create_test_request(Method::GET, "/test");
405 let _response = stack.execute(request, handler).await;
406
407 let execution_order = order.lock().unwrap();
409
410 prop_assert_eq!(execution_order.len(), num_layers * 2);
412
413 for i in 0..num_layers {
415 prop_assert_eq!(execution_order[i], (i, "pre"),
416 "Pre-handler order mismatch at index {}", i);
417 }
418
419 for i in 0..num_layers {
421 let expected_id = num_layers - 1 - i;
422 prop_assert_eq!(execution_order[num_layers + i], (expected_id, "post"),
423 "Post-handler order mismatch at index {}", i);
424 }
425
426 Ok(())
427 });
428 result?;
429 }
430 }
431
432 #[derive(Clone)]
434 struct ShortCircuitMiddleware {
435 error_status: StatusCode,
436 should_short_circuit: bool,
437 }
438
439 impl ShortCircuitMiddleware {
440 fn new(error_status: StatusCode, should_short_circuit: bool) -> Self {
441 Self {
442 error_status,
443 should_short_circuit,
444 }
445 }
446 }
447
448 impl MiddlewareLayer for ShortCircuitMiddleware {
449 fn call(
450 &self,
451 req: Request,
452 next: BoxedNext,
453 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
454 let error_status = self.error_status;
455 let should_short_circuit = self.should_short_circuit;
456
457 Box::pin(async move {
458 if should_short_circuit {
459 http::Response::builder()
461 .status(error_status)
462 .body(http_body_util::Full::new(Bytes::from("error")))
463 .unwrap()
464 } else {
465 next(req).await
467 }
468 })
469 }
470
471 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
472 Box::new(self.clone())
473 }
474 }
475
476 proptest! {
483 #![proptest_config(ProptestConfig::with_cases(100))]
484
485 #[test]
486 fn prop_middleware_short_circuit_on_error(
487 error_status in 400u16..600u16,
488 num_middleware_before in 0usize..5usize,
489 num_middleware_after in 0usize..5usize,
490 ) {
491 let rt = tokio::runtime::Runtime::new().unwrap();
492 let result: Result<(), TestCaseError> = rt.block_on(async {
493 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
494 let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
495
496 let mut stack = LayerStack::new();
498
499 for i in 0..num_middleware_before {
501 stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
502 }
503
504 let error_status = StatusCode::from_u16(error_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
506 stack.push(Box::new(ShortCircuitMiddleware::new(error_status, true)));
507
508 for i in 0..num_middleware_after {
510 stack.push(Box::new(OrderTrackingMiddleware::new(100 + i, order.clone())));
511 }
512
513 let handler_called_clone = handler_called.clone();
515 let handler: BoxedNext = Arc::new(move |_req: Request| {
516 let handler_called = handler_called_clone.clone();
517 Box::pin(async move {
518 handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
519 http::Response::builder()
520 .status(StatusCode::OK)
521 .body(http_body_util::Full::new(Bytes::from("handler")))
522 .unwrap()
523 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
524 });
525
526 let request = create_test_request(Method::GET, "/test");
528 let response = stack.execute(request, handler).await;
529
530 prop_assert_eq!(response.status(), error_status,
532 "Response should have the error status from short-circuit middleware");
533
534 prop_assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst),
536 "Handler should NOT be called when middleware short-circuits");
537
538 let execution_order = order.lock().unwrap();
543
544 let pre_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "pre").count();
546 let post_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "post").count();
547
548 prop_assert_eq!(pre_count, num_middleware_before,
549 "All middleware before short-circuit should have pre recorded");
550 prop_assert_eq!(post_count, num_middleware_before,
551 "All middleware before short-circuit should have post recorded (unwinding)");
552
553 let after_entries = execution_order.iter().filter(|(id, _)| *id >= 100).count();
555 prop_assert_eq!(after_entries, 0,
556 "Middleware after short-circuit should NOT be executed");
557
558 Ok(())
559 });
560 result?;
561 }
562 }
563
564 #[test]
565 fn test_short_circuit_returns_error_response() {
566 let rt = tokio::runtime::Runtime::new().unwrap();
567 rt.block_on(async {
568 let mut stack = LayerStack::new();
569 stack.push(Box::new(ShortCircuitMiddleware::new(
570 StatusCode::UNAUTHORIZED,
571 true,
572 )));
573
574 let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
575 let handler_called_clone = handler_called.clone();
576
577 let handler: BoxedNext = Arc::new(move |_req: Request| {
578 let handler_called = handler_called_clone.clone();
579 Box::pin(async move {
580 handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
581 http::Response::builder()
582 .status(StatusCode::OK)
583 .body(http_body_util::Full::new(Bytes::from("handler")))
584 .unwrap()
585 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
586 });
587
588 let request = create_test_request(Method::GET, "/test");
589 let response = stack.execute(request, handler).await;
590
591 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
592 assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst));
593 });
594 }
595}