rustapi_core/middleware/
layer.rs

1//! Tower Layer integration for RustAPI middleware
2//!
3//! This module provides the infrastructure for applying Tower-compatible layers
4//! to the RustAPI request/response pipeline.
5
6use crate::request::Request;
7use crate::response::Response;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use tower::Service;
13
14/// A boxed middleware function type
15#[allow(dead_code)]
16pub type BoxedMiddleware = Arc<
17    dyn Fn(Request, BoxedNext) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
18        + Send
19        + Sync,
20>;
21
22/// A boxed next function for middleware chains
23pub type BoxedNext =
24    Arc<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>;
25
26/// Trait for middleware that can be applied to RustAPI
27///
28/// This trait allows both Tower layers and custom middleware to be used
29/// with the `.layer()` method.
30pub trait MiddlewareLayer: Send + Sync + 'static {
31    /// Apply this middleware to a request, calling `next` to continue the chain
32    fn call(
33        &self,
34        req: Request,
35        next: BoxedNext,
36    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
37
38    /// Clone this middleware into a boxed trait object
39    fn clone_box(&self) -> Box<dyn MiddlewareLayer>;
40}
41
42impl Clone for Box<dyn MiddlewareLayer> {
43    fn clone(&self) -> Self {
44        self.clone_box()
45    }
46}
47
48/// A stack of middleware layers
49#[derive(Clone, Default)]
50pub struct LayerStack {
51    layers: Vec<Box<dyn MiddlewareLayer>>,
52}
53
54impl LayerStack {
55    /// Create a new empty layer stack
56    pub fn new() -> Self {
57        Self { layers: Vec::new() }
58    }
59
60    /// Add a middleware layer to the stack
61    ///
62    /// Layers are executed in the order they are added (outermost first).
63    pub fn push(&mut self, layer: Box<dyn MiddlewareLayer>) {
64        self.layers.push(layer);
65    }
66
67    /// Add a middleware layer to the beginning of the stack
68    ///
69    /// This layer will be executed first (outermost).
70    pub fn prepend(&mut self, layer: Box<dyn MiddlewareLayer>) {
71        self.layers.insert(0, layer);
72    }
73
74    /// Check if the stack is empty
75    pub fn is_empty(&self) -> bool {
76        self.layers.is_empty()
77    }
78
79    /// Get the number of layers
80    pub fn len(&self) -> usize {
81        self.layers.len()
82    }
83
84    /// Execute the middleware stack with a final handler
85    pub fn execute(
86        &self,
87        req: Request,
88        handler: BoxedNext,
89    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
90        if self.layers.is_empty() {
91            return handler(req);
92        }
93
94        // Build the chain from inside out
95        // The last layer added should be the outermost (first to execute)
96        let mut next = handler;
97
98        for layer in self.layers.iter().rev() {
99            let layer = layer.clone_box();
100            let current_next = next;
101            next = Arc::new(move |req: Request| {
102                let layer = layer.clone_box();
103                let next = current_next.clone();
104                Box::pin(async move { layer.call(req, next).await })
105                    as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
106            });
107        }
108
109        next(req)
110    }
111}
112
113impl IntoIterator for LayerStack {
114    type Item = Box<dyn MiddlewareLayer>;
115    type IntoIter = std::vec::IntoIter<Self::Item>;
116
117    fn into_iter(self) -> Self::IntoIter {
118        self.layers.into_iter()
119    }
120}
121
122impl Extend<Box<dyn MiddlewareLayer>> for LayerStack {
123    fn extend<T: IntoIterator<Item = Box<dyn MiddlewareLayer>>>(&mut self, iter: T) {
124        self.layers.extend(iter);
125    }
126}
127
128/// Wrapper to adapt a Tower Layer to RustAPI's middleware system
129#[allow(dead_code)]
130pub struct TowerLayerAdapter<L> {
131    layer: L,
132}
133
134impl<L> TowerLayerAdapter<L>
135where
136    L: Clone + Send + Sync + 'static,
137{
138    /// Create a new adapter from a Tower layer
139    #[allow(dead_code)]
140    pub fn new(layer: L) -> Self {
141        Self { layer }
142    }
143}
144
145impl<L> Clone for TowerLayerAdapter<L>
146where
147    L: Clone,
148{
149    fn clone(&self) -> Self {
150        Self {
151            layer: self.layer.clone(),
152        }
153    }
154}
155
156/// A simple service wrapper for the next handler in the chain
157#[allow(dead_code)]
158pub struct NextService {
159    next: BoxedNext,
160}
161
162impl NextService {
163    #[allow(dead_code)]
164    pub fn new(next: BoxedNext) -> Self {
165        Self { next }
166    }
167}
168
169impl Clone for NextService {
170    fn clone(&self) -> Self {
171        Self {
172            next: self.next.clone(),
173        }
174    }
175}
176
177impl Service<Request> for NextService {
178    type Response = Response;
179    type Error = std::convert::Infallible;
180    type Future = Pin<Box<dyn Future<Output = Result<Response, Self::Error>> + Send>>;
181
182    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
183        Poll::Ready(Ok(()))
184    }
185
186    fn call(&mut self, req: Request) -> Self::Future {
187        let next = self.next.clone();
188        Box::pin(async move { Ok(next(req).await) })
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::request::Request;
196    use crate::response::Response;
197    use bytes::Bytes;
198    use http::{Extensions, Method, StatusCode};
199    use proptest::prelude::*;
200    use proptest::test_runner::TestCaseError;
201    use std::collections::HashMap;
202
203    /// Create a test request with the given method and path
204    fn create_test_request(method: Method, path: &str) -> Request {
205        let uri: http::Uri = path.parse().unwrap();
206        let builder = http::Request::builder().method(method).uri(uri);
207
208        let req = builder.body(()).unwrap();
209        let (parts, _) = req.into_parts();
210
211        Request::new(
212            parts,
213            Bytes::new(),
214            Arc::new(Extensions::new()),
215            HashMap::new(),
216        )
217    }
218
219    /// A simple test middleware that tracks execution order
220    #[derive(Clone)]
221    struct OrderTrackingMiddleware {
222        id: usize,
223        order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>,
224    }
225
226    impl OrderTrackingMiddleware {
227        fn new(id: usize, order: Arc<std::sync::Mutex<Vec<(usize, &'static str)>>>) -> Self {
228            Self { id, order }
229        }
230    }
231
232    impl MiddlewareLayer for OrderTrackingMiddleware {
233        fn call(
234            &self,
235            req: Request,
236            next: BoxedNext,
237        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
238            let id = self.id;
239            let order = self.order.clone();
240
241            Box::pin(async move {
242                // Record pre-handler execution
243                order.lock().unwrap().push((id, "pre"));
244
245                // Call next
246                let response = next(req).await;
247
248                // Record post-handler execution
249                order.lock().unwrap().push((id, "post"));
250
251                response
252            })
253        }
254
255        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
256            Box::new(self.clone())
257        }
258    }
259
260    /// A middleware that modifies the response status
261    #[derive(Clone)]
262    #[allow(dead_code)]
263    struct StatusModifyingMiddleware {
264        status: StatusCode,
265    }
266
267    #[allow(dead_code)]
268    impl StatusModifyingMiddleware {
269        fn new(status: StatusCode) -> Self {
270            Self { status }
271        }
272    }
273
274    impl MiddlewareLayer for StatusModifyingMiddleware {
275        fn call(
276            &self,
277            req: Request,
278            next: BoxedNext,
279        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
280            let status = self.status;
281
282            Box::pin(async move {
283                let mut response = next(req).await;
284                *response.status_mut() = status;
285                response
286            })
287        }
288
289        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
290            Box::new(self.clone())
291        }
292    }
293
294    // **Feature: phase3-batteries-included, Property 1: Layer application preserves handler behavior**
295    //
296    // For any Tower-compatible layer L and handler H, applying L via `.layer(L)` SHALL result
297    // in requests being processed by L before reaching H, and responses being processed by L
298    // after leaving H.
299    //
300    // **Validates: Requirements 1.1**
301    proptest! {
302        #![proptest_config(ProptestConfig::with_cases(100))]
303
304        #[test]
305        fn prop_layer_application_preserves_handler_behavior(
306            handler_status in 200u16..600u16,
307        ) {
308            let rt = tokio::runtime::Runtime::new().unwrap();
309            let result: Result<(), TestCaseError> = rt.block_on(async {
310                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
311
312                // Create a layer stack with one middleware
313                let mut stack = LayerStack::new();
314                stack.push(Box::new(OrderTrackingMiddleware::new(1, order.clone())));
315
316                // Create a handler that returns the specified status
317                let handler_status = StatusCode::from_u16(handler_status).unwrap_or(StatusCode::OK);
318                let handler: BoxedNext = Arc::new(move |_req: Request| {
319                    let status = handler_status;
320                    Box::pin(async move {
321                        http::Response::builder()
322                            .status(status)
323                            .body(http_body_util::Full::new(Bytes::from("test")))
324                            .unwrap()
325                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
326                });
327
328                // Execute through the stack
329                let request = create_test_request(Method::GET, "/test");
330                let response = stack.execute(request, handler).await;
331
332                // Verify the handler was called (response has expected status)
333                prop_assert_eq!(response.status(), handler_status);
334
335                // Verify middleware executed in correct order (pre before post)
336                let execution_order = order.lock().unwrap();
337                prop_assert_eq!(execution_order.len(), 2);
338                prop_assert_eq!(execution_order[0], (1, "pre"));
339                prop_assert_eq!(execution_order[1], (1, "post"));
340
341                Ok(())
342            });
343            result?;
344        }
345    }
346
347    #[test]
348    fn test_empty_layer_stack_calls_handler_directly() {
349        let rt = tokio::runtime::Runtime::new().unwrap();
350        rt.block_on(async {
351            let stack = LayerStack::new();
352
353            let handler: BoxedNext = Arc::new(|_req: Request| {
354                Box::pin(async {
355                    http::Response::builder()
356                        .status(StatusCode::OK)
357                        .body(http_body_util::Full::new(Bytes::from("direct")))
358                        .unwrap()
359                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
360            });
361
362            let request = create_test_request(Method::GET, "/test");
363            let response = stack.execute(request, handler).await;
364
365            assert_eq!(response.status(), StatusCode::OK);
366        });
367    }
368
369    // **Feature: phase3-batteries-included, Property 2: Middleware execution order**
370    //
371    // For any sequence of layers [L1, L2, ..., Ln] added via `.layer()`, requests SHALL pass
372    // through layers in the order L1 → L2 → ... → Ln → Handler → Ln → ... → L2 → L1
373    // (outermost first on request, innermost first on response).
374    //
375    // **Validates: Requirements 1.2**
376    proptest! {
377        #![proptest_config(ProptestConfig::with_cases(100))]
378
379        #[test]
380        fn prop_middleware_execution_order(
381            num_layers in 1usize..10usize,
382        ) {
383            let rt = tokio::runtime::Runtime::new().unwrap();
384            let result: Result<(), TestCaseError> = rt.block_on(async {
385                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
386
387                // Create a layer stack with multiple middleware
388                let mut stack = LayerStack::new();
389                for i in 0..num_layers {
390                    stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
391                }
392
393                // Create a simple handler
394                let handler: BoxedNext = Arc::new(|_req: Request| {
395                    Box::pin(async {
396                        http::Response::builder()
397                            .status(StatusCode::OK)
398                            .body(http_body_util::Full::new(Bytes::from("test")))
399                            .unwrap()
400                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
401                });
402
403                // Execute through the stack
404                let request = create_test_request(Method::GET, "/test");
405                let _response = stack.execute(request, handler).await;
406
407                // Verify execution order
408                let execution_order = order.lock().unwrap();
409
410                // Should have 2 * num_layers entries (pre and post for each)
411                prop_assert_eq!(execution_order.len(), num_layers * 2);
412
413                // First half should be "pre" in order 0, 1, 2, ... (outermost first)
414                for i in 0..num_layers {
415                    prop_assert_eq!(execution_order[i], (i, "pre"),
416                        "Pre-handler order mismatch at index {}", i);
417                }
418
419                // Second half should be "post" in reverse order n-1, n-2, ..., 0 (innermost first)
420                for i in 0..num_layers {
421                    let expected_id = num_layers - 1 - i;
422                    prop_assert_eq!(execution_order[num_layers + i], (expected_id, "post"),
423                        "Post-handler order mismatch at index {}", i);
424                }
425
426                Ok(())
427            });
428            result?;
429        }
430    }
431
432    /// A middleware that short-circuits with an error response without calling next
433    #[derive(Clone)]
434    struct ShortCircuitMiddleware {
435        error_status: StatusCode,
436        should_short_circuit: bool,
437    }
438
439    impl ShortCircuitMiddleware {
440        fn new(error_status: StatusCode, should_short_circuit: bool) -> Self {
441            Self {
442                error_status,
443                should_short_circuit,
444            }
445        }
446    }
447
448    impl MiddlewareLayer for ShortCircuitMiddleware {
449        fn call(
450            &self,
451            req: Request,
452            next: BoxedNext,
453        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
454            let error_status = self.error_status;
455            let should_short_circuit = self.should_short_circuit;
456
457            Box::pin(async move {
458                if should_short_circuit {
459                    // Return error response without calling next (short-circuit)
460                    http::Response::builder()
461                        .status(error_status)
462                        .body(http_body_util::Full::new(Bytes::from("error")))
463                        .unwrap()
464                } else {
465                    // Continue to next middleware/handler
466                    next(req).await
467                }
468            })
469        }
470
471        fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
472            Box::new(self.clone())
473        }
474    }
475
476    // **Feature: phase3-batteries-included, Property 4: Middleware short-circuit on error**
477    //
478    // For any middleware that returns an error response, the handler SHALL NOT be invoked,
479    // and the error response SHALL be returned directly to the client.
480    //
481    // **Validates: Requirements 1.5**
482    proptest! {
483        #![proptest_config(ProptestConfig::with_cases(100))]
484
485        #[test]
486        fn prop_middleware_short_circuit_on_error(
487            error_status in 400u16..600u16,
488            num_middleware_before in 0usize..5usize,
489            num_middleware_after in 0usize..5usize,
490        ) {
491            let rt = tokio::runtime::Runtime::new().unwrap();
492            let result: Result<(), TestCaseError> = rt.block_on(async {
493                let order = Arc::new(std::sync::Mutex::new(Vec::new()));
494                let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
495
496                // Create a layer stack with middleware before the short-circuit
497                let mut stack = LayerStack::new();
498
499                // Add middleware before the short-circuit middleware
500                for i in 0..num_middleware_before {
501                    stack.push(Box::new(OrderTrackingMiddleware::new(i, order.clone())));
502                }
503
504                // Add the short-circuit middleware
505                let error_status = StatusCode::from_u16(error_status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
506                stack.push(Box::new(ShortCircuitMiddleware::new(error_status, true)));
507
508                // Add middleware after the short-circuit middleware (these should NOT execute pre)
509                for i in 0..num_middleware_after {
510                    stack.push(Box::new(OrderTrackingMiddleware::new(100 + i, order.clone())));
511                }
512
513                // Create a handler that tracks if it was called
514                let handler_called_clone = handler_called.clone();
515                let handler: BoxedNext = Arc::new(move |_req: Request| {
516                    let handler_called = handler_called_clone.clone();
517                    Box::pin(async move {
518                        handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
519                        http::Response::builder()
520                            .status(StatusCode::OK)
521                            .body(http_body_util::Full::new(Bytes::from("handler")))
522                            .unwrap()
523                    }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
524                });
525
526                // Execute through the stack
527                let request = create_test_request(Method::GET, "/test");
528                let response = stack.execute(request, handler).await;
529
530                // Verify the error response was returned
531                prop_assert_eq!(response.status(), error_status,
532                    "Response should have the error status from short-circuit middleware");
533
534                // Verify the handler was NOT called
535                prop_assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst),
536                    "Handler should NOT be called when middleware short-circuits");
537
538                // Verify execution order:
539                // - Middleware before short-circuit should have "pre" recorded
540                // - Middleware after short-circuit should NOT have "pre" recorded (never reached)
541                // - All middleware before short-circuit should have "post" recorded (unwinding)
542                let execution_order = order.lock().unwrap();
543
544                // Count pre and post for middleware before short-circuit
545                let pre_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "pre").count();
546                let post_count = execution_order.iter().filter(|(id, phase)| *id < 100 && *phase == "post").count();
547
548                prop_assert_eq!(pre_count, num_middleware_before,
549                    "All middleware before short-circuit should have pre recorded");
550                prop_assert_eq!(post_count, num_middleware_before,
551                    "All middleware before short-circuit should have post recorded (unwinding)");
552
553                // Middleware after short-circuit should NOT have any entries
554                let after_entries = execution_order.iter().filter(|(id, _)| *id >= 100).count();
555                prop_assert_eq!(after_entries, 0,
556                    "Middleware after short-circuit should NOT be executed");
557
558                Ok(())
559            });
560            result?;
561        }
562    }
563
564    #[test]
565    fn test_short_circuit_returns_error_response() {
566        let rt = tokio::runtime::Runtime::new().unwrap();
567        rt.block_on(async {
568            let mut stack = LayerStack::new();
569            stack.push(Box::new(ShortCircuitMiddleware::new(
570                StatusCode::UNAUTHORIZED,
571                true,
572            )));
573
574            let handler_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
575            let handler_called_clone = handler_called.clone();
576
577            let handler: BoxedNext = Arc::new(move |_req: Request| {
578                let handler_called = handler_called_clone.clone();
579                Box::pin(async move {
580                    handler_called.store(true, std::sync::atomic::Ordering::SeqCst);
581                    http::Response::builder()
582                        .status(StatusCode::OK)
583                        .body(http_body_util::Full::new(Bytes::from("handler")))
584                        .unwrap()
585                }) as Pin<Box<dyn Future<Output = Response> + Send + 'static>>
586            });
587
588            let request = create_test_request(Method::GET, "/test");
589            let response = stack.execute(request, handler).await;
590
591            assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
592            assert!(!handler_called.load(std::sync::atomic::Ordering::SeqCst));
593        });
594    }
595}