Skip to main content

turul_http_mcp_server/middleware/
stack.rs

1//! Middleware stack execution
2
3use super::{DispatcherResult, McpMiddleware, MiddlewareError, RequestContext, SessionInjection};
4use std::sync::Arc;
5use turul_mcp_session_storage::SessionView;
6
7/// Ordered collection of middleware with execution logic
8///
9/// The stack executes middleware in two phases:
10///
11/// 1. **Before dispatch**: Middleware execute in registration order
12///    - First error stops the chain
13///    - Session injections accumulate across all middleware
14///
15/// 2. **After dispatch**: Middleware execute in reverse registration order
16///    - Allows proper cleanup/finalization
17///    - Errors replace the result
18///
19/// # Examples
20///
21/// ```rust,no_run
22/// use turul_http_mcp_server::middleware::{MiddlewareStack, McpMiddleware, RequestContext, SessionInjection, MiddlewareError};
23/// use turul_mcp_session_storage::SessionView;
24/// use async_trait::async_trait;
25/// use std::sync::Arc;
26///
27/// struct LoggingMiddleware;
28///
29/// #[async_trait]
30/// impl McpMiddleware for LoggingMiddleware {
31///     async fn before_dispatch(
32///         &self,
33///         ctx: &mut RequestContext<'_>,
34///         _session: Option<&dyn SessionView>,
35///         _injection: &mut SessionInjection,
36///     ) -> Result<(), MiddlewareError> {
37///         println!("Request: {}", ctx.method());
38///         Ok(())
39///     }
40/// }
41///
42/// # async fn example() {
43/// let mut stack = MiddlewareStack::new();
44/// stack.push(Arc::new(LoggingMiddleware));
45///
46/// assert_eq!(stack.len(), 1);
47/// # }
48/// ```
49#[derive(Default, Clone)]
50pub struct MiddlewareStack {
51    middleware: Vec<Arc<dyn McpMiddleware>>,
52}
53
54impl MiddlewareStack {
55    /// Create an empty middleware stack
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// Add middleware to the end of the stack
61    ///
62    /// # Parameters
63    ///
64    /// - `middleware`: Middleware implementation (must be Arc-wrapped for sharing)
65    ///
66    /// # Execution Order
67    ///
68    /// - Before dispatch: First added executes first
69    /// - After dispatch: First added executes last (reverse order)
70    pub fn push(&mut self, middleware: Arc<dyn McpMiddleware>) {
71        self.middleware.push(middleware);
72    }
73
74    /// Get the number of middleware in the stack
75    pub fn len(&self) -> usize {
76        self.middleware.len()
77    }
78
79    /// Check if the stack is empty
80    pub fn is_empty(&self) -> bool {
81        self.middleware.is_empty()
82    }
83
84    /// Execute all middleware before dispatch
85    ///
86    /// # Parameters
87    ///
88    /// - `ctx`: Mutable request context
89    /// - `session`: Optional read-only session view
90    ///   - `None` for `initialize` (session doesn't exist yet)
91    ///   - `Some(session)` for all other methods
92    ///
93    /// # Returns
94    ///
95    /// - `Ok(SessionInjection)`: All middleware succeeded, contains accumulated injections
96    /// - `Err(MiddlewareError)`: First middleware that failed
97    ///
98    /// # Execution
99    ///
100    /// 1. Execute each middleware in registration order
101    /// 2. Accumulate session injections from all middleware
102    /// 3. Stop on first error
103    pub async fn execute_before(
104        &self,
105        ctx: &mut RequestContext<'_>,
106        session: Option<&dyn SessionView>,
107    ) -> Result<SessionInjection, MiddlewareError> {
108        let mut combined_injection = SessionInjection::new();
109
110        for middleware in &self.middleware {
111            let mut injection = SessionInjection::new();
112            middleware
113                .before_dispatch(ctx, session, &mut injection)
114                .await?;
115
116            // Accumulate injections (later middleware can override earlier ones)
117            for (key, value) in injection.state() {
118                combined_injection.set_state(key.clone(), value.clone());
119            }
120            for (key, value) in injection.metadata() {
121                combined_injection.set_metadata(key.clone(), value.clone());
122            }
123        }
124
125        Ok(combined_injection)
126    }
127
128    /// Execute all middleware after dispatch
129    ///
130    /// # Parameters
131    ///
132    /// - `ctx`: Read-only request context
133    /// - `result`: Mutable dispatcher result
134    ///
135    /// # Returns
136    ///
137    /// - `Ok(())`: All middleware succeeded
138    /// - `Err(MiddlewareError)`: First middleware that failed
139    ///
140    /// # Execution
141    ///
142    /// 1. Execute each middleware in reverse registration order
143    /// 2. Stop on first error
144    /// 3. Allow middleware to modify result
145    pub async fn execute_after(
146        &self,
147        ctx: &RequestContext<'_>,
148        result: &mut DispatcherResult,
149    ) -> Result<(), MiddlewareError> {
150        // Execute in reverse order
151        for middleware in self.middleware.iter().rev() {
152            middleware.after_dispatch(ctx, result).await?;
153        }
154
155        Ok(())
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use async_trait::async_trait;
163    use serde_json::json;
164
165    struct CountingMiddleware {
166        id: String,
167        counter: Arc<std::sync::Mutex<Vec<String>>>,
168    }
169
170    #[async_trait]
171    impl McpMiddleware for CountingMiddleware {
172        async fn before_dispatch(
173            &self,
174            _ctx: &mut RequestContext<'_>,
175            _session: Option<&dyn SessionView>,
176            injection: &mut SessionInjection,
177        ) -> Result<(), MiddlewareError> {
178            self.counter
179                .lock()
180                .unwrap()
181                .push(format!("before_{}", self.id));
182            injection.set_state(&self.id, json!(true));
183            Ok(())
184        }
185
186        async fn after_dispatch(
187            &self,
188            _ctx: &RequestContext<'_>,
189            _result: &mut DispatcherResult,
190        ) -> Result<(), MiddlewareError> {
191            self.counter
192                .lock()
193                .unwrap()
194                .push(format!("after_{}", self.id));
195            Ok(())
196        }
197    }
198
199    struct ErrorMiddleware {
200        error_on_before: bool,
201    }
202
203    #[async_trait]
204    impl McpMiddleware for ErrorMiddleware {
205        async fn before_dispatch(
206            &self,
207            _ctx: &mut RequestContext<'_>,
208            _session: Option<&dyn SessionView>,
209            _injection: &mut SessionInjection,
210        ) -> Result<(), MiddlewareError> {
211            if self.error_on_before {
212                Err(MiddlewareError::unauthorized("Test error"))
213            } else {
214                Ok(())
215            }
216        }
217    }
218
219    #[tokio::test]
220    async fn test_middleware_execution_order() {
221        let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
222        let mut stack = MiddlewareStack::new();
223
224        stack.push(Arc::new(CountingMiddleware {
225            id: "first".to_string(),
226            counter: counter.clone(),
227        }));
228        stack.push(Arc::new(CountingMiddleware {
229            id: "second".to_string(),
230            counter: counter.clone(),
231        }));
232
233        let mut ctx = RequestContext::new("test/method", None);
234
235        // Execute before (no session needed for this test)
236        let injection = stack.execute_before(&mut ctx, None).await.unwrap();
237        assert_eq!(injection.state().len(), 2);
238        assert!(injection.state().contains_key("first"));
239        assert!(injection.state().contains_key("second"));
240
241        // Execute after
242        let mut result = DispatcherResult::Success(json!({"ok": true}));
243        stack.execute_after(&ctx, &mut result).await.unwrap();
244
245        // Verify order: before in normal order, after in reverse
246        let log = counter.lock().unwrap();
247        assert_eq!(log[0], "before_first");
248        assert_eq!(log[1], "before_second");
249        assert_eq!(log[2], "after_second"); // Reverse order
250        assert_eq!(log[3], "after_first");
251    }
252
253    #[tokio::test]
254    async fn test_middleware_error_stops_chain() {
255        let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
256        let mut stack = MiddlewareStack::new();
257
258        stack.push(Arc::new(CountingMiddleware {
259            id: "first".to_string(),
260            counter: counter.clone(),
261        }));
262        stack.push(Arc::new(ErrorMiddleware {
263            error_on_before: true,
264        }));
265        stack.push(Arc::new(CountingMiddleware {
266            id: "third".to_string(),
267            counter: counter.clone(),
268        }));
269
270        let mut ctx = RequestContext::new("test/method", None);
271
272        // Execute before - should fail at second middleware (no session needed)
273        let result = stack.execute_before(&mut ctx, None).await;
274        assert!(result.is_err());
275        assert_eq!(
276            result.unwrap_err(),
277            MiddlewareError::unauthorized("Test error")
278        );
279
280        // Verify only first middleware executed
281        let log = counter.lock().unwrap();
282        assert_eq!(log.len(), 1);
283        assert_eq!(log[0], "before_first");
284    }
285
286    #[tokio::test]
287    async fn test_empty_stack() {
288        let stack = MiddlewareStack::new();
289        assert!(stack.is_empty());
290        assert_eq!(stack.len(), 0);
291
292        let mut ctx = RequestContext::new("test/method", None);
293
294        let injection = stack.execute_before(&mut ctx, None).await.unwrap();
295        assert!(injection.is_empty());
296
297        let mut result = DispatcherResult::Success(json!({"ok": true}));
298        stack.execute_after(&ctx, &mut result).await.unwrap();
299    }
300}