Skip to main content

rustapi_core/middleware/
layer.rs

1//! Tower Layer integration for RustAPI middleware
2//!
3//! This module provides the infrastructure for applying Tower-compatible layers
4//! to the RustAPI request/response pipeline.
5
6use crate::request::Request;
7use crate::response::Response;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tower::Service;
13
14/// A boxed middleware function type
15#[allow(dead_code)]
16pub type BoxedMiddleware = Arc<
17    dyn Fn(Request, BoxedNext) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
18        + Send
19        + Sync,
20>;
21
22/// A boxed next function for middleware chains
23pub type BoxedNext =
24    Arc<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>;
25
26/// Trait for middleware that can be applied to RustAPI
27///
28/// This trait allows both Tower layers and custom middleware to be used
29/// with the `.layer()` method.
30///
31/// # Example: Implementing a custom simple logger middleware
32///
33/// ```rust
34/// use rustapi_core::middleware::{MiddlewareLayer, BoxedNext};
35/// use rustapi_core::{Request, Response};
36/// use std::pin::Pin;
37/// use std::future::Future;
38///
39/// #[derive(Clone)]
40/// struct SimpleLogger;
41///
42/// impl MiddlewareLayer for SimpleLogger {
43///     fn call(
44///         &self,
45///         req: Request,
46///         next: BoxedNext,
47///     ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
48///         Box::pin(async move {
49///             println!("Incoming request: {} {}", req.method(), req.uri());
50///             let response = next(req).await;
51///             println!("Response status: {}", response.status());
52///             response
53///         })
54///     }
55///
56///     fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
57///         Box::new(self.clone())
58///     }
59/// }
60/// ```
61pub trait MiddlewareLayer: Send + Sync + 'static {
62    /// Apply this middleware to a request, calling `next` to continue the chain
63    fn call(
64        &self,
65        req: Request,
66        next: BoxedNext,
67    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
68
69    /// Clone this middleware into a boxed trait object
70    fn clone_box(&self) -> Box<dyn MiddlewareLayer>;
71}
72
73impl Clone for Box<dyn MiddlewareLayer> {
74    fn clone(&self) -> Self {
75        self.clone_box()
76    }
77}
78
79/// A stack of middleware layers
80#[derive(Clone, Default)]
81pub struct LayerStack {
82    layers: Vec<Box<dyn MiddlewareLayer>>,
83}
84
85impl LayerStack {
86    /// Create a new empty layer stack
87    pub fn new() -> Self {
88        Self { layers: Vec::new() }
89    }
90
91    /// Add a middleware layer to the stack
92    ///
93    /// Layers are executed in the order they are added (outermost first).
94    pub fn push(&mut self, layer: Box<dyn MiddlewareLayer>) {
95        self.layers.push(layer);
96    }
97
98    /// Add a middleware layer to the beginning of the stack
99    ///
100    /// This layer will be executed first (outermost).
101    pub fn prepend(&mut self, layer: Box<dyn MiddlewareLayer>) {
102        self.layers.insert(0, layer);
103    }
104
105    /// Check if the stack is empty
106    pub fn is_empty(&self) -> bool {
107        self.layers.is_empty()
108    }
109
110    /// Get the number of layers
111    pub fn len(&self) -> usize {
112        self.layers.len()
113    }
114
115    /// Execute the middleware stack with a final handler
116    pub fn execute(
117        &self,
118        req: Request,
119        handler: BoxedNext,
120    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
121        if self.layers.is_empty() {
122            return handler(req);
123        }
124
125        // Build the chain from inside out
126        // The last layer added should be the outermost (first to execute)
127        let mut next = handler;
128
129        for layer in self.layers.iter().rev() {
130            let layer = layer.clone_box();
131            let current_next = next;
132            next = Arc::new(move |req: Request| {
133                let layer = layer.clone_box();
134                let next = current_next.clone();
135                Box::pin(async move { layer.call(req, next).await })
136                    as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
137            });
138        }
139
140        next(req)
141    }
142}
143
144impl IntoIterator for LayerStack {
145    type Item = Box<dyn MiddlewareLayer>;
146    type IntoIter = std::vec::IntoIter<Self::Item>;
147
148    fn into_iter(self) -> Self::IntoIter {
149        self.layers.into_iter()
150    }
151}
152
153impl Extend<Box<dyn MiddlewareLayer>> for LayerStack {
154    fn extend<T: IntoIterator<Item = Box<dyn MiddlewareLayer>>>(&mut self, iter: T) {
155        self.layers.extend(iter);
156    }
157}
158
159/// Wrapper to adapt a Tower Layer to RustAPI's middleware system
160#[allow(dead_code)]
161pub struct TowerLayerAdapter<L> {
162    layer: L,
163}
164
165impl<L> TowerLayerAdapter<L>
166where
167    L: Clone + Send + Sync + 'static,
168{
169    /// Create a new adapter from a Tower layer
170    #[allow(dead_code)]
171    pub fn new(layer: L) -> Self {
172        Self { layer }
173    }
174}
175
176impl<L> Clone for TowerLayerAdapter<L>
177where
178    L: Clone,
179{
180    fn clone(&self) -> Self {
181        Self {
182            layer: self.layer.clone(),
183        }
184    }
185}
186
187/// A simple service wrapper for the next handler in the chain
188#[allow(dead_code)]
189pub struct NextService {
190    next: BoxedNext,
191}
192
193impl NextService {
194    #[allow(dead_code)]
195    pub fn new(next: BoxedNext) -> Self {
196        Self { next }
197    }
198}
199
200impl Clone for NextService {
201    fn clone(&self) -> Self {
202        Self {
203            next: self.next.clone(),
204        }
205    }
206}
207
208impl Service<Request> for NextService {
209    type Response = Response;
210    type Error = std::convert::Infallible;
211    type Future = Pin<Box<dyn Future<Output = Result<Response, Self::Error>> + Send>>;
212
213    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
214        Poll::Ready(Ok(()))
215    }
216
217    fn call(&mut self, req: Request) -> Self::Future {
218        let next = self.next.clone();
219        Box::pin(async move { Ok(next(req).await) })
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use crate::path_params::PathParams;
227    use crate::request::Request;
228    use crate::response::Response;
229    use bytes::Bytes;
230    use http::{Extensions, Method, StatusCode};
231    use proptest::prelude::*;
232    use proptest::test_runner::TestCaseError;
233
234    /// Create a test request with the given method and path
235    fn create_test_request(method: Method, path: &str) -> Request {
236        let uri: http::Uri = path.parse().unwrap();
237        let builder = http::Request::builder().method(method).uri(uri);
238
239        let req = builder.body(()).unwrap();
240        let (parts, _) = req.into_parts();
241
242        Request::new(
243            parts,
244            crate::request::BodyVariant::Buffered(Bytes::new()),
245            Arc::new(Extensions::new()),
246            PathParams::new(),
247        )
248    }
249
250    /// A simple test middleware that tracks execution order
251    #[derive(Clone)]
252    struct OrderTrackingMiddleware {
253        id: usize,
254        order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>,
255    }
256
257    impl OrderTrackingMiddleware {
258        fn new(id: usize, order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>) -> Self {
259            Self { id, order }
260        }
261    }
262
263    impl MiddlewareLayer for OrderTrackingMiddleware {
264        fn call(
265            &self,
266            req: Request,
267            next: BoxedNext,
268        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
269            let id = self.id;
270            let order = self.order.clone();
271
272            Box::pin(async move {
273                // Record pre-handler execution
274                order.lock().unwrap().push((id, "pre"));
275
276                // Call next
277                let response = next(req).await;
278
279                // Record post-handler execution
280                order.lock().unwrap().push((id, "post"));
281
282                response
283            })
284        }
285
286        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
287            Box::new(self.clone())
288        }
289    }
290
291    /// A middleware that modifies the response status
292    #[derive(Clone)]
293    #[allow(dead_code)]
294    struct StatusModifyingMiddleware {
295        status: StatusCode,
296    }
297
298    #[allow(dead_code)]
299    impl StatusModifyingMiddleware {
300        fn new(status: StatusCode) -> Self {
301            Self { status }
302        }
303    }
304
305    impl MiddlewareLayer for StatusModifyingMiddleware {
306        fn call(
307            &self,
308            req: Request,
309            next: BoxedNext,
310        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
311            let status = self.status;
312
313            Box::pin(async move {
314                let mut response = next(req).await;
315                *response.status_mut() = status;
316                response
317            })
318        }
319
320        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
321            Box::new(self.clone())
322        }
323    }
324
325    // **Feature: phase3-batteries-included, Property 1: Layer application preserves handler behavior**
326    //
327    // For any Tower-compatible layer L and handler H, applying L via `.layer(L)` SHALL result
328    // in requests being processed by L before reaching H, and responses being processed by L
329    // after leaving H.
330    //
331    // **Validates: Requirements 1.1**
332    proptest! {
333        #![proptest_config(ProptestConfig::with_cases(100))]
334
335        #[test]
336        fn prop_layer_application_preserves_handler_behavior(
337            handler_status in 200u16..600u16,
338        ) {
339            let rt = tokio::runtime::Runtime::new().unwrap();
340            let result: Result<(), TestCaseError> = rt.block_on(async {
341                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
342
343                // Create a layer stack with one middleware
344                let mut stack = LayerStack::new();
345                stack.push(Box::new(OrderTrackingMiddleware::new(1, order.clone())));
346
347                // Create a handler that returns the specified status
348                let handler_status = StatusCode::from_u16(handler_status).unwrap_or(StatusCode::OK);
349                let handler: BoxedNext = Arc::new(move |_req: Request| {
350                    let status = handler_status;
351                    Box::pin(async move {
352                        http::Response::builder()
353                            .status(status)
354                            .body(crate::response::Body::from("test"))
355                            .unwrap()
356                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
357                });
358
359                // Execute through the stack
360                let request = create_test_request(Method::GET, "/test");
361                let response = stack.execute(request, handler).await;
362
363                // Verify the handler was called (response has expected status)
364                prop_assert_eq!(response.status(), handler_status);
365
366                // Verify middleware executed in correct order (pre before post)
367                let execution_order = order.lock().unwrap();
368                prop_assert_eq!(execution_order.len(), 2);
369                prop_assert_eq!(execution_order[0], (1, "pre"));
370                prop_assert_eq!(execution_order[1], (1, "post"));
371
372                Ok(())
373            });
374            result?;
375        }
376    }
377
378    #[test]
379    fn test_empty_layer_stack_calls_handler_directly() {
380        let rt = tokio::runtime::Runtime::new().unwrap();
381        rt.block_on(async {
382            let stack = LayerStack::new();
383
384            let handler: BoxedNext = Arc::new(|_req: Request| {
385                Box::pin(async {
386                    http::Response::builder()
387                        .status(StatusCode::OK)
388                        .body(crate::response::Body::from("direct"))
389                        .unwrap()
390                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
391            });
392
393            let request = create_test_request(Method::GET, "/test");
394            let response = stack.execute(request, handler).await;
395
396            assert_eq!(response.status(), StatusCode::OK);
397        });
398    }
399
400    // **Feature: phase3-batteries-included, Property 2: Middleware execution order**
401    //
402    // For any sequence of layers [L1, L2, ..., Ln] added via `.layer()`, requests SHALL pass
403    // through layers in the order L1 → L2 → ... → Ln → Handler → Ln → ... → L2 → L1
404    // (outermost first on request, innermost first on response).
405    //
406    // **Validates: Requirements 1.2**
407    proptest! {
408        #![proptest_config(ProptestConfig::with_cases(100))]
409
410        #[test]
411        fn prop_middleware_execution_order(
412            num_layers in 1usize..10usize,
413        ) {
414            let rt = tokio::runtime::Runtime::new().unwrap();
415            let result: Result<(), TestCaseError> = rt.block_on(async {
416                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
417
418                // Create a layer stack with multiple middleware
419                let mut stack = LayerStack::new();
420                for i in 0..num_layers {
421                    stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
422                }
423
424                // Create a simple handler
425                let handler: BoxedNext = Arc::new(|_req: Request| {
426                    Box::pin(async {
427                        http::Response::builder()
428                            .status(StatusCode::OK)
429                            .body(crate::response::Body::from("test"))
430                            .unwrap()
431                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
432                });
433
434                // Execute through the stack
435                let request = create_test_request(Method::GET, "/test");
436                let _response = stack.execute(request, handler).await;
437
438                // Verify execution order
439                let execution_order = order.lock().unwrap();
440
441                // Should have 2 * num_layers entries (pre and post for each)
442                prop_assert_eq!(execution_order.len(), num_layers * 2);
443
444                // First half should be "pre" in order 0, 1, 2, ... (outermost first)
445                for i in 0..num_layers {
446                    prop_assert_eq!(execution_order[i], (i, "pre"),
447                        "Pre-handler order mismatch at index {}", i);
448                }
449
450                // Second half should be "post" in reverse order n-1, n-2, ..., 0 (innermost first)
451                for i in 0..num_layers {
452                    let expected_id = num_layers - 1 - i;
453                    prop_assert_eq!(execution_order[num_layers + i], (expected_id, "post"),
454                        "Post-handler order mismatch at index {}", i);
455                }
456
457                Ok(())
458            });
459            result?;
460        }
461    }
462
463    /// A middleware that short-circuits with an error response without calling next
464    #[derive(Clone)]
465    struct ShortCircuitMiddleware {
466        error_status: StatusCode,
467        should_short_circuit: bool,
468    }
469
470    impl ShortCircuitMiddleware {
471        fn new(error_status: StatusCode, should_short_circuit: bool) -> Self {
472            Self {
473                error_status,
474                should_short_circuit,
475            }
476        }
477    }
478
479    impl MiddlewareLayer for ShortCircuitMiddleware {
480        fn call(
481            &self,
482            req: Request,
483            next: BoxedNext,
484        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
485            let error_status = self.error_status;
486            let should_short_circuit = self.should_short_circuit;
487
488            Box::pin(async move {
489                if should_short_circuit {
490                    // Return error response without calling next (short-circuit)
491                    http::Response::builder()
492                        .status(error_status)
493                        .body(crate::response::Body::from("error"))
494                        .unwrap()
495                } else {
496                    // Continue to next middleware/handler
497                    next(req).await
498                }
499            })
500        }
501
502        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
503            Box::new(self.clone())
504        }
505    }
506
507    // **Feature: phase3-batteries-included, Property 4: Middleware short-circuit on error**
508    //
509    // For any middleware that returns an error response, the handler SHALL NOT be invoked,
510    // and the error response SHALL be returned directly to the client.
511    //
512    // **Validates: Requirements 1.5**
513    proptest! {
514        #![proptest_config(ProptestConfig::with_cases(100))]
515
516        #[test]
517        fn prop_middleware_short_circuit_on_error(
518            error_status in 400u16..600u16,
519            num_middleware_before in 0usize..5usize,
520            num_middleware_after in 0usize..5usize,
521        ) {
522            let rt = tokio::runtime::Runtime::new().unwrap();
523            let result: Result<(), TestCaseError> = rt.block_on(async {
524                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
525                let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
526
527                // Create a layer stack with middleware before the short-circuit
528                let mut stack = LayerStack::new();
529
530                // Add middleware before the short-circuit middleware
531                for i in 0..num_middleware_before {
532                    stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
533                }
534
535                // Add the short-circuit middleware
536                let error_status = StatusCode::from_u16(error_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
537                stack.push(Box::new(ShortCircuitMiddleware::new(error_status, true)));
538
539                // Add middleware after the short-circuit middleware (these should NOT execute pre)
540                for i in 0..num_middleware_after {
541                    stack.push(Box::new(OrderTrackingMiddleware::new(100 + i, order.clone())));
542                }
543
544                // Create a handler that tracks if it was called
545                let handler_called_clone = handler_called.clone();
546                let handler: BoxedNext = Arc::new(move |_req: Request| {
547                    let handler_called = handler_called_clone.clone();
548                    Box::pin(async move {
549                        handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
550                        http::Response::builder()
551                            .status(StatusCode::OK)
552                            .body(crate::response::Body::from("handler"))
553                            .unwrap()
554                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
555                });
556
557                // Execute through the stack
558                let request = create_test_request(Method::GET, "/test");
559                let response = stack.execute(request, handler).await;
560
561                // Verify the error response was returned
562                prop_assert_eq!(response.status(), error_status,
563                    "Response should have the error status from short-circuit middleware");
564
565                // Verify the handler was NOT called
566                prop_assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst),
567                    "Handler should NOT be called when middleware short-circuits");
568
569                // Verify execution order:
570                // - Middleware before short-circuit should have "pre" recorded
571                // - Middleware after short-circuit should NOT have "pre" recorded (never reached)
572                // - All middleware before short-circuit should have "post" recorded (unwinding)
573                let execution_order = order.lock().unwrap();
574
575                // Count pre and post for middleware before short-circuit
576                let pre_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "pre").count();
577                let post_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "post").count();
578
579                prop_assert_eq!(pre_count, num_middleware_before,
580                    "All middleware before short-circuit should have pre recorded");
581                prop_assert_eq!(post_count, num_middleware_before,
582                    "All middleware before short-circuit should have post recorded (unwinding)");
583
584                // Middleware after short-circuit should NOT have any entries
585                let after_entries = execution_order.iter().filter(|(id, _)| *id >= 100).count();
586                prop_assert_eq!(after_entries, 0,
587                    "Middleware after short-circuit should NOT be executed");
588
589                Ok(())
590            });
591            result?;
592        }
593    }
594
595    #[test]
596    fn test_short_circuit_returns_error_response() {
597        let rt = tokio::runtime::Runtime::new().unwrap();
598        rt.block_on(async {
599            let mut stack = LayerStack::new();
600            stack.push(Box::new(ShortCircuitMiddleware::new(
601                StatusCode::UNAUTHORIZED,
602                true,
603            )));
604
605            let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
606            let handler_called_clone = handler_called.clone();
607
608            let handler: BoxedNext = Arc::new(move |_req: Request| {
609                let handler_called = handler_called_clone.clone();
610                Box::pin(async move {
611                    handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
612                    http::Response::builder()
613                        .status(StatusCode::OK)
614                        .body(crate::response::Body::from("handler"))
615                        .unwrap()
616                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
617            });
618
619            let request = create_test_request(Method::GET, "/test");
620            let response = stack.execute(request, handler).await;
621
622            assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
623            assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst));
624        });
625    }
626}