turul_http_mcp_server/middleware/
stack.rs1use super::{DispatcherResult, McpMiddleware, MiddlewareError, RequestContext, SessionInjection};
4use std::sync::Arc;
5use turul_mcp_session_storage::SessionView;
6
7#[derive(Default, Clone)]
50pub struct MiddlewareStack {
51 middleware: Vec<Arc<dyn McpMiddleware>>,
52}
53
54impl MiddlewareStack {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn push(&mut self, middleware: Arc<dyn McpMiddleware>) {
71 self.middleware.push(middleware);
72 }
73
74 pub fn len(&self) -> usize {
76 self.middleware.len()
77 }
78
79 pub fn is_empty(&self) -> bool {
81 self.middleware.is_empty()
82 }
83
84 pub fn has_pre_session_middleware(&self) -> bool {
86 self.middleware.iter().any(|m| m.runs_before_session())
87 }
88
89 pub async fn execute_before_session(
94 &self,
95 ctx: &mut RequestContext<'_>,
96 ) -> Result<(), MiddlewareError> {
97 for middleware in &self.middleware {
98 if middleware.runs_before_session() {
99 let mut injection = SessionInjection::new();
100 middleware
101 .before_dispatch(ctx, None, &mut injection)
102 .await?;
103 }
106 }
107 Ok(())
108 }
109
110 pub async fn execute_before(
130 &self,
131 ctx: &mut RequestContext<'_>,
132 session: Option<&dyn SessionView>,
133 ) -> Result<SessionInjection, MiddlewareError> {
134 let mut combined_injection = SessionInjection::new();
135
136 for middleware in &self.middleware {
137 if middleware.runs_before_session() {
139 continue;
140 }
141
142 let mut injection = SessionInjection::new();
143 middleware
144 .before_dispatch(ctx, session, &mut injection)
145 .await?;
146
147 for (key, value) in injection.state() {
149 combined_injection.set_state(key.clone(), value.clone());
150 }
151 for (key, value) in injection.metadata() {
152 combined_injection.set_metadata(key.clone(), value.clone());
153 }
154 }
155
156 Ok(combined_injection)
157 }
158
159 pub async fn execute_after(
177 &self,
178 ctx: &RequestContext<'_>,
179 result: &mut DispatcherResult,
180 ) -> Result<(), MiddlewareError> {
181 for middleware in self.middleware.iter().rev() {
183 middleware.after_dispatch(ctx, result).await?;
184 }
185
186 Ok(())
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use async_trait::async_trait;
194 use serde_json::json;
195
196 struct CountingMiddleware {
197 id: String,
198 counter: Arc<std::sync::Mutex<Vec<String>>>,
199 }
200
201 #[async_trait]
202 impl McpMiddleware for CountingMiddleware {
203 async fn before_dispatch(
204 &self,
205 _ctx: &mut RequestContext<'_>,
206 _session: Option<&dyn SessionView>,
207 injection: &mut SessionInjection,
208 ) -> Result<(), MiddlewareError> {
209 self.counter
210 .lock()
211 .unwrap()
212 .push(format!("before_{}", self.id));
213 injection.set_state(&self.id, json!(true));
214 Ok(())
215 }
216
217 async fn after_dispatch(
218 &self,
219 _ctx: &RequestContext<'_>,
220 _result: &mut DispatcherResult,
221 ) -> Result<(), MiddlewareError> {
222 self.counter
223 .lock()
224 .unwrap()
225 .push(format!("after_{}", self.id));
226 Ok(())
227 }
228 }
229
230 struct ErrorMiddleware {
231 error_on_before: bool,
232 }
233
234 #[async_trait]
235 impl McpMiddleware for ErrorMiddleware {
236 async fn before_dispatch(
237 &self,
238 _ctx: &mut RequestContext<'_>,
239 _session: Option<&dyn SessionView>,
240 _injection: &mut SessionInjection,
241 ) -> Result<(), MiddlewareError> {
242 if self.error_on_before {
243 Err(MiddlewareError::unauthorized("Test error"))
244 } else {
245 Ok(())
246 }
247 }
248 }
249
250 #[tokio::test]
251 async fn test_middleware_execution_order() {
252 let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
253 let mut stack = MiddlewareStack::new();
254
255 stack.push(Arc::new(CountingMiddleware {
256 id: "first".to_string(),
257 counter: counter.clone(),
258 }));
259 stack.push(Arc::new(CountingMiddleware {
260 id: "second".to_string(),
261 counter: counter.clone(),
262 }));
263
264 let mut ctx = RequestContext::new("test/method", None);
265
266 let injection = stack.execute_before(&mut ctx, None).await.unwrap();
268 assert_eq!(injection.state().len(), 2);
269 assert!(injection.state().contains_key("first"));
270 assert!(injection.state().contains_key("second"));
271
272 let mut result = DispatcherResult::Success(json!({"ok": true}));
274 stack.execute_after(&ctx, &mut result).await.unwrap();
275
276 let log = counter.lock().unwrap();
278 assert_eq!(log[0], "before_first");
279 assert_eq!(log[1], "before_second");
280 assert_eq!(log[2], "after_second"); assert_eq!(log[3], "after_first");
282 }
283
284 #[tokio::test]
285 async fn test_middleware_error_stops_chain() {
286 let counter = Arc::new(std::sync::Mutex::new(Vec::new()));
287 let mut stack = MiddlewareStack::new();
288
289 stack.push(Arc::new(CountingMiddleware {
290 id: "first".to_string(),
291 counter: counter.clone(),
292 }));
293 stack.push(Arc::new(ErrorMiddleware {
294 error_on_before: true,
295 }));
296 stack.push(Arc::new(CountingMiddleware {
297 id: "third".to_string(),
298 counter: counter.clone(),
299 }));
300
301 let mut ctx = RequestContext::new("test/method", None);
302
303 let result = stack.execute_before(&mut ctx, None).await;
305 assert!(result.is_err());
306 assert_eq!(
307 result.unwrap_err(),
308 MiddlewareError::unauthorized("Test error")
309 );
310
311 let log = counter.lock().unwrap();
313 assert_eq!(log.len(), 1);
314 assert_eq!(log[0], "before_first");
315 }
316
317 struct ToolFilteringMiddleware;
319
320 #[async_trait]
321 impl McpMiddleware for ToolFilteringMiddleware {
322 async fn before_dispatch(
323 &self,
324 _ctx: &mut RequestContext<'_>,
325 _session: Option<&dyn SessionView>,
326 _injection: &mut SessionInjection,
327 ) -> Result<(), MiddlewareError> {
328 Ok(())
329 }
330
331 async fn after_dispatch(
332 &self,
333 _ctx: &RequestContext<'_>,
334 result: &mut DispatcherResult,
335 ) -> Result<(), MiddlewareError> {
336 if let DispatcherResult::Success(val) = result {
337 if let Some(tools) = val.get_mut("tools") {
338 if let Some(arr) = tools.as_array_mut() {
339 arr.retain(|t| t["name"] != "secret_tool");
340 }
341 }
342 }
343 Ok(())
344 }
345 }
346
347 #[tokio::test]
348 async fn test_after_dispatch_success_mutation_visible() {
349 let mut stack = MiddlewareStack::new();
350 stack.push(Arc::new(ToolFilteringMiddleware));
351
352 let ctx = RequestContext::new("tools/list", None);
353 let mut result = DispatcherResult::Success(json!({
354 "tools": [
355 {"name": "public_tool"},
356 {"name": "secret_tool"},
357 {"name": "another_tool"}
358 ]
359 }));
360
361 stack.execute_after(&ctx, &mut result).await.unwrap();
362
363 let val = result.success().unwrap();
364 let tools = val["tools"].as_array().unwrap();
365 assert_eq!(tools.len(), 2);
366 assert!(tools.iter().all(|t| t["name"] != "secret_tool"));
367 }
368
369 #[tokio::test]
370 async fn test_after_dispatch_success_to_error_mutation_visible() {
371 struct RejectingMiddleware;
372
373 #[async_trait]
374 impl McpMiddleware for RejectingMiddleware {
375 async fn before_dispatch(
376 &self,
377 _ctx: &mut RequestContext<'_>,
378 _session: Option<&dyn SessionView>,
379 _injection: &mut SessionInjection,
380 ) -> Result<(), MiddlewareError> {
381 Ok(())
382 }
383
384 async fn after_dispatch(
385 &self,
386 _ctx: &RequestContext<'_>,
387 result: &mut DispatcherResult,
388 ) -> Result<(), MiddlewareError> {
389 if result.is_success() {
390 *result = DispatcherResult::Error("rejected by policy".to_string());
391 }
392 Ok(())
393 }
394 }
395
396 let mut stack = MiddlewareStack::new();
397 stack.push(Arc::new(RejectingMiddleware));
398
399 let ctx = RequestContext::new("tools/list", None);
400 let mut result = DispatcherResult::Success(json!({"tools": []}));
401
402 stack.execute_after(&ctx, &mut result).await.unwrap();
403
404 assert!(result.is_error());
405 assert_eq!(result.error().unwrap(), "rejected by policy");
406 }
407
408 #[tokio::test]
409 async fn test_empty_stack() {
410 let stack = MiddlewareStack::new();
411 assert!(stack.is_empty());
412 assert_eq!(stack.len(), 0);
413
414 let mut ctx = RequestContext::new("test/method", None);
415
416 let injection = stack.execute_before(&mut ctx, None).await.unwrap();
417 assert!(injection.is_empty());
418
419 let mut result = DispatcherResult::Success(json!({"ok": true}));
420 stack.execute_after(&ctx, &mut result).await.unwrap();
421 }
422}