Skip to main content

turul_http_mcp_server/middleware/
stack.rs

1//! Middleware stack execution
2
3use super::{DispatcherResult, McpMiddleware, MiddlewareError, RequestContext, SessionInjection};
4use std::sync::Arc;
5use turul_mcp_session_storage::SessionView;
6
7/// Ordered collection of middleware with execution logic
8///
9/// The stack executes middleware in two phases:
10///
11/// 1. **Before dispatch**: Middleware execute in registration order
12///    - First error stops the chain
13///    - Session injections accumulate across all middleware
14///
15/// 2. **After dispatch**: Middleware execute in reverse registration order
16///    - Allows proper cleanup/finalization
17///    - Errors replace the result
18///
19/// # Examples
20///
21/// ```rust,no_run
22/// use turul_http_mcp_server::middleware::{MiddlewareStack, McpMiddleware, RequestContext, SessionInjection, MiddlewareError};
23/// use turul_mcp_session_storage::SessionView;
24/// use async_trait::async_trait;
25/// use std::sync::Arc;
26///
27/// struct LoggingMiddleware;
28///
29/// #[async_trait]
30/// impl McpMiddleware for LoggingMiddleware {
31///     async fn before_dispatch(
32///         &self,
33///         ctx: &mut RequestContext<'_>,
34///         _session: Option<&dyn SessionView>,
35///         _injection: &mut SessionInjection,
36///     ) -> Result<(), MiddlewareError> {
37///         println!("Request: {}", ctx.method());
38///         Ok(())
39///     }
40/// }
41///
42/// # async fn example() {
43/// let mut stack = MiddlewareStack::new();
44/// stack.push(Arc::new(LoggingMiddleware));
45///
46/// assert_eq!(stack.len(), 1);
47/// # }
48/// ```
49#[derive(Default, Clone)]
50pub struct MiddlewareStack {
51    middleware: Vec<Arc<dyn McpMiddleware>>,
52}
53
54impl MiddlewareStack {
55    /// Create an empty middleware stack
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// Add middleware to the end of the stack
61    ///
62    /// # Parameters
63    ///
64    /// - `middleware`: Middleware implementation (must be Arc-wrapped for sharing)
65    ///
66    /// # Execution Order
67    ///
68    /// - Before dispatch: First added executes first
69    /// - After dispatch: First added executes last (reverse order)
70    pub fn push(&mut self, middleware: Arc<dyn McpMiddleware>) {
71        self.middleware.push(middleware);
72    }
73
74    /// Get the number of middleware in the stack
75    pub fn len(&self) -> usize {
76        self.middleware.len()
77    }
78
79    /// Check if the stack is empty
80    pub fn is_empty(&self) -> bool {
81        self.middleware.is_empty()
82    }
83
84    /// Check if any middleware in the stack runs before session creation
85    pub fn has_pre_session_middleware(&self) -> bool {
86        self.middleware.iter().any(|m| m.runs_before_session())
87    }
88
89    /// Execute only pre-session middleware (those with `runs_before_session() == true`)
90    ///
91    /// Called by the transport layer before session lookup/creation.
92    /// Session is always `None` in this phase.
93    pub async fn execute_before_session(
94        &self,
95        ctx: &mut RequestContext<'_>,
96    ) -> Result<(), MiddlewareError> {
97        for middleware in &self.middleware {
98            if middleware.runs_before_session() {
99                let mut injection = SessionInjection::new();
100                middleware
101                    .before_dispatch(ctx, None, &mut injection)
102                    .await?;
103                // Pre-session injections are intentionally discarded —
104                // session doesn't exist yet. Use ctx.extensions instead.
105            }
106        }
107        Ok(())
108    }
109
110    /// Execute all middleware before dispatch
111    ///
112    /// # Parameters
113    ///
114    /// - `ctx`: Mutable request context
115    /// - `session`: Optional read-only session view
116    ///   - `None` for `initialize` (session doesn't exist yet)
117    ///   - `Some(session)` for all other methods
118    ///
119    /// # Returns
120    ///
121    /// - `Ok(SessionInjection)`: All middleware succeeded, contains accumulated injections
122    /// - `Err(MiddlewareError)`: First middleware that failed
123    ///
124    /// # Execution
125    ///
126    /// 1. Execute each middleware in registration order
127    /// 2. Accumulate session injections from all middleware
128    /// 3. Stop on first error
129    pub async fn execute_before(
130        &self,
131        ctx: &mut RequestContext<'_>,
132        session: Option<&dyn SessionView>,
133    ) -> Result<SessionInjection, MiddlewareError> {
134        let mut combined_injection = SessionInjection::new();
135
136        for middleware in &self.middleware {
137            // Skip pre-session middleware — they already ran in execute_before_session()
138            if middleware.runs_before_session() {
139                continue;
140            }
141
142            let mut injection = SessionInjection::new();
143            middleware
144                .before_dispatch(ctx, session, &mut injection)
145                .await?;
146
147            // Accumulate injections (later middleware can override earlier ones)
148            for (key, value) in injection.state() {
149                combined_injection.set_state(key.clone(), value.clone());
150            }
151            for (key, value) in injection.metadata() {
152                combined_injection.set_metadata(key.clone(), value.clone());
153            }
154        }
155
156        Ok(combined_injection)
157    }
158
159    /// Execute all middleware after dispatch
160    ///
161    /// # Parameters
162    ///
163    /// - `ctx`: Read-only request context
164    /// - `result`: Mutable dispatcher result
165    ///
166    /// # Returns
167    ///
168    /// - `Ok(())`: All middleware succeeded
169    /// - `Err(MiddlewareError)`: First middleware that failed
170    ///
171    /// # Execution
172    ///
173    /// 1. Execute each middleware in reverse registration order
174    /// 2. Stop on first error
175    /// 3. Allow middleware to modify result
176    pub async fn execute_after(
177        &self,
178        ctx: &RequestContext<'_>,
179        result: &mut DispatcherResult,
180    ) -> Result<(), MiddlewareError> {
181        // Execute in reverse order
182        for middleware in self.middleware.iter().rev() {
183            middleware.after_dispatch(ctx, result).await?;
184        }
185
186        Ok(())
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use async_trait::async_trait;
194    use serde_json::json;
195
196    struct CountingMiddleware {
197        id: String,
198        counter: Arc<std::sync::Mutex<Vec<String>>>,
199    }
200
201    #[async_trait]
202    impl McpMiddleware for CountingMiddleware {
203        async fn before_dispatch(
204            &self,
205            _ctx: &mut RequestContext<'_>,
206            _session: Option<&dyn SessionView>,
207            injection: &mut SessionInjection,
208        ) -> Result<(), MiddlewareError> {
209            self.counter
210                .lock()
211                .unwrap()
212                .push(format!("before_{}", self.id));
213            injection.set_state(&self.id, json!(true));
214            Ok(())
215        }
216
217        async fn after_dispatch(
218            &self,
219            _ctx: &RequestContext<'_>,
220            _result: &mut DispatcherResult,
221        ) -> Result<(), MiddlewareError> {
222            self.counter
223                .lock()
224                .unwrap()
225                .push(format!("after_{}", self.id));
226            Ok(())
227        }
228    }
229
230    struct ErrorMiddleware {
231        error_on_before: bool,
232    }
233
234    #[async_trait]
235    impl McpMiddleware for ErrorMiddleware {
236        async fn before_dispatch(
237            &self,
238            _ctx: &mut RequestContext<'_>,
239            _session: Option<&dyn SessionView>,
240            _injection: &mut SessionInjection,
241        ) -> Result<(), MiddlewareError> {
242            if self.error_on_before {
243                Err(MiddlewareError::unauthorized("Test error"))
244            } else {
245                Ok(())
246            }
247        }
248    }
249
250    #[tokio::test]
251    async fn test_middleware_execution_order() {
252        let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
253        let mut stack = MiddlewareStack::new();
254
255        stack.push(Arc::new(CountingMiddleware {
256            id: "first".to_string(),
257            counter: counter.clone(),
258        }));
259        stack.push(Arc::new(CountingMiddleware {
260            id: "second".to_string(),
261            counter: counter.clone(),
262        }));
263
264        let mut ctx = RequestContext::new("test/method", None);
265
266        // Execute before (no session needed for this test)
267        let injection = stack.execute_before(&mut ctx, None).await.unwrap();
268        assert_eq!(injection.state().len(), 2);
269        assert!(injection.state().contains_key("first"));
270        assert!(injection.state().contains_key("second"));
271
272        // Execute after
273        let mut result = DispatcherResult::Success(json!({"ok": true}));
274        stack.execute_after(&ctx, &mut result).await.unwrap();
275
276        // Verify order: before in normal order, after in reverse
277        let log = counter.lock().unwrap();
278        assert_eq!(log[0], "before_first");
279        assert_eq!(log[1], "before_second");
280        assert_eq!(log[2], "after_second"); // Reverse order
281        assert_eq!(log[3], "after_first");
282    }
283
284    #[tokio::test]
285    async fn test_middleware_error_stops_chain() {
286        let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
287        let mut stack = MiddlewareStack::new();
288
289        stack.push(Arc::new(CountingMiddleware {
290            id: "first".to_string(),
291            counter: counter.clone(),
292        }));
293        stack.push(Arc::new(ErrorMiddleware {
294            error_on_before: true,
295        }));
296        stack.push(Arc::new(CountingMiddleware {
297            id: "third".to_string(),
298            counter: counter.clone(),
299        }));
300
301        let mut ctx = RequestContext::new("test/method", None);
302
303        // Execute before - should fail at second middleware (no session needed)
304        let result = stack.execute_before(&mut ctx, None).await;
305        assert!(result.is_err());
306        assert_eq!(
307            result.unwrap_err(),
308            MiddlewareError::unauthorized("Test error")
309        );
310
311        // Verify only first middleware executed
312        let log = counter.lock().unwrap();
313        assert_eq!(log.len(), 1);
314        assert_eq!(log[0], "before_first");
315    }
316
317    #[tokio::test]
318    async fn test_empty_stack() {
319        let stack = MiddlewareStack::new();
320        assert!(stack.is_empty());
321        assert_eq!(stack.len(), 0);
322
323        let mut ctx = RequestContext::new("test/method", None);
324
325        let injection = stack.execute_before(&mut ctx, None).await.unwrap();
326        assert!(injection.is_empty());
327
328        let mut result = DispatcherResult::Success(json!({"ok": true}));
329        stack.execute_after(&ctx, &mut result).await.unwrap();
330    }
331}