1use crate::request::Request;
7use crate::response::Response;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tower::Service;
13
14pub type BoxedMiddleware = Arc<
16 dyn Fn(
17 Request,
18 BoxedNext,
19 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
20 + Send
21 + Sync,
22>;
23
24pub type BoxedNext =
26 Arc<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>;
27
28pub trait MiddlewareLayer: Send + Sync + 'static {
33 fn call(
35 &self,
36 req: Request,
37 next: BoxedNext,
38 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
39
40 fn clone_box(&self) -> Box<dyn MiddlewareLayer>;
42}
43
44impl Clone for Box<dyn MiddlewareLayer> {
45 fn clone(&self) -> Self {
46 self.clone_box()
47 }
48}
49
50#[derive(Clone, Default)]
52pub struct LayerStack {
53 layers: Vec<Box<dyn MiddlewareLayer>>,
54}
55
56impl LayerStack {
57 pub fn new() -> Self {
59 Self { layers: Vec::new() }
60 }
61
62 pub fn push(&mut self, layer: Box<dyn MiddlewareLayer>) {
66 self.layers.push(layer);
67 }
68
69 pub fn prepend(&mut self, layer: Box<dyn MiddlewareLayer>) {
73 self.layers.insert(0, layer);
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.layers.is_empty()
79 }
80
81 pub fn len(&self) -> usize {
83 self.layers.len()
84 }
85
86 pub fn execute(
88 &self,
89 req: Request,
90 handler: BoxedNext,
91 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
92 if self.layers.is_empty() {
93 return handler(req);
94 }
95
96 let mut next = handler;
99
100 for layer in self.layers.iter().rev() {
101 let layer = layer.clone_box();
102 let current_next = next;
103 next = Arc::new(move |req: Request| {
104 let layer = layer.clone_box();
105 let next = current_next.clone();
106 Box::pin(async move { layer.call(req, next).await })
107 as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
108 });
109 }
110
111 next(req)
112 }
113}
114
115pub struct TowerLayerAdapter<L> {
117 layer: L,
118}
119
120impl<L> TowerLayerAdapter<L>
121where
122 L: Clone + Send + Sync + 'static,
123{
124 pub fn new(layer: L) -> Self {
126 Self { layer }
127 }
128}
129
130impl<L> Clone for TowerLayerAdapter<L>
131where
132 L: Clone,
133{
134 fn clone(&self) -> Self {
135 Self {
136 layer: self.layer.clone(),
137 }
138 }
139}
140
141pub struct NextService {
143 next: BoxedNext,
144}
145
146impl NextService {
147 pub fn new(next: BoxedNext) -> Self {
148 Self { next }
149 }
150}
151
152impl Clone for NextService {
153 fn clone(&self) -> Self {
154 Self {
155 next: self.next.clone(),
156 }
157 }
158}
159
160impl Service<Request> for NextService {
161 type Response = Response;
162 type Error = std::convert::Infallible;
163 type Future = Pin<Box<dyn Future<Output = Result<Response, Self::Error>> + Send>>;
164
165 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
166 Poll::Ready(Ok(()))
167 }
168
169 fn call(&mut self, req: Request) -> Self::Future {
170 let next = self.next.clone();
171 Box::pin(async move { Ok(next(req).await) })
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use crate::request::Request;
179 use crate::response::Response;
180 use bytes::Bytes;
181 use http::{Extensions, Method, StatusCode};
182 use proptest::prelude::*;
183 use proptest::test_runner::TestCaseError;
184 use std::collections::HashMap;
185
186 fn create_test_request(method: Method, path: &str) -> Request {
188 let uri: http::Uri = path.parse().unwrap();
189 let builder = http::Request::builder()
190 .method(method)
191 .uri(uri);
192
193 let req = builder.body(()).unwrap();
194 let (parts, _) = req.into_parts();
195
196 Request::new(
197 parts,
198 Bytes::new(),
199 Arc::new(Extensions::new()),
200 HashMap::new(),
201 )
202 }
203
204 #[derive(Clone)]
206 struct OrderTrackingMiddleware {
207 id: usize,
208 order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>,
209 }
210
211 impl OrderTrackingMiddleware {
212 fn new(id: usize, order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>) -> Self {
213 Self { id, order }
214 }
215 }
216
217 impl MiddlewareLayer for OrderTrackingMiddleware {
218 fn call(
219 &self,
220 req: Request,
221 next: BoxedNext,
222 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
223 let id = self.id;
224 let order = self.order.clone();
225
226 Box::pin(async move {
227 order.lock().unwrap().push((id, "pre"));
229
230 let response = next(req).await;
232
233 order.lock().unwrap().push((id, "post"));
235
236 response
237 })
238 }
239
240 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
241 Box::new(self.clone())
242 }
243 }
244
245 #[derive(Clone)]
247 struct StatusModifyingMiddleware {
248 status: StatusCode,
249 }
250
251 impl StatusModifyingMiddleware {
252 fn new(status: StatusCode) -> Self {
253 Self { status }
254 }
255 }
256
257 impl MiddlewareLayer for StatusModifyingMiddleware {
258 fn call(
259 &self,
260 req: Request,
261 next: BoxedNext,
262 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
263 let status = self.status;
264
265 Box::pin(async move {
266 let mut response = next(req).await;
267 *response.status_mut() = status;
268 response
269 })
270 }
271
272 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
273 Box::new(self.clone())
274 }
275 }
276
277 proptest! {
285 #![proptest_config(ProptestConfig::with_cases(100))]
286
287 #[test]
288 fn prop_layer_application_preserves_handler_behavior(
289 handler_status in 200u16..600u16,
290 ) {
291 let rt = tokio::runtime::Runtime::new().unwrap();
292 let result: Result<(), TestCaseError> = rt.block_on(async {
293 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
294
295 let mut stack = LayerStack::new();
297 stack.push(Box::new(OrderTrackingMiddleware::new(1, order.clone())));
298
299 let handler_status = StatusCode::from_u16(handler_status).unwrap_or(StatusCode::OK);
301 let handler: BoxedNext = Arc::new(move |_req: Request| {
302 let status = handler_status;
303 Box::pin(async move {
304 http::Response::builder()
305 .status(status)
306 .body(http_body_util::Full::new(Bytes::from("test")))
307 .unwrap()
308 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
309 });
310
311 let request = create_test_request(Method::GET, "/test");
313 let response = stack.execute(request, handler).await;
314
315 prop_assert_eq!(response.status(), handler_status);
317
318 let execution_order = order.lock().unwrap();
320 prop_assert_eq!(execution_order.len(), 2);
321 prop_assert_eq!(execution_order[0], (1, "pre"));
322 prop_assert_eq!(execution_order[1], (1, "post"));
323
324 Ok(())
325 });
326 result?;
327 }
328 }
329
330 #[test]
331 fn test_empty_layer_stack_calls_handler_directly() {
332 let rt = tokio::runtime::Runtime::new().unwrap();
333 rt.block_on(async {
334 let stack = LayerStack::new();
335
336 let handler: BoxedNext = Arc::new(|_req: Request| {
337 Box::pin(async {
338 http::Response::builder()
339 .status(StatusCode::OK)
340 .body(http_body_util::Full::new(Bytes::from("direct")))
341 .unwrap()
342 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
343 });
344
345 let request = create_test_request(Method::GET, "/test");
346 let response = stack.execute(request, handler).await;
347
348 assert_eq!(response.status(), StatusCode::OK);
349 });
350 }
351
352 proptest! {
360 #![proptest_config(ProptestConfig::with_cases(100))]
361
362 #[test]
363 fn prop_middleware_execution_order(
364 num_layers in 1usize..10usize,
365 ) {
366 let rt = tokio::runtime::Runtime::new().unwrap();
367 let result: Result<(), TestCaseError> = rt.block_on(async {
368 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
369
370 let mut stack = LayerStack::new();
372 for i in 0..num_layers {
373 stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
374 }
375
376 let handler: BoxedNext = Arc::new(|_req: Request| {
378 Box::pin(async {
379 http::Response::builder()
380 .status(StatusCode::OK)
381 .body(http_body_util::Full::new(Bytes::from("test")))
382 .unwrap()
383 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
384 });
385
386 let request = create_test_request(Method::GET, "/test");
388 let _response = stack.execute(request, handler).await;
389
390 let execution_order = order.lock().unwrap();
392
393 prop_assert_eq!(execution_order.len(), num_layers * 2);
395
396 for i in 0..num_layers {
398 prop_assert_eq!(execution_order[i], (i, "pre"),
399 "Pre-handler order mismatch at index {}", i);
400 }
401
402 for i in 0..num_layers {
404 let expected_id = num_layers - 1 - i;
405 prop_assert_eq!(execution_order[num_layers + i], (expected_id, "post"),
406 "Post-handler order mismatch at index {}", i);
407 }
408
409 Ok(())
410 });
411 result?;
412 }
413 }
414
415 #[derive(Clone)]
417 struct ShortCircuitMiddleware {
418 error_status: StatusCode,
419 should_short_circuit: bool,
420 }
421
422 impl ShortCircuitMiddleware {
423 fn new(error_status: StatusCode, should_short_circuit: bool) -> Self {
424 Self { error_status, should_short_circuit }
425 }
426 }
427
428 impl MiddlewareLayer for ShortCircuitMiddleware {
429 fn call(
430 &self,
431 req: Request,
432 next: BoxedNext,
433 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
434 let error_status = self.error_status;
435 let should_short_circuit = self.should_short_circuit;
436
437 Box::pin(async move {
438 if should_short_circuit {
439 http::Response::builder()
441 .status(error_status)
442 .body(http_body_util::Full::new(Bytes::from("error")))
443 .unwrap()
444 } else {
445 next(req).await
447 }
448 })
449 }
450
451 fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
452 Box::new(self.clone())
453 }
454 }
455
456 proptest! {
463 #![proptest_config(ProptestConfig::with_cases(100))]
464
465 #[test]
466 fn prop_middleware_short_circuit_on_error(
467 error_status in 400u16..600u16,
468 num_middleware_before in 0usize..5usize,
469 num_middleware_after in 0usize..5usize,
470 ) {
471 let rt = tokio::runtime::Runtime::new().unwrap();
472 let result: Result<(), TestCaseError> = rt.block_on(async {
473 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
474 let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
475
476 let mut stack = LayerStack::new();
478
479 for i in 0..num_middleware_before {
481 stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
482 }
483
484 let error_status = StatusCode::from_u16(error_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
486 stack.push(Box::new(ShortCircuitMiddleware::new(error_status, true)));
487
488 for i in 0..num_middleware_after {
490 stack.push(Box::new(OrderTrackingMiddleware::new(100 + i, order.clone())));
491 }
492
493 let handler_called_clone = handler_called.clone();
495 let handler: BoxedNext = Arc::new(move |_req: Request| {
496 let handler_called = handler_called_clone.clone();
497 Box::pin(async move {
498 handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
499 http::Response::builder()
500 .status(StatusCode::OK)
501 .body(http_body_util::Full::new(Bytes::from("handler")))
502 .unwrap()
503 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
504 });
505
506 let request = create_test_request(Method::GET, "/test");
508 let response = stack.execute(request, handler).await;
509
510 prop_assert_eq!(response.status(), error_status,
512 "Response should have the error status from short-circuit middleware");
513
514 prop_assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst),
516 "Handler should NOT be called when middleware short-circuits");
517
518 let execution_order = order.lock().unwrap();
523
524 let pre_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "pre").count();
526 let post_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "post").count();
527
528 prop_assert_eq!(pre_count, num_middleware_before,
529 "All middleware before short-circuit should have pre recorded");
530 prop_assert_eq!(post_count, num_middleware_before,
531 "All middleware before short-circuit should have post recorded (unwinding)");
532
533 let after_entries = execution_order.iter().filter(|(id, _)| *id >= 100).count();
535 prop_assert_eq!(after_entries, 0,
536 "Middleware after short-circuit should NOT be executed");
537
538 Ok(())
539 });
540 result?;
541 }
542 }
543
544 #[test]
545 fn test_short_circuit_returns_error_response() {
546 let rt = tokio::runtime::Runtime::new().unwrap();
547 rt.block_on(async {
548 let mut stack = LayerStack::new();
549 stack.push(Box::new(ShortCircuitMiddleware::new(StatusCode::UNAUTHORIZED, true)));
550
551 let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
552 let handler_called_clone = handler_called.clone();
553
554 let handler: BoxedNext = Arc::new(move |_req: Request| {
555 let handler_called = handler_called_clone.clone();
556 Box::pin(async move {
557 handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
558 http::Response::builder()
559 .status(StatusCode::OK)
560 .body(http_body_util::Full::new(Bytes::from("handler")))
561 .unwrap()
562 }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
563 });
564
565 let request = create_test_request(Method::GET, "/test");
566 let response = stack.execute(request, handler).await;
567
568 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
569 assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst));
570 });
571 }
572}