turul_http_mcp_server/middleware/
stack.rs

1//! Middleware stack execution
2
3use super::{McpMiddleware, RequestContext, SessionInjection, DispatcherResult, MiddlewareError};
4use turul_mcp_session_storage::SessionView;
5use std::sync::Arc;
6
7/// Ordered collection of middleware with execution logic
8///
9/// The stack executes middleware in two phases:
10///
11/// 1. **Before dispatch**: Middleware execute in registration order
12///    - First error stops the chain
13///    - Session injections accumulate across all middleware
14///
15/// 2. **After dispatch**: Middleware execute in reverse registration order
16///    - Allows proper cleanup/finalization
17///    - Errors replace the result
18///
19/// # Examples
20///
21/// ```rust,no_run
22/// use turul_http_mcp_server::middleware::{MiddlewareStack, McpMiddleware, RequestContext, SessionInjection, MiddlewareError};
23/// use turul_mcp_session_storage::SessionView;
24/// use async_trait::async_trait;
25/// use std::sync::Arc;
26///
27/// struct LoggingMiddleware;
28///
29/// #[async_trait]
30/// impl McpMiddleware for LoggingMiddleware {
31///     async fn before_dispatch(
32///         &self,
33///         ctx: &mut RequestContext<'_>,
34///         _session: Option<&dyn SessionView>,
35///         _injection: &mut SessionInjection,
36///     ) -> Result<(), MiddlewareError> {
37///         println!("Request: {}", ctx.method());
38///         Ok(())
39///     }
40/// }
41///
42/// # async fn example() {
43/// let mut stack = MiddlewareStack::new();
44/// stack.push(Arc::new(LoggingMiddleware));
45///
46/// assert_eq!(stack.len(), 1);
47/// # }
48/// ```
49#[derive(Default, Clone)]
50pub struct MiddlewareStack {
51    middleware: Vec<Arc<dyn McpMiddleware>>,
52}
53
54impl MiddlewareStack {
55    /// Create an empty middleware stack
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// Add middleware to the end of the stack
61    ///
62    /// # Parameters
63    ///
64    /// - `middleware`: Middleware implementation (must be Arc-wrapped for sharing)
65    ///
66    /// # Execution Order
67    ///
68    /// - Before dispatch: First added executes first
69    /// - After dispatch: First added executes last (reverse order)
70    pub fn push(&mut self, middleware: Arc<dyn McpMiddleware>) {
71        self.middleware.push(middleware);
72    }
73
74    /// Get the number of middleware in the stack
75    pub fn len(&self) -> usize {
76        self.middleware.len()
77    }
78
79    /// Check if the stack is empty
80    pub fn is_empty(&self) -> bool {
81        self.middleware.is_empty()
82    }
83
84    /// Execute all middleware before dispatch
85    ///
86    /// # Parameters
87    ///
88    /// - `ctx`: Mutable request context
89    /// - `session`: Optional read-only session view
90    ///   - `None` for `initialize` (session doesn't exist yet)
91    ///   - `Some(session)` for all other methods
92    ///
93    /// # Returns
94    ///
95    /// - `Ok(SessionInjection)`: All middleware succeeded, contains accumulated injections
96    /// - `Err(MiddlewareError)`: First middleware that failed
97    ///
98    /// # Execution
99    ///
100    /// 1. Execute each middleware in registration order
101    /// 2. Accumulate session injections from all middleware
102    /// 3. Stop on first error
103    pub async fn execute_before(
104        &self,
105        ctx: &mut RequestContext<'_>,
106        session: Option<&dyn SessionView>,
107    ) -> Result<SessionInjection, MiddlewareError> {
108        let mut combined_injection = SessionInjection::new();
109
110        for middleware in &self.middleware {
111            let mut injection = SessionInjection::new();
112            middleware.before_dispatch(ctx, session, &mut injection).await?;
113
114            // Accumulate injections (later middleware can override earlier ones)
115            for (key, value) in injection.state() {
116                combined_injection.set_state(key.clone(), value.clone());
117            }
118            for (key, value) in injection.metadata() {
119                combined_injection.set_metadata(key.clone(), value.clone());
120            }
121        }
122
123        Ok(combined_injection)
124    }
125
126    /// Execute all middleware after dispatch
127    ///
128    /// # Parameters
129    ///
130    /// - `ctx`: Read-only request context
131    /// - `result`: Mutable dispatcher result
132    ///
133    /// # Returns
134    ///
135    /// - `Ok(())`: All middleware succeeded
136    /// - `Err(MiddlewareError)`: First middleware that failed
137    ///
138    /// # Execution
139    ///
140    /// 1. Execute each middleware in reverse registration order
141    /// 2. Stop on first error
142    /// 3. Allow middleware to modify result
143    pub async fn execute_after(
144        &self,
145        ctx: &RequestContext<'_>,
146        result: &mut DispatcherResult,
147    ) -> Result<(), MiddlewareError> {
148        // Execute in reverse order
149        for middleware in self.middleware.iter().rev() {
150            middleware.after_dispatch(ctx, result).await?;
151        }
152
153        Ok(())
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use async_trait::async_trait;
161    use serde_json::json;
162
163    struct CountingMiddleware {
164        id: String,
165        counter: Arc<std::sync::Mutex<Vec<String>>>,
166    }
167
168    #[async_trait]
169    impl McpMiddleware for CountingMiddleware {
170        async fn before_dispatch(
171            &self,
172            _ctx: &mut RequestContext<'_>,
173            _session: Option<&dyn SessionView>,
174            injection: &mut SessionInjection,
175        ) -> Result<(), MiddlewareError> {
176            self.counter.lock().unwrap().push(format!("before_{}", self.id));
177            injection.set_state(&self.id, json!(true));
178            Ok(())
179        }
180
181        async fn after_dispatch(
182            &self,
183            _ctx: &RequestContext<'_>,
184            _result: &mut DispatcherResult,
185        ) -> Result<(), MiddlewareError> {
186            self.counter.lock().unwrap().push(format!("after_{}", self.id));
187            Ok(())
188        }
189    }
190
191    struct ErrorMiddleware {
192        error_on_before: bool,
193    }
194
195    #[async_trait]
196    impl McpMiddleware for ErrorMiddleware {
197        async fn before_dispatch(
198            &self,
199            _ctx: &mut RequestContext<'_>,
200            _session: Option<&dyn SessionView>,
201            _injection: &mut SessionInjection,
202        ) -> Result<(), MiddlewareError> {
203            if self.error_on_before {
204                Err(MiddlewareError::unauthorized("Test error"))
205            } else {
206                Ok(())
207            }
208        }
209    }
210
211    #[tokio::test]
212    async fn test_middleware_execution_order() {
213        let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
214        let mut stack = MiddlewareStack::new();
215
216        stack.push(Arc::new(CountingMiddleware {
217            id: "first".to_string(),
218            counter: counter.clone(),
219        }));
220        stack.push(Arc::new(CountingMiddleware {
221            id: "second".to_string(),
222            counter: counter.clone(),
223        }));
224
225        let mut ctx = RequestContext::new("test/method", None);
226
227        // Execute before (no session needed for this test)
228        let injection = stack.execute_before(&mut ctx, None).await.unwrap();
229        assert_eq!(injection.state().len(), 2);
230        assert!(injection.state().contains_key("first"));
231        assert!(injection.state().contains_key("second"));
232
233        // Execute after
234        let mut result = DispatcherResult::Success(json!({"ok": true}));
235        stack.execute_after(&ctx, &mut result).await.unwrap();
236
237        // Verify order: before in normal order, after in reverse
238        let log = counter.lock().unwrap();
239        assert_eq!(log[0], "before_first");
240        assert_eq!(log[1], "before_second");
241        assert_eq!(log[2], "after_second"); // Reverse order
242        assert_eq!(log[3], "after_first");
243    }
244
245    #[tokio::test]
246    async fn test_middleware_error_stops_chain() {
247        let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
248        let mut stack = MiddlewareStack::new();
249
250        stack.push(Arc::new(CountingMiddleware {
251            id: "first".to_string(),
252            counter: counter.clone(),
253        }));
254        stack.push(Arc::new(ErrorMiddleware {
255            error_on_before: true,
256        }));
257        stack.push(Arc::new(CountingMiddleware {
258            id: "third".to_string(),
259            counter: counter.clone(),
260        }));
261
262        let mut ctx = RequestContext::new("test/method", None);
263
264        // Execute before - should fail at second middleware (no session needed)
265        let result = stack.execute_before(&mut ctx, None).await;
266        assert!(result.is_err());
267        assert_eq!(result.unwrap_err(), MiddlewareError::unauthorized("Test error"));
268
269        // Verify only first middleware executed
270        let log = counter.lock().unwrap();
271        assert_eq!(log.len(), 1);
272        assert_eq!(log[0], "before_first");
273    }
274
275    #[tokio::test]
276    async fn test_empty_stack() {
277        let stack = MiddlewareStack::new();
278        assert!(stack.is_empty());
279        assert_eq!(stack.len(), 0);
280
281        let mut ctx = RequestContext::new("test/method", None);
282
283        let injection = stack.execute_before(&mut ctx, None).await.unwrap();
284        assert!(injection.is_empty());
285
286        let mut result = DispatcherResult::Success(json!({"ok": true}));
287        stack.execute_after(&ctx, &mut result).await.unwrap();
288    }
289}