turul_http_mcp_server/middleware/
stack.rs1use super::{DispatcherResult, McpMiddleware, MiddlewareError, RequestContext, SessionInjection};
4use std::sync::Arc;
5use turul_mcp_session_storage::SessionView;
6
7#[derive(Default, Clone)]
50pub struct MiddlewareStack {
51 middleware: Vec<Arc<dyn McpMiddleware>>,
52}
53
54impl MiddlewareStack {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn push(&mut self, middleware: Arc<dyn McpMiddleware>) {
71 self.middleware.push(middleware);
72 }
73
74 pub fn len(&self) -> usize {
76 self.middleware.len()
77 }
78
79 pub fn is_empty(&self) -> bool {
81 self.middleware.is_empty()
82 }
83
84 pub fn has_pre_session_middleware(&self) -> bool {
86 self.middleware.iter().any(|m| m.runs_before_session())
87 }
88
89 pub async fn execute_before_session(
94 &self,
95 ctx: &mut RequestContext<'_>,
96 ) -> Result<(), MiddlewareError> {
97 for middleware in &self.middleware {
98 if middleware.runs_before_session() {
99 let mut injection = SessionInjection::new();
100 middleware
101 .before_dispatch(ctx, None, &mut injection)
102 .await?;
103 }
106 }
107 Ok(())
108 }
109
110 pub async fn execute_before(
130 &self,
131 ctx: &mut RequestContext<'_>,
132 session: Option<&dyn SessionView>,
133 ) -> Result<SessionInjection, MiddlewareError> {
134 let mut combined_injection = SessionInjection::new();
135
136 for middleware in &self.middleware {
137 if middleware.runs_before_session() {
139 continue;
140 }
141
142 let mut injection = SessionInjection::new();
143 middleware
144 .before_dispatch(ctx, session, &mut injection)
145 .await?;
146
147 for (key, value) in injection.state() {
149 combined_injection.set_state(key.clone(), value.clone());
150 }
151 for (key, value) in injection.metadata() {
152 combined_injection.set_metadata(key.clone(), value.clone());
153 }
154 }
155
156 Ok(combined_injection)
157 }
158
159 pub async fn execute_after(
177 &self,
178 ctx: &RequestContext<'_>,
179 result: &mut DispatcherResult,
180 ) -> Result<(), MiddlewareError> {
181 for middleware in self.middleware.iter().rev() {
183 middleware.after_dispatch(ctx, result).await?;
184 }
185
186 Ok(())
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use async_trait::async_trait;
194 use serde_json::json;
195
196 struct CountingMiddleware {
197 id: String,
198 counter: Arc<std::sync::Mutex<Vec<String>>>,
199 }
200
201 #[async_trait]
202 impl McpMiddleware for CountingMiddleware {
203 async fn before_dispatch(
204 &self,
205 _ctx: &mut RequestContext<'_>,
206 _session: Option<&dyn SessionView>,
207 injection: &mut SessionInjection,
208 ) -> Result<(), MiddlewareError> {
209 self.counter
210 .lock()
211 .unwrap()
212 .push(format!("before_{}", self.id));
213 injection.set_state(&self.id, json!(true));
214 Ok(())
215 }
216
217 async fn after_dispatch(
218 &self,
219 _ctx: &RequestContext<'_>,
220 _result: &mut DispatcherResult,
221 ) -> Result<(), MiddlewareError> {
222 self.counter
223 .lock()
224 .unwrap()
225 .push(format!("after_{}", self.id));
226 Ok(())
227 }
228 }
229
230 struct ErrorMiddleware {
231 error_on_before: bool,
232 }
233
234 #[async_trait]
235 impl McpMiddleware for ErrorMiddleware {
236 async fn before_dispatch(
237 &self,
238 _ctx: &mut RequestContext<'_>,
239 _session: Option<&dyn SessionView>,
240 _injection: &mut SessionInjection,
241 ) -> Result<(), MiddlewareError> {
242 if self.error_on_before {
243 Err(MiddlewareError::unauthorized("Test error"))
244 } else {
245 Ok(())
246 }
247 }
248 }
249
250 #[tokio::test]
251 async fn test_middleware_execution_order() {
252 let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
253 let mut stack = MiddlewareStack::new();
254
255 stack.push(Arc::new(CountingMiddleware {
256 id: "first".to_string(),
257 counter: counter.clone(),
258 }));
259 stack.push(Arc::new(CountingMiddleware {
260 id: "second".to_string(),
261 counter: counter.clone(),
262 }));
263
264 let mut ctx = RequestContext::new("test/method", None);
265
266 let injection = stack.execute_before(&mut ctx, None).await.unwrap();
268 assert_eq!(injection.state().len(), 2);
269 assert!(injection.state().contains_key("first"));
270 assert!(injection.state().contains_key("second"));
271
272 let mut result = DispatcherResult::Success(json!({"ok": true}));
274 stack.execute_after(&ctx, &mut result).await.unwrap();
275
276 let log = counter.lock().unwrap();
278 assert_eq!(log[0], "before_first");
279 assert_eq!(log[1], "before_second");
280 assert_eq!(log[2], "after_second"); assert_eq!(log[3], "after_first");
282 }
283
284 #[tokio::test]
285 async fn test_middleware_error_stops_chain() {
286 let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
287 let mut stack = MiddlewareStack::new();
288
289 stack.push(Arc::new(CountingMiddleware {
290 id: "first".to_string(),
291 counter: counter.clone(),
292 }));
293 stack.push(Arc::new(ErrorMiddleware {
294 error_on_before: true,
295 }));
296 stack.push(Arc::new(CountingMiddleware {
297 id: "third".to_string(),
298 counter: counter.clone(),
299 }));
300
301 let mut ctx = RequestContext::new("test/method", None);
302
303 let result = stack.execute_before(&mut ctx, None).await;
305 assert!(result.is_err());
306 assert_eq!(
307 result.unwrap_err(),
308 MiddlewareError::unauthorized("Test error")
309 );
310
311 let log = counter.lock().unwrap();
313 assert_eq!(log.len(), 1);
314 assert_eq!(log[0], "before_first");
315 }
316
317 #[tokio::test]
318 async fn test_empty_stack() {
319 let stack = MiddlewareStack::new();
320 assert!(stack.is_empty());
321 assert_eq!(stack.len(), 0);
322
323 let mut ctx = RequestContext::new("test/method", None);
324
325 let injection = stack.execute_before(&mut ctx, None).await.unwrap();
326 assert!(injection.is_empty());
327
328 let mut result = DispatcherResult::Success(json!({"ok": true}));
329 stack.execute_after(&ctx, &mut result).await.unwrap();
330 }
331}