rustapi_core/middleware/
layer.rs

1//! Tower Layer integration for RustAPI middleware
2//!
3//! This module provides the infrastructure for applying Tower-compatible layers
4//! to the RustAPI request/response pipeline.
5
6use crate::request::Request;
7use crate::response::Response;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tower::Service;
13
14/// A boxed middleware function type
15#[allow(dead_code)]
16pub type BoxedMiddleware = Arc<
17    dyn Fn(Request, BoxedNext) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
18        + Send
19        + Sync,
20>;
21
22/// A boxed next function for middleware chains
23pub type BoxedNext =
24    Arc<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>;
25
26/// Trait for middleware that can be applied to RustAPI
27///
28/// This trait allows both Tower layers and custom middleware to be used
29/// with the `.layer()` method.
30pub trait MiddlewareLayer: Send + Sync + 'static {
31    /// Apply this middleware to a request, calling `next` to continue the chain
32    fn call(
33        &self,
34        req: Request,
35        next: BoxedNext,
36    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
37
38    /// Clone this middleware into a boxed trait object
39    fn clone_box(&self) -> Box<dyn MiddlewareLayer>;
40}
41
42impl Clone for Box<dyn MiddlewareLayer> {
43    fn clone(&self) -> Self {
44        self.clone_box()
45    }
46}
47
48/// A stack of middleware layers
49#[derive(Clone, Default)]
50pub struct LayerStack {
51    layers: Vec<Box<dyn MiddlewareLayer>>,
52}
53
54impl LayerStack {
55    /// Create a new empty layer stack
56    pub fn new() -> Self {
57        Self { layers: Vec::new() }
58    }
59
60    /// Add a middleware layer to the stack
61    ///
62    /// Layers are executed in the order they are added (outermost first).
63    pub fn push(&mut self, layer: Box<dyn MiddlewareLayer>) {
64        self.layers.push(layer);
65    }
66
67    /// Add a middleware layer to the beginning of the stack
68    ///
69    /// This layer will be executed first (outermost).
70    pub fn prepend(&mut self, layer: Box<dyn MiddlewareLayer>) {
71        self.layers.insert(0, layer);
72    }
73
74    /// Check if the stack is empty
75    pub fn is_empty(&self) -> bool {
76        self.layers.is_empty()
77    }
78
79    /// Get the number of layers
80    pub fn len(&self) -> usize {
81        self.layers.len()
82    }
83
84    /// Execute the middleware stack with a final handler
85    pub fn execute(
86        &self,
87        req: Request,
88        handler: BoxedNext,
89    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
90        if self.layers.is_empty() {
91            return handler(req);
92        }
93
94        // Build the chain from inside out
95        // The last layer added should be the outermost (first to execute)
96        let mut next = handler;
97
98        for layer in self.layers.iter().rev() {
99            let layer = layer.clone_box();
100            let current_next = next;
101            next = Arc::new(move |req: Request| {
102                let layer = layer.clone_box();
103                let next = current_next.clone();
104                Box::pin(async move { layer.call(req, next).await })
105                    as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
106            });
107        }
108
109        next(req)
110    }
111}
112
113/// Wrapper to adapt a Tower Layer to RustAPI's middleware system
114#[allow(dead_code)]
115pub struct TowerLayerAdapter<L> {
116    layer: L,
117}
118
119impl<L> TowerLayerAdapter<L>
120where
121    L: Clone + Send + Sync + 'static,
122{
123    /// Create a new adapter from a Tower layer
124    #[allow(dead_code)]
125    pub fn new(layer: L) -> Self {
126        Self { layer }
127    }
128}
129
130impl<L> Clone for TowerLayerAdapter<L>
131where
132    L: Clone,
133{
134    fn clone(&self) -> Self {
135        Self {
136            layer: self.layer.clone(),
137        }
138    }
139}
140
141/// A simple service wrapper for the next handler in the chain
142#[allow(dead_code)]
143pub struct NextService {
144    next: BoxedNext,
145}
146
147impl NextService {
148    #[allow(dead_code)]
149    pub fn new(next: BoxedNext) -> Self {
150        Self { next }
151    }
152}
153
154impl Clone for NextService {
155    fn clone(&self) -> Self {
156        Self {
157            next: self.next.clone(),
158        }
159    }
160}
161
162impl Service<Request> for NextService {
163    type Response = Response;
164    type Error = std::convert::Infallible;
165    type Future = Pin<Box<dyn Future<Output = Result<Response, Self::Error>> + Send>>;
166
167    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168        Poll::Ready(Ok(()))
169    }
170
171    fn call(&mut self, req: Request) -> Self::Future {
172        let next = self.next.clone();
173        Box::pin(async move { Ok(next(req).await) })
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::request::Request;
181    use crate::response::Response;
182    use bytes::Bytes;
183    use http::{Extensions, Method, StatusCode};
184    use proptest::prelude::*;
185    use proptest::test_runner::TestCaseError;
186    use std::collections::HashMap;
187
188    /// Create a test request with the given method and path
189    fn create_test_request(method: Method, path: &str) -> Request {
190        let uri: http::Uri = path.parse().unwrap();
191        let builder = http::Request::builder().method(method).uri(uri);
192
193        let req = builder.body(()).unwrap();
194        let (parts, _) = req.into_parts();
195
196        Request::new(
197            parts,
198            Bytes::new(),
199            Arc::new(Extensions::new()),
200            HashMap::new(),
201        )
202    }
203
204    /// A simple test middleware that tracks execution order
205    #[derive(Clone)]
206    struct OrderTrackingMiddleware {
207        id: usize,
208        order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>,
209    }
210
211    impl OrderTrackingMiddleware {
212        fn new(id: usize, order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>) -> Self {
213            Self { id, order }
214        }
215    }
216
217    impl MiddlewareLayer for OrderTrackingMiddleware {
218        fn call(
219            &self,
220            req: Request,
221            next: BoxedNext,
222        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
223            let id = self.id;
224            let order = self.order.clone();
225
226            Box::pin(async move {
227                // Record pre-handler execution
228                order.lock().unwrap().push((id, "pre"));
229
230                // Call next
231                let response = next(req).await;
232
233                // Record post-handler execution
234                order.lock().unwrap().push((id, "post"));
235
236                response
237            })
238        }
239
240        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
241            Box::new(self.clone())
242        }
243    }
244
245    /// A middleware that modifies the response status
246    #[derive(Clone)]
247    struct StatusModifyingMiddleware {
248        status: StatusCode,
249    }
250
251    impl StatusModifyingMiddleware {
252        fn new(status: StatusCode) -> Self {
253            Self { status }
254        }
255    }
256
257    impl MiddlewareLayer for StatusModifyingMiddleware {
258        fn call(
259            &self,
260            req: Request,
261            next: BoxedNext,
262        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
263            let status = self.status;
264
265            Box::pin(async move {
266                let mut response = next(req).await;
267                *response.status_mut() = status;
268                response
269            })
270        }
271
272        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
273            Box::new(self.clone())
274        }
275    }
276
277    // **Feature: phase3-batteries-included, Property 1: Layer application preserves handler behavior**
278    //
279    // For any Tower-compatible layer L and handler H, applying L via `.layer(L)` SHALL result
280    // in requests being processed by L before reaching H, and responses being processed by L
281    // after leaving H.
282    //
283    // **Validates: Requirements 1.1**
284    proptest! {
285        #![proptest_config(ProptestConfig::with_cases(100))]
286
287        #[test]
288        fn prop_layer_application_preserves_handler_behavior(
289            handler_status in 200u16..600u16,
290        ) {
291            let rt = tokio::runtime::Runtime::new().unwrap();
292            let result: Result<(), TestCaseError> = rt.block_on(async {
293                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
294
295                // Create a layer stack with one middleware
296                let mut stack = LayerStack::new();
297                stack.push(Box::new(OrderTrackingMiddleware::new(1, order.clone())));
298
299                // Create a handler that returns the specified status
300                let handler_status = StatusCode::from_u16(handler_status).unwrap_or(StatusCode::OK);
301                let handler: BoxedNext = Arc::new(move |_req: Request| {
302                    let status = handler_status;
303                    Box::pin(async move {
304                        http::Response::builder()
305                            .status(status)
306                            .body(http_body_util::Full::new(Bytes::from("test")))
307                            .unwrap()
308                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
309                });
310
311                // Execute through the stack
312                let request = create_test_request(Method::GET, "/test");
313                let response = stack.execute(request, handler).await;
314
315                // Verify the handler was called (response has expected status)
316                prop_assert_eq!(response.status(), handler_status);
317
318                // Verify middleware executed in correct order (pre before post)
319                let execution_order = order.lock().unwrap();
320                prop_assert_eq!(execution_order.len(), 2);
321                prop_assert_eq!(execution_order[0], (1, "pre"));
322                prop_assert_eq!(execution_order[1], (1, "post"));
323
324                Ok(())
325            });
326            result?;
327        }
328    }
329
330    #[test]
331    fn test_empty_layer_stack_calls_handler_directly() {
332        let rt = tokio::runtime::Runtime::new().unwrap();
333        rt.block_on(async {
334            let stack = LayerStack::new();
335
336            let handler: BoxedNext = Arc::new(|_req: Request| {
337                Box::pin(async {
338                    http::Response::builder()
339                        .status(StatusCode::OK)
340                        .body(http_body_util::Full::new(Bytes::from("direct")))
341                        .unwrap()
342                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
343            });
344
345            let request = create_test_request(Method::GET, "/test");
346            let response = stack.execute(request, handler).await;
347
348            assert_eq!(response.status(), StatusCode::OK);
349        });
350    }
351
352    // **Feature: phase3-batteries-included, Property 2: Middleware execution order**
353    //
354    // For any sequence of layers [L1, L2, ..., Ln] added via `.layer()`, requests SHALL pass
355    // through layers in the order L1 → L2 → ... → Ln → Handler → Ln → ... → L2 → L1
356    // (outermost first on request, innermost first on response).
357    //
358    // **Validates: Requirements 1.2**
359    proptest! {
360        #![proptest_config(ProptestConfig::with_cases(100))]
361
362        #[test]
363        fn prop_middleware_execution_order(
364            num_layers in 1usize..10usize,
365        ) {
366            let rt = tokio::runtime::Runtime::new().unwrap();
367            let result: Result<(), TestCaseError> = rt.block_on(async {
368                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
369
370                // Create a layer stack with multiple middleware
371                let mut stack = LayerStack::new();
372                for i in 0..num_layers {
373                    stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
374                }
375
376                // Create a simple handler
377                let handler: BoxedNext = Arc::new(|_req: Request| {
378                    Box::pin(async {
379                        http::Response::builder()
380                            .status(StatusCode::OK)
381                            .body(http_body_util::Full::new(Bytes::from("test")))
382                            .unwrap()
383                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
384                });
385
386                // Execute through the stack
387                let request = create_test_request(Method::GET, "/test");
388                let _response = stack.execute(request, handler).await;
389
390                // Verify execution order
391                let execution_order = order.lock().unwrap();
392
393                // Should have 2 * num_layers entries (pre and post for each)
394                prop_assert_eq!(execution_order.len(), num_layers * 2);
395
396                // First half should be "pre" in order 0, 1, 2, ... (outermost first)
397                for i in 0..num_layers {
398                    prop_assert_eq!(execution_order[i], (i, "pre"),
399                        "Pre-handler order mismatch at index {}", i);
400                }
401
402                // Second half should be "post" in reverse order n-1, n-2, ..., 0 (innermost first)
403                for i in 0..num_layers {
404                    let expected_id = num_layers - 1 - i;
405                    prop_assert_eq!(execution_order[num_layers + i], (expected_id, "post"),
406                        "Post-handler order mismatch at index {}", i);
407                }
408
409                Ok(())
410            });
411            result?;
412        }
413    }
414
415    /// A middleware that short-circuits with an error response without calling next
416    #[derive(Clone)]
417    struct ShortCircuitMiddleware {
418        error_status: StatusCode,
419        should_short_circuit: bool,
420    }
421
422    impl ShortCircuitMiddleware {
423        fn new(error_status: StatusCode, should_short_circuit: bool) -> Self {
424            Self {
425                error_status,
426                should_short_circuit,
427            }
428        }
429    }
430
431    impl MiddlewareLayer for ShortCircuitMiddleware {
432        fn call(
433            &self,
434            req: Request,
435            next: BoxedNext,
436        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
437            let error_status = self.error_status;
438            let should_short_circuit = self.should_short_circuit;
439
440            Box::pin(async move {
441                if should_short_circuit {
442                    // Return error response without calling next (short-circuit)
443                    http::Response::builder()
444                        .status(error_status)
445                        .body(http_body_util::Full::new(Bytes::from("error")))
446                        .unwrap()
447                } else {
448                    // Continue to next middleware/handler
449                    next(req).await
450                }
451            })
452        }
453
454        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
455            Box::new(self.clone())
456        }
457    }
458
459    // **Feature: phase3-batteries-included, Property 4: Middleware short-circuit on error**
460    //
461    // For any middleware that returns an error response, the handler SHALL NOT be invoked,
462    // and the error response SHALL be returned directly to the client.
463    //
464    // **Validates: Requirements 1.5**
465    proptest! {
466        #![proptest_config(ProptestConfig::with_cases(100))]
467
468        #[test]
469        fn prop_middleware_short_circuit_on_error(
470            error_status in 400u16..600u16,
471            num_middleware_before in 0usize..5usize,
472            num_middleware_after in 0usize..5usize,
473        ) {
474            let rt = tokio::runtime::Runtime::new().unwrap();
475            let result: Result<(), TestCaseError> = rt.block_on(async {
476                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
477                let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
478
479                // Create a layer stack with middleware before the short-circuit
480                let mut stack = LayerStack::new();
481
482                // Add middleware before the short-circuit middleware
483                for i in 0..num_middleware_before {
484                    stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
485                }
486
487                // Add the short-circuit middleware
488                let error_status = StatusCode::from_u16(error_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
489                stack.push(Box::new(ShortCircuitMiddleware::new(error_status, true)));
490
491                // Add middleware after the short-circuit middleware (these should NOT execute pre)
492                for i in 0..num_middleware_after {
493                    stack.push(Box::new(OrderTrackingMiddleware::new(100 + i, order.clone())));
494                }
495
496                // Create a handler that tracks if it was called
497                let handler_called_clone = handler_called.clone();
498                let handler: BoxedNext = Arc::new(move |_req: Request| {
499                    let handler_called = handler_called_clone.clone();
500                    Box::pin(async move {
501                        handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
502                        http::Response::builder()
503                            .status(StatusCode::OK)
504                            .body(http_body_util::Full::new(Bytes::from("handler")))
505                            .unwrap()
506                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
507                });
508
509                // Execute through the stack
510                let request = create_test_request(Method::GET, "/test");
511                let response = stack.execute(request, handler).await;
512
513                // Verify the error response was returned
514                prop_assert_eq!(response.status(), error_status,
515                    "Response should have the error status from short-circuit middleware");
516
517                // Verify the handler was NOT called
518                prop_assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst),
519                    "Handler should NOT be called when middleware short-circuits");
520
521                // Verify execution order:
522                // - Middleware before short-circuit should have "pre" recorded
523                // - Middleware after short-circuit should NOT have "pre" recorded (never reached)
524                // - All middleware before short-circuit should have "post" recorded (unwinding)
525                let execution_order = order.lock().unwrap();
526
527                // Count pre and post for middleware before short-circuit
528                let pre_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "pre").count();
529                let post_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "post").count();
530
531                prop_assert_eq!(pre_count, num_middleware_before,
532                    "All middleware before short-circuit should have pre recorded");
533                prop_assert_eq!(post_count, num_middleware_before,
534                    "All middleware before short-circuit should have post recorded (unwinding)");
535
536                // Middleware after short-circuit should NOT have any entries
537                let after_entries = execution_order.iter().filter(|(id, _)| *id >= 100).count();
538                prop_assert_eq!(after_entries, 0,
539                    "Middleware after short-circuit should NOT be executed");
540
541                Ok(())
542            });
543            result?;
544        }
545    }
546
547    #[test]
548    fn test_short_circuit_returns_error_response() {
549        let rt = tokio::runtime::Runtime::new().unwrap();
550        rt.block_on(async {
551            let mut stack = LayerStack::new();
552            stack.push(Box::new(ShortCircuitMiddleware::new(
553                StatusCode::UNAUTHORIZED,
554                true,
555            )));
556
557            let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
558            let handler_called_clone = handler_called.clone();
559
560            let handler: BoxedNext = Arc::new(move |_req: Request| {
561                let handler_called = handler_called_clone.clone();
562                Box::pin(async move {
563                    handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
564                    http::Response::builder()
565                        .status(StatusCode::OK)
566                        .body(http_body_util::Full::new(Bytes::from("handler")))
567                        .unwrap()
568                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
569            });
570
571            let request = create_test_request(Method::GET, "/test");
572            let response = stack.execute(request, handler).await;
573
574            assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
575            assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst));
576        });
577    }
578}