rustapi_core/middleware/
layer.rs

1//! Tower Layer integration for RustAPI middleware
2//!
3//! This module provides the infrastructure for applying Tower-compatible layers
4//! to the RustAPI request/response pipeline.
5
6use crate::request::Request;
7use crate::response::Response;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tower::Service;
13
14/// A boxed middleware function type
15pub type BoxedMiddleware = Arc<
16    dyn Fn(
17            Request,
18            BoxedNext,
19        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
20        + Send
21        + Sync,
22>;
23
24/// A boxed next function for middleware chains
25pub type BoxedNext =
26    Arc<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>;
27
28/// Trait for middleware that can be applied to RustAPI
29///
30/// This trait allows both Tower layers and custom middleware to be used
31/// with the `.layer()` method.
32pub trait MiddlewareLayer: Send + Sync + 'static {
33    /// Apply this middleware to a request, calling `next` to continue the chain
34    fn call(
35        &self,
36        req: Request,
37        next: BoxedNext,
38    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
39
40    /// Clone this middleware into a boxed trait object
41    fn clone_box(&self) -> Box<dyn MiddlewareLayer>;
42}
43
44impl Clone for Box<dyn MiddlewareLayer> {
45    fn clone(&self) -> Self {
46        self.clone_box()
47    }
48}
49
50/// A stack of middleware layers
51#[derive(Clone, Default)]
52pub struct LayerStack {
53    layers: Vec<Box<dyn MiddlewareLayer>>,
54}
55
56impl LayerStack {
57    /// Create a new empty layer stack
58    pub fn new() -> Self {
59        Self { layers: Vec::new() }
60    }
61
62    /// Add a middleware layer to the stack
63    ///
64    /// Layers are executed in the order they are added (outermost first).
65    pub fn push(&mut self, layer: Box<dyn MiddlewareLayer>) {
66        self.layers.push(layer);
67    }
68
69    /// Add a middleware layer to the beginning of the stack
70    ///
71    /// This layer will be executed first (outermost).
72    pub fn prepend(&mut self, layer: Box<dyn MiddlewareLayer>) {
73        self.layers.insert(0, layer);
74    }
75
76    /// Check if the stack is empty
77    pub fn is_empty(&self) -> bool {
78        self.layers.is_empty()
79    }
80
81    /// Get the number of layers
82    pub fn len(&self) -> usize {
83        self.layers.len()
84    }
85
86    /// Execute the middleware stack with a final handler
87    pub fn execute(
88        &self,
89        req: Request,
90        handler: BoxedNext,
91    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
92        if self.layers.is_empty() {
93            return handler(req);
94        }
95
96        // Build the chain from inside out
97        // The last layer added should be the outermost (first to execute)
98        let mut next = handler;
99
100        for layer in self.layers.iter().rev() {
101            let layer = layer.clone_box();
102            let current_next = next;
103            next = Arc::new(move |req: Request| {
104                let layer = layer.clone_box();
105                let next = current_next.clone();
106                Box::pin(async move { layer.call(req, next).await })
107                    as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
108            });
109        }
110
111        next(req)
112    }
113}
114
115/// Wrapper to adapt a Tower Layer to RustAPI's middleware system
116pub struct TowerLayerAdapter<L> {
117    layer: L,
118}
119
120impl<L> TowerLayerAdapter<L>
121where
122    L: Clone + Send + Sync + 'static,
123{
124    /// Create a new adapter from a Tower layer
125    pub fn new(layer: L) -> Self {
126        Self { layer }
127    }
128}
129
130impl<L> Clone for TowerLayerAdapter<L>
131where
132    L: Clone,
133{
134    fn clone(&self) -> Self {
135        Self {
136            layer: self.layer.clone(),
137        }
138    }
139}
140
141/// A simple service wrapper for the next handler in the chain
142pub struct NextService {
143    next: BoxedNext,
144}
145
146impl NextService {
147    pub fn new(next: BoxedNext) -> Self {
148        Self { next }
149    }
150}
151
152impl Clone for NextService {
153    fn clone(&self) -> Self {
154        Self {
155            next: self.next.clone(),
156        }
157    }
158}
159
160impl Service<Request> for NextService {
161    type Response = Response;
162    type Error = std::convert::Infallible;
163    type Future = Pin<Box<dyn Future<Output = Result<Response, Self::Error>> + Send>>;
164
165    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
166        Poll::Ready(Ok(()))
167    }
168
169    fn call(&mut self, req: Request) -> Self::Future {
170        let next = self.next.clone();
171        Box::pin(async move { Ok(next(req).await) })
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use crate::request::Request;
179    use crate::response::Response;
180    use bytes::Bytes;
181    use http::{Extensions, Method, StatusCode};
182    use proptest::prelude::*;
183    use proptest::test_runner::TestCaseError;
184    use std::collections::HashMap;
185
186    /// Create a test request with the given method and path
187    fn create_test_request(method: Method, path: &str) -> Request {
188        let uri: http::Uri = path.parse().unwrap();
189        let builder = http::Request::builder()
190            .method(method)
191            .uri(uri);
192        
193        let req = builder.body(()).unwrap();
194        let (parts, _) = req.into_parts();
195        
196        Request::new(
197            parts,
198            Bytes::new(),
199            Arc::new(Extensions::new()),
200            HashMap::new(),
201        )
202    }
203
204    /// A simple test middleware that tracks execution order
205    #[derive(Clone)]
206    struct OrderTrackingMiddleware {
207        id: usize,
208        order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>,
209    }
210
211    impl OrderTrackingMiddleware {
212        fn new(id: usize, order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>) -> Self {
213            Self { id, order }
214        }
215    }
216
217    impl MiddlewareLayer for OrderTrackingMiddleware {
218        fn call(
219            &self,
220            req: Request,
221            next: BoxedNext,
222        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
223            let id = self.id;
224            let order = self.order.clone();
225            
226            Box::pin(async move {
227                // Record pre-handler execution
228                order.lock().unwrap().push((id, "pre"));
229                
230                // Call next
231                let response = next(req).await;
232                
233                // Record post-handler execution
234                order.lock().unwrap().push((id, "post"));
235                
236                response
237            })
238        }
239
240        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
241            Box::new(self.clone())
242        }
243    }
244
245    /// A middleware that modifies the response status
246    #[derive(Clone)]
247    struct StatusModifyingMiddleware {
248        status: StatusCode,
249    }
250
251    impl StatusModifyingMiddleware {
252        fn new(status: StatusCode) -> Self {
253            Self { status }
254        }
255    }
256
257    impl MiddlewareLayer for StatusModifyingMiddleware {
258        fn call(
259            &self,
260            req: Request,
261            next: BoxedNext,
262        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
263            let status = self.status;
264            
265            Box::pin(async move {
266                let mut response = next(req).await;
267                *response.status_mut() = status;
268                response
269            })
270        }
271
272        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
273            Box::new(self.clone())
274        }
275    }
276
277    // **Feature: phase3-batteries-included, Property 1: Layer application preserves handler behavior**
278    // 
279    // For any Tower-compatible layer L and handler H, applying L via `.layer(L)` SHALL result 
280    // in requests being processed by L before reaching H, and responses being processed by L 
281    // after leaving H.
282    // 
283    // **Validates: Requirements 1.1**
284    proptest! {
285        #![proptest_config(ProptestConfig::with_cases(100))]
286        
287        #[test]
288        fn prop_layer_application_preserves_handler_behavior(
289            handler_status in 200u16..600u16,
290        ) {
291            let rt = tokio::runtime::Runtime::new().unwrap();
292            let result: Result<(), TestCaseError> = rt.block_on(async {
293                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
294                
295                // Create a layer stack with one middleware
296                let mut stack = LayerStack::new();
297                stack.push(Box::new(OrderTrackingMiddleware::new(1, order.clone())));
298                
299                // Create a handler that returns the specified status
300                let handler_status = StatusCode::from_u16(handler_status).unwrap_or(StatusCode::OK);
301                let handler: BoxedNext = Arc::new(move |_req: Request| {
302                    let status = handler_status;
303                    Box::pin(async move {
304                        http::Response::builder()
305                            .status(status)
306                            .body(http_body_util::Full::new(Bytes::from("test")))
307                            .unwrap()
308                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
309                });
310                
311                // Execute through the stack
312                let request = create_test_request(Method::GET, "/test");
313                let response = stack.execute(request, handler).await;
314                
315                // Verify the handler was called (response has expected status)
316                prop_assert_eq!(response.status(), handler_status);
317                
318                // Verify middleware executed in correct order (pre before post)
319                let execution_order = order.lock().unwrap();
320                prop_assert_eq!(execution_order.len(), 2);
321                prop_assert_eq!(execution_order[0], (1, "pre"));
322                prop_assert_eq!(execution_order[1], (1, "post"));
323                
324                Ok(())
325            });
326            result?;
327        }
328    }
329
330    #[test]
331    fn test_empty_layer_stack_calls_handler_directly() {
332        let rt = tokio::runtime::Runtime::new().unwrap();
333        rt.block_on(async {
334            let stack = LayerStack::new();
335            
336            let handler: BoxedNext = Arc::new(|_req: Request| {
337                Box::pin(async {
338                    http::Response::builder()
339                        .status(StatusCode::OK)
340                        .body(http_body_util::Full::new(Bytes::from("direct")))
341                        .unwrap()
342                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
343            });
344            
345            let request = create_test_request(Method::GET, "/test");
346            let response = stack.execute(request, handler).await;
347            
348            assert_eq!(response.status(), StatusCode::OK);
349        });
350    }
351
352    // **Feature: phase3-batteries-included, Property 2: Middleware execution order**
353    // 
354    // For any sequence of layers [L1, L2, ..., Ln] added via `.layer()`, requests SHALL pass 
355    // through layers in the order L1 → L2 → ... → Ln → Handler → Ln → ... → L2 → L1 
356    // (outermost first on request, innermost first on response).
357    // 
358    // **Validates: Requirements 1.2**
359    proptest! {
360        #![proptest_config(ProptestConfig::with_cases(100))]
361        
362        #[test]
363        fn prop_middleware_execution_order(
364            num_layers in 1usize..10usize,
365        ) {
366            let rt = tokio::runtime::Runtime::new().unwrap();
367            let result: Result<(), TestCaseError> = rt.block_on(async {
368                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
369                
370                // Create a layer stack with multiple middleware
371                let mut stack = LayerStack::new();
372                for i in 0..num_layers {
373                    stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
374                }
375                
376                // Create a simple handler
377                let handler: BoxedNext = Arc::new(|_req: Request| {
378                    Box::pin(async {
379                        http::Response::builder()
380                            .status(StatusCode::OK)
381                            .body(http_body_util::Full::new(Bytes::from("test")))
382                            .unwrap()
383                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
384                });
385                
386                // Execute through the stack
387                let request = create_test_request(Method::GET, "/test");
388                let _response = stack.execute(request, handler).await;
389                
390                // Verify execution order
391                let execution_order = order.lock().unwrap();
392                
393                // Should have 2 * num_layers entries (pre and post for each)
394                prop_assert_eq!(execution_order.len(), num_layers * 2);
395                
396                // First half should be "pre" in order 0, 1, 2, ... (outermost first)
397                for i in 0..num_layers {
398                    prop_assert_eq!(execution_order[i], (i, "pre"), 
399                        "Pre-handler order mismatch at index {}", i);
400                }
401                
402                // Second half should be "post" in reverse order n-1, n-2, ..., 0 (innermost first)
403                for i in 0..num_layers {
404                    let expected_id = num_layers - 1 - i;
405                    prop_assert_eq!(execution_order[num_layers + i], (expected_id, "post"),
406                        "Post-handler order mismatch at index {}", i);
407                }
408                
409                Ok(())
410            });
411            result?;
412        }
413    }
414
415    /// A middleware that short-circuits with an error response without calling next
416    #[derive(Clone)]
417    struct ShortCircuitMiddleware {
418        error_status: StatusCode,
419        should_short_circuit: bool,
420    }
421
422    impl ShortCircuitMiddleware {
423        fn new(error_status: StatusCode, should_short_circuit: bool) -> Self {
424            Self { error_status, should_short_circuit }
425        }
426    }
427
428    impl MiddlewareLayer for ShortCircuitMiddleware {
429        fn call(
430            &self,
431            req: Request,
432            next: BoxedNext,
433        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
434            let error_status = self.error_status;
435            let should_short_circuit = self.should_short_circuit;
436            
437            Box::pin(async move {
438                if should_short_circuit {
439                    // Return error response without calling next (short-circuit)
440                    http::Response::builder()
441                        .status(error_status)
442                        .body(http_body_util::Full::new(Bytes::from("error")))
443                        .unwrap()
444                } else {
445                    // Continue to next middleware/handler
446                    next(req).await
447                }
448            })
449        }
450
451        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
452            Box::new(self.clone())
453        }
454    }
455
456    // **Feature: phase3-batteries-included, Property 4: Middleware short-circuit on error**
457    // 
458    // For any middleware that returns an error response, the handler SHALL NOT be invoked,
459    // and the error response SHALL be returned directly to the client.
460    // 
461    // **Validates: Requirements 1.5**
462    proptest! {
463        #![proptest_config(ProptestConfig::with_cases(100))]
464        
465        #[test]
466        fn prop_middleware_short_circuit_on_error(
467            error_status in 400u16..600u16,
468            num_middleware_before in 0usize..5usize,
469            num_middleware_after in 0usize..5usize,
470        ) {
471            let rt = tokio::runtime::Runtime::new().unwrap();
472            let result: Result<(), TestCaseError> = rt.block_on(async {
473                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
474                let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
475                
476                // Create a layer stack with middleware before the short-circuit
477                let mut stack = LayerStack::new();
478                
479                // Add middleware before the short-circuit middleware
480                for i in 0..num_middleware_before {
481                    stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
482                }
483                
484                // Add the short-circuit middleware
485                let error_status = StatusCode::from_u16(error_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
486                stack.push(Box::new(ShortCircuitMiddleware::new(error_status, true)));
487                
488                // Add middleware after the short-circuit middleware (these should NOT execute pre)
489                for i in 0..num_middleware_after {
490                    stack.push(Box::new(OrderTrackingMiddleware::new(100 + i, order.clone())));
491                }
492                
493                // Create a handler that tracks if it was called
494                let handler_called_clone = handler_called.clone();
495                let handler: BoxedNext = Arc::new(move |_req: Request| {
496                    let handler_called = handler_called_clone.clone();
497                    Box::pin(async move {
498                        handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
499                        http::Response::builder()
500                            .status(StatusCode::OK)
501                            .body(http_body_util::Full::new(Bytes::from("handler")))
502                            .unwrap()
503                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
504                });
505                
506                // Execute through the stack
507                let request = create_test_request(Method::GET, "/test");
508                let response = stack.execute(request, handler).await;
509                
510                // Verify the error response was returned
511                prop_assert_eq!(response.status(), error_status,
512                    "Response should have the error status from short-circuit middleware");
513                
514                // Verify the handler was NOT called
515                prop_assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst),
516                    "Handler should NOT be called when middleware short-circuits");
517                
518                // Verify execution order:
519                // - Middleware before short-circuit should have "pre" recorded
520                // - Middleware after short-circuit should NOT have "pre" recorded (never reached)
521                // - All middleware before short-circuit should have "post" recorded (unwinding)
522                let execution_order = order.lock().unwrap();
523                
524                // Count pre and post for middleware before short-circuit
525                let pre_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "pre").count();
526                let post_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "post").count();
527                
528                prop_assert_eq!(pre_count, num_middleware_before,
529                    "All middleware before short-circuit should have pre recorded");
530                prop_assert_eq!(post_count, num_middleware_before,
531                    "All middleware before short-circuit should have post recorded (unwinding)");
532                
533                // Middleware after short-circuit should NOT have any entries
534                let after_entries = execution_order.iter().filter(|(id, _)| *id >= 100).count();
535                prop_assert_eq!(after_entries, 0,
536                    "Middleware after short-circuit should NOT be executed");
537                
538                Ok(())
539            });
540            result?;
541        }
542    }
543
544    #[test]
545    fn test_short_circuit_returns_error_response() {
546        let rt = tokio::runtime::Runtime::new().unwrap();
547        rt.block_on(async {
548            let mut stack = LayerStack::new();
549            stack.push(Box::new(ShortCircuitMiddleware::new(StatusCode::UNAUTHORIZED, true)));
550            
551            let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
552            let handler_called_clone = handler_called.clone();
553            
554            let handler: BoxedNext = Arc::new(move |_req: Request| {
555                let handler_called = handler_called_clone.clone();
556                Box::pin(async move {
557                    handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
558                    http::Response::builder()
559                        .status(StatusCode::OK)
560                        .body(http_body_util::Full::new(Bytes::from("handler")))
561                        .unwrap()
562                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
563            });
564            
565            let request = create_test_request(Method::GET, "/test");
566            let response = stack.execute(request, handler).await;
567            
568            assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
569            assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst));
570        });
571    }
572}