Skip to main content

turbomcp_server/middleware/
typed.rs

1//! Typed middleware with per-method hooks.
2//!
3//! This module provides a middleware trait with typed hooks for each MCP operation,
4//! enabling request interception, modification, and short-circuiting.
5
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use turbomcp_core::context::RequestContext;
13use turbomcp_core::error::McpResult;
14use turbomcp_core::handler::McpHandler;
15use turbomcp_types::{
16    Prompt, PromptResult, Resource, ResourceResult, ServerInfo, Tool, ToolResult,
17};
18
19/// Typed middleware trait with hooks for each MCP operation.
20///
21/// Implement this trait to intercept and modify MCP requests and responses.
22/// Each hook receives the request parameters and a `Next` object for calling
23/// the next middleware or the final handler.
24///
25/// # Default Implementations
26///
27/// All hooks have default implementations that simply pass through to the next
28/// middleware. Override only the hooks you need.
29///
30/// # Example
31///
32/// ```rust,ignore
33/// use turbomcp_server::middleware::{McpMiddleware, Next};
34///
35/// struct RateLimitMiddleware {
36///     max_calls_per_minute: u32,
37/// }
38///
39/// impl McpMiddleware for RateLimitMiddleware {
40///     async fn on_call_tool<'a>(
41///         &'a self,
42///         name: &'a str,
43///         args: Value,
44///         ctx: &'a RequestContext,
45///         next: Next<'a>,
46///     ) -> McpResult<ToolResult> {
47///         // Check rate limit
48///         if self.is_rate_limited(ctx) {
49///             return Err(McpError::internal("Rate limit exceeded"));
50///         }
51///         next.call_tool(name, args, ctx).await
52///     }
53/// }
54/// ```
55pub trait McpMiddleware: Send + Sync + 'static {
56    /// Hook called when listing tools.
57    ///
58    /// Can filter, modify, or replace the tool list.
59    fn on_list_tools<'a>(
60        &'a self,
61        next: Next<'a>,
62    ) -> Pin<Box<dyn Future<Output = Vec<Tool>> + Send + 'a>> {
63        Box::pin(async move { next.list_tools() })
64    }
65
66    /// Hook called when listing resources.
67    fn on_list_resources<'a>(
68        &'a self,
69        next: Next<'a>,
70    ) -> Pin<Box<dyn Future<Output = Vec<Resource>> + Send + 'a>> {
71        Box::pin(async move { next.list_resources() })
72    }
73
74    /// Hook called when listing prompts.
75    fn on_list_prompts<'a>(
76        &'a self,
77        next: Next<'a>,
78    ) -> Pin<Box<dyn Future<Output = Vec<Prompt>> + Send + 'a>> {
79        Box::pin(async move { next.list_prompts() })
80    }
81
82    /// Hook called when a tool is invoked.
83    ///
84    /// Can modify arguments, short-circuit with an error, or transform the result.
85    fn on_call_tool<'a>(
86        &'a self,
87        name: &'a str,
88        args: Value,
89        ctx: &'a RequestContext,
90        next: Next<'a>,
91    ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
92        Box::pin(async move { next.call_tool(name, args, ctx).await })
93    }
94
95    /// Hook called when a resource is read.
96    fn on_read_resource<'a>(
97        &'a self,
98        uri: &'a str,
99        ctx: &'a RequestContext,
100        next: Next<'a>,
101    ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
102        Box::pin(async move { next.read_resource(uri, ctx).await })
103    }
104
105    /// Hook called when a prompt is retrieved.
106    fn on_get_prompt<'a>(
107        &'a self,
108        name: &'a str,
109        args: Option<Value>,
110        ctx: &'a RequestContext,
111        next: Next<'a>,
112    ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
113        Box::pin(async move { next.get_prompt(name, args, ctx).await })
114    }
115
116    /// Hook called when the server is initialized.
117    ///
118    /// Can perform setup tasks, validate configuration, or short-circuit
119    /// initialization by returning an error.
120    fn on_initialize<'a>(
121        &'a self,
122        next: Next<'a>,
123    ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
124        Box::pin(async move { next.initialize().await })
125    }
126
127    /// Hook called when the server is shutting down.
128    ///
129    /// Can perform cleanup tasks like flushing buffers or closing connections.
130    fn on_shutdown<'a>(
131        &'a self,
132        next: Next<'a>,
133    ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
134        Box::pin(async move { next.shutdown().await })
135    }
136}
137
138/// Continuation for calling the next middleware or handler.
139///
140/// This struct is passed to each middleware hook and provides methods
141/// to continue processing with the next middleware in the chain.
142pub struct Next<'a> {
143    handler: &'a dyn DynHandler,
144    middlewares: &'a [Arc<dyn McpMiddleware>],
145    index: usize,
146}
147
148impl<'a> std::fmt::Debug for Next<'a> {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("Next")
151            .field("index", &self.index)
152            .field(
153                "remaining_middlewares",
154                &(self.middlewares.len() - self.index),
155            )
156            .finish()
157    }
158}
159
160impl<'a> Next<'a> {
161    fn new(
162        handler: &'a dyn DynHandler,
163        middlewares: &'a [Arc<dyn McpMiddleware>],
164        index: usize,
165    ) -> Self {
166        Self {
167            handler,
168            middlewares,
169            index,
170        }
171    }
172
173    /// List tools from the next middleware or handler.
174    pub fn list_tools(self) -> Vec<Tool> {
175        if self.index < self.middlewares.len() {
176            // Can't easily make this recursive async, so just call handler directly
177            // In a full implementation, we'd use a different pattern
178            self.handler.dyn_list_tools()
179        } else {
180            self.handler.dyn_list_tools()
181        }
182    }
183
184    /// List resources from the next middleware or handler.
185    pub fn list_resources(self) -> Vec<Resource> {
186        self.handler.dyn_list_resources()
187    }
188
189    /// List prompts from the next middleware or handler.
190    pub fn list_prompts(self) -> Vec<Prompt> {
191        self.handler.dyn_list_prompts()
192    }
193
194    /// Call a tool through the next middleware or handler.
195    pub async fn call_tool(
196        self,
197        name: &str,
198        args: Value,
199        ctx: &RequestContext,
200    ) -> McpResult<ToolResult> {
201        if self.index < self.middlewares.len() {
202            let middleware = &self.middlewares[self.index];
203            let next = Next::new(self.handler, self.middlewares, self.index + 1);
204            middleware.on_call_tool(name, args, ctx, next).await
205        } else {
206            self.handler.dyn_call_tool(name, args, ctx).await
207        }
208    }
209
210    /// Read a resource through the next middleware or handler.
211    pub async fn read_resource(self, uri: &str, ctx: &RequestContext) -> McpResult<ResourceResult> {
212        if self.index < self.middlewares.len() {
213            let middleware = &self.middlewares[self.index];
214            let next = Next::new(self.handler, self.middlewares, self.index + 1);
215            middleware.on_read_resource(uri, ctx, next).await
216        } else {
217            self.handler.dyn_read_resource(uri, ctx).await
218        }
219    }
220
221    /// Get a prompt through the next middleware or handler.
222    pub async fn get_prompt(
223        self,
224        name: &str,
225        args: Option<Value>,
226        ctx: &RequestContext,
227    ) -> McpResult<PromptResult> {
228        if self.index < self.middlewares.len() {
229            let middleware = &self.middlewares[self.index];
230            let next = Next::new(self.handler, self.middlewares, self.index + 1);
231            middleware.on_get_prompt(name, args, ctx, next).await
232        } else {
233            self.handler.dyn_get_prompt(name, args, ctx).await
234        }
235    }
236
237    /// Run initialization through the next middleware or handler.
238    pub async fn initialize(self) -> McpResult<()> {
239        if self.index < self.middlewares.len() {
240            let middleware = &self.middlewares[self.index];
241            let next = Next::new(self.handler, self.middlewares, self.index + 1);
242            middleware.on_initialize(next).await
243        } else {
244            self.handler.dyn_on_initialize().await
245        }
246    }
247
248    /// Run shutdown through the next middleware or handler.
249    pub async fn shutdown(self) -> McpResult<()> {
250        if self.index < self.middlewares.len() {
251            let middleware = &self.middlewares[self.index];
252            let next = Next::new(self.handler, self.middlewares, self.index + 1);
253            middleware.on_shutdown(next).await
254        } else {
255            self.handler.dyn_on_shutdown().await
256        }
257    }
258}
259
260/// Internal trait for type-erased handler access.
261trait DynHandler: Send + Sync {
262    fn dyn_server_info(&self) -> ServerInfo;
263    fn dyn_list_tools(&self) -> Vec<Tool>;
264    fn dyn_list_resources(&self) -> Vec<Resource>;
265    fn dyn_list_prompts(&self) -> Vec<Prompt>;
266    fn dyn_call_tool<'a>(
267        &'a self,
268        name: &'a str,
269        args: Value,
270        ctx: &'a RequestContext,
271    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ToolResult>> + Send + 'a>>;
272    fn dyn_read_resource<'a>(
273        &'a self,
274        uri: &'a str,
275        ctx: &'a RequestContext,
276    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ResourceResult>> + Send + 'a>>;
277    fn dyn_get_prompt<'a>(
278        &'a self,
279        name: &'a str,
280        args: Option<Value>,
281        ctx: &'a RequestContext,
282    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<PromptResult>> + Send + 'a>>;
283    fn dyn_on_initialize<'a>(
284        &'a self,
285    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>>;
286    fn dyn_on_shutdown<'a>(
287        &'a self,
288    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>>;
289}
290
291/// Wrapper for type-erased handler access.
292struct HandlerWrapper<H: McpHandler> {
293    handler: H,
294}
295
296impl<H: McpHandler> DynHandler for HandlerWrapper<H> {
297    fn dyn_server_info(&self) -> ServerInfo {
298        self.handler.server_info()
299    }
300
301    fn dyn_list_tools(&self) -> Vec<Tool> {
302        self.handler.list_tools()
303    }
304
305    fn dyn_list_resources(&self) -> Vec<Resource> {
306        self.handler.list_resources()
307    }
308
309    fn dyn_list_prompts(&self) -> Vec<Prompt> {
310        self.handler.list_prompts()
311    }
312
313    fn dyn_call_tool<'a>(
314        &'a self,
315        name: &'a str,
316        args: Value,
317        ctx: &'a RequestContext,
318    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ToolResult>> + Send + 'a>>
319    {
320        Box::pin(self.handler.call_tool(name, args, ctx))
321    }
322
323    fn dyn_read_resource<'a>(
324        &'a self,
325        uri: &'a str,
326        ctx: &'a RequestContext,
327    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ResourceResult>> + Send + 'a>>
328    {
329        Box::pin(self.handler.read_resource(uri, ctx))
330    }
331
332    fn dyn_get_prompt<'a>(
333        &'a self,
334        name: &'a str,
335        args: Option<Value>,
336        ctx: &'a RequestContext,
337    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<PromptResult>> + Send + 'a>>
338    {
339        Box::pin(self.handler.get_prompt(name, args, ctx))
340    }
341
342    fn dyn_on_initialize<'a>(
343        &'a self,
344    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>> {
345        Box::pin(self.handler.on_initialize())
346    }
347
348    fn dyn_on_shutdown<'a>(
349        &'a self,
350    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>> {
351        Box::pin(self.handler.on_shutdown())
352    }
353}
354
355/// A handler wrapped with a middleware stack.
356///
357/// This implements `McpHandler` and runs requests through the middleware chain.
358pub struct MiddlewareStack<H: McpHandler> {
359    handler: Arc<HandlerWrapper<H>>,
360    middlewares: Arc<Vec<Arc<dyn McpMiddleware>>>,
361}
362
363impl<H: McpHandler> Clone for MiddlewareStack<H> {
364    fn clone(&self) -> Self {
365        Self {
366            handler: Arc::clone(&self.handler),
367            middlewares: Arc::clone(&self.middlewares),
368        }
369    }
370}
371
372impl<H: McpHandler> std::fmt::Debug for MiddlewareStack<H> {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        f.debug_struct("MiddlewareStack")
375            .field("middleware_count", &self.middlewares.len())
376            .finish()
377    }
378}
379
380impl<H: McpHandler> MiddlewareStack<H> {
381    /// Create a new middleware stack wrapping the given handler.
382    pub fn new(handler: H) -> Self {
383        Self {
384            handler: Arc::new(HandlerWrapper { handler }),
385            middlewares: Arc::new(Vec::new()),
386        }
387    }
388
389    /// Add a middleware to the stack.
390    ///
391    /// Middlewares are called in the order they are added.
392    #[must_use]
393    pub fn with_middleware<M: McpMiddleware>(mut self, middleware: M) -> Self {
394        let middlewares = Arc::make_mut(&mut self.middlewares);
395        middlewares.push(Arc::new(middleware));
396        self
397    }
398
399    /// Get the number of middlewares in the stack.
400    pub fn middleware_count(&self) -> usize {
401        self.middlewares.len()
402    }
403
404    fn next(&self) -> Next<'_> {
405        Next::new(self.handler.as_ref(), &self.middlewares, 0)
406    }
407}
408
409#[allow(clippy::manual_async_fn)]
410impl<H: McpHandler> McpHandler for MiddlewareStack<H> {
411    fn server_info(&self) -> ServerInfo {
412        self.handler.dyn_server_info()
413    }
414
415    fn list_tools(&self) -> Vec<Tool> {
416        self.handler.dyn_list_tools()
417    }
418
419    fn list_resources(&self) -> Vec<Resource> {
420        self.handler.dyn_list_resources()
421    }
422
423    fn list_prompts(&self) -> Vec<Prompt> {
424        self.handler.dyn_list_prompts()
425    }
426
427    fn call_tool<'a>(
428        &'a self,
429        name: &'a str,
430        args: Value,
431        ctx: &'a RequestContext,
432    ) -> impl std::future::Future<Output = McpResult<ToolResult>> + turbomcp_core::marker::MaybeSend + 'a
433    {
434        async move { self.next().call_tool(name, args, ctx).await }
435    }
436
437    fn read_resource<'a>(
438        &'a self,
439        uri: &'a str,
440        ctx: &'a RequestContext,
441    ) -> impl std::future::Future<Output = McpResult<ResourceResult>>
442    + turbomcp_core::marker::MaybeSend
443    + 'a {
444        async move { self.next().read_resource(uri, ctx).await }
445    }
446
447    fn get_prompt<'a>(
448        &'a self,
449        name: &'a str,
450        args: Option<Value>,
451        ctx: &'a RequestContext,
452    ) -> impl std::future::Future<Output = McpResult<PromptResult>> + turbomcp_core::marker::MaybeSend + 'a
453    {
454        async move { self.next().get_prompt(name, args, ctx).await }
455    }
456
457    fn on_initialize(
458        &self,
459    ) -> impl std::future::Future<Output = McpResult<()>> + turbomcp_core::marker::MaybeSend {
460        async move { self.next().initialize().await }
461    }
462
463    fn on_shutdown(
464        &self,
465    ) -> impl std::future::Future<Output = McpResult<()>> + turbomcp_core::marker::MaybeSend {
466        async move { self.next().shutdown().await }
467    }
468}
469
470#[cfg(test)]
471#[allow(clippy::manual_async_fn)]
472mod tests {
473    use super::*;
474    use std::sync::atomic::{AtomicU32, Ordering};
475    use turbomcp_core::error::McpError;
476    use turbomcp_core::marker::MaybeSend;
477
478    #[derive(Clone)]
479    struct TestHandler;
480
481    impl McpHandler for TestHandler {
482        fn server_info(&self) -> ServerInfo {
483            ServerInfo::new("test", "1.0.0")
484        }
485
486        fn list_tools(&self) -> Vec<Tool> {
487            vec![Tool::new("test_tool", "A test tool")]
488        }
489
490        fn list_resources(&self) -> Vec<Resource> {
491            vec![Resource::new("test://resource", "A test resource")]
492        }
493
494        fn list_prompts(&self) -> Vec<Prompt> {
495            vec![Prompt::new("test_prompt", "A test prompt")]
496        }
497
498        fn call_tool<'a>(
499            &'a self,
500            name: &'a str,
501            _args: Value,
502            _ctx: &'a RequestContext,
503        ) -> impl std::future::Future<Output = McpResult<ToolResult>> + MaybeSend + 'a {
504            async move {
505                match name {
506                    "test_tool" => Ok(ToolResult::text("Test result")),
507                    _ => Err(McpError::tool_not_found(name)),
508                }
509            }
510        }
511
512        fn read_resource<'a>(
513            &'a self,
514            uri: &'a str,
515            _ctx: &'a RequestContext,
516        ) -> impl std::future::Future<Output = McpResult<ResourceResult>> + MaybeSend + 'a {
517            let uri = uri.to_string();
518            async move {
519                if uri == "test://resource" {
520                    Ok(ResourceResult::text(&uri, "Test content"))
521                } else {
522                    Err(McpError::resource_not_found(&uri))
523                }
524            }
525        }
526
527        fn get_prompt<'a>(
528            &'a self,
529            name: &'a str,
530            _args: Option<Value>,
531            _ctx: &'a RequestContext,
532        ) -> impl std::future::Future<Output = McpResult<PromptResult>> + MaybeSend + 'a {
533            let name = name.to_string();
534            async move {
535                if name == "test_prompt" {
536                    Ok(PromptResult::user("Test prompt message"))
537                } else {
538                    Err(McpError::prompt_not_found(&name))
539                }
540            }
541        }
542    }
543
544    /// A simple counting middleware for testing.
545    struct CountingMiddleware {
546        tool_calls: AtomicU32,
547        resource_reads: AtomicU32,
548        prompt_gets: AtomicU32,
549        initializes: AtomicU32,
550        shutdowns: AtomicU32,
551    }
552
553    impl CountingMiddleware {
554        fn new() -> Self {
555            Self {
556                tool_calls: AtomicU32::new(0),
557                resource_reads: AtomicU32::new(0),
558                prompt_gets: AtomicU32::new(0),
559                initializes: AtomicU32::new(0),
560                shutdowns: AtomicU32::new(0),
561            }
562        }
563
564        fn tool_calls(&self) -> u32 {
565            self.tool_calls.load(Ordering::Relaxed)
566        }
567
568        fn resource_reads(&self) -> u32 {
569            self.resource_reads.load(Ordering::Relaxed)
570        }
571
572        fn prompt_gets(&self) -> u32 {
573            self.prompt_gets.load(Ordering::Relaxed)
574        }
575
576        fn initializes(&self) -> u32 {
577            self.initializes.load(Ordering::Relaxed)
578        }
579
580        fn shutdowns(&self) -> u32 {
581            self.shutdowns.load(Ordering::Relaxed)
582        }
583    }
584
585    impl McpMiddleware for CountingMiddleware {
586        fn on_call_tool<'a>(
587            &'a self,
588            name: &'a str,
589            args: Value,
590            ctx: &'a RequestContext,
591            next: Next<'a>,
592        ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
593            Box::pin(async move {
594                self.tool_calls.fetch_add(1, Ordering::Relaxed);
595                next.call_tool(name, args, ctx).await
596            })
597        }
598
599        fn on_read_resource<'a>(
600            &'a self,
601            uri: &'a str,
602            ctx: &'a RequestContext,
603            next: Next<'a>,
604        ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
605            Box::pin(async move {
606                self.resource_reads.fetch_add(1, Ordering::Relaxed);
607                next.read_resource(uri, ctx).await
608            })
609        }
610
611        fn on_get_prompt<'a>(
612            &'a self,
613            name: &'a str,
614            args: Option<Value>,
615            ctx: &'a RequestContext,
616            next: Next<'a>,
617        ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
618            Box::pin(async move {
619                self.prompt_gets.fetch_add(1, Ordering::Relaxed);
620                next.get_prompt(name, args, ctx).await
621            })
622        }
623
624        fn on_initialize<'a>(
625            &'a self,
626            next: Next<'a>,
627        ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
628            Box::pin(async move {
629                self.initializes.fetch_add(1, Ordering::Relaxed);
630                next.initialize().await
631            })
632        }
633
634        fn on_shutdown<'a>(
635            &'a self,
636            next: Next<'a>,
637        ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
638            Box::pin(async move {
639                self.shutdowns.fetch_add(1, Ordering::Relaxed);
640                next.shutdown().await
641            })
642        }
643    }
644
645    /// A middleware that blocks certain tools.
646    struct BlockingMiddleware {
647        blocked_tools: Vec<String>,
648    }
649
650    impl BlockingMiddleware {
651        fn new(blocked: Vec<&str>) -> Self {
652            Self {
653                blocked_tools: blocked.into_iter().map(String::from).collect(),
654            }
655        }
656    }
657
658    impl McpMiddleware for BlockingMiddleware {
659        fn on_call_tool<'a>(
660            &'a self,
661            name: &'a str,
662            args: Value,
663            ctx: &'a RequestContext,
664            next: Next<'a>,
665        ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
666            Box::pin(async move {
667                if self.blocked_tools.contains(&name.to_string()) {
668                    return Err(McpError::internal(format!("Tool '{}' is blocked", name)));
669                }
670                next.call_tool(name, args, ctx).await
671            })
672        }
673    }
674
675    #[test]
676    fn test_middleware_stack_creation() {
677        let stack = MiddlewareStack::new(TestHandler)
678            .with_middleware(CountingMiddleware::new())
679            .with_middleware(BlockingMiddleware::new(vec!["blocked"]));
680
681        assert_eq!(stack.middleware_count(), 2);
682    }
683
684    #[test]
685    fn test_server_info_passthrough() {
686        let stack = MiddlewareStack::new(TestHandler);
687        let info = stack.server_info();
688        assert_eq!(info.name, "test");
689        assert_eq!(info.version, "1.0.0");
690    }
691
692    #[test]
693    fn test_list_tools_passthrough() {
694        let stack = MiddlewareStack::new(TestHandler);
695        let tools = stack.list_tools();
696        assert_eq!(tools.len(), 1);
697        assert_eq!(tools[0].name, "test_tool");
698    }
699
700    #[tokio::test]
701    async fn test_call_tool_through_middleware() {
702        let counting = Arc::new(CountingMiddleware::new());
703        let stack =
704            MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
705
706        let ctx = RequestContext::default();
707        let result = stack
708            .call_tool("test_tool", serde_json::json!({}), &ctx)
709            .await
710            .unwrap();
711
712        assert_eq!(result.first_text(), Some("Test result"));
713        assert_eq!(counting.tool_calls(), 1);
714    }
715
716    #[tokio::test]
717    async fn test_blocking_middleware() {
718        let stack = MiddlewareStack::new(TestHandler)
719            .with_middleware(BlockingMiddleware::new(vec!["test_tool"]));
720
721        let ctx = RequestContext::default();
722        let result = stack
723            .call_tool("test_tool", serde_json::json!({}), &ctx)
724            .await;
725
726        assert!(result.is_err());
727        assert!(result.unwrap_err().to_string().contains("blocked"));
728    }
729
730    #[tokio::test]
731    async fn test_middleware_chain_order() {
732        let counting1 = Arc::new(CountingMiddleware::new());
733        let counting2 = Arc::new(CountingMiddleware::new());
734
735        let stack = MiddlewareStack::new(TestHandler)
736            .with_middleware(CountingClone(counting1.clone()))
737            .with_middleware(CountingClone(counting2.clone()));
738
739        let ctx = RequestContext::default();
740        stack
741            .call_tool("test_tool", serde_json::json!({}), &ctx)
742            .await
743            .unwrap();
744
745        // Both middlewares should be called
746        assert_eq!(counting1.tool_calls(), 1);
747        assert_eq!(counting2.tool_calls(), 1);
748    }
749
750    #[tokio::test]
751    async fn test_read_resource_through_middleware() {
752        let counting = Arc::new(CountingMiddleware::new());
753        let stack =
754            MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
755
756        let ctx = RequestContext::default();
757        let result = stack.read_resource("test://resource", &ctx).await.unwrap();
758
759        assert!(!result.contents.is_empty());
760        assert_eq!(counting.resource_reads(), 1);
761    }
762
763    #[tokio::test]
764    async fn test_get_prompt_through_middleware() {
765        let counting = Arc::new(CountingMiddleware::new());
766        let stack =
767            MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
768
769        let ctx = RequestContext::default();
770        let result = stack.get_prompt("test_prompt", None, &ctx).await.unwrap();
771
772        assert!(!result.messages.is_empty());
773        assert_eq!(counting.prompt_gets(), 1);
774    }
775
776    /// Wrapper to make Arc<CountingMiddleware> work as middleware.
777    struct CountingClone(Arc<CountingMiddleware>);
778
779    impl McpMiddleware for CountingClone {
780        fn on_call_tool<'a>(
781            &'a self,
782            name: &'a str,
783            args: Value,
784            ctx: &'a RequestContext,
785            next: Next<'a>,
786        ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
787            self.0.on_call_tool(name, args, ctx, next)
788        }
789
790        fn on_read_resource<'a>(
791            &'a self,
792            uri: &'a str,
793            ctx: &'a RequestContext,
794            next: Next<'a>,
795        ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
796            self.0.on_read_resource(uri, ctx, next)
797        }
798
799        fn on_get_prompt<'a>(
800            &'a self,
801            name: &'a str,
802            args: Option<Value>,
803            ctx: &'a RequestContext,
804            next: Next<'a>,
805        ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
806            self.0.on_get_prompt(name, args, ctx, next)
807        }
808
809        fn on_initialize<'a>(
810            &'a self,
811            next: Next<'a>,
812        ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
813            self.0.on_initialize(next)
814        }
815
816        fn on_shutdown<'a>(
817            &'a self,
818            next: Next<'a>,
819        ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
820            self.0.on_shutdown(next)
821        }
822    }
823
824    #[tokio::test]
825    async fn test_on_initialize_through_middleware() {
826        let counting = Arc::new(CountingMiddleware::new());
827        let stack =
828            MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
829
830        stack.on_initialize().await.unwrap();
831
832        assert_eq!(counting.initializes(), 1);
833    }
834
835    #[tokio::test]
836    async fn test_on_shutdown_through_middleware() {
837        let counting = Arc::new(CountingMiddleware::new());
838        let stack =
839            MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
840
841        stack.on_shutdown().await.unwrap();
842
843        assert_eq!(counting.shutdowns(), 1);
844    }
845
846    #[tokio::test]
847    async fn test_lifecycle_hooks_chain_through_multiple_middlewares() {
848        let counting1 = Arc::new(CountingMiddleware::new());
849        let counting2 = Arc::new(CountingMiddleware::new());
850
851        let stack = MiddlewareStack::new(TestHandler)
852            .with_middleware(CountingClone(counting1.clone()))
853            .with_middleware(CountingClone(counting2.clone()));
854
855        stack.on_initialize().await.unwrap();
856        stack.on_shutdown().await.unwrap();
857
858        assert_eq!(counting1.initializes(), 1);
859        assert_eq!(counting2.initializes(), 1);
860        assert_eq!(counting1.shutdowns(), 1);
861        assert_eq!(counting2.shutdowns(), 1);
862    }
863
864    /// A middleware that blocks initialization.
865    struct BlockInitMiddleware;
866
867    impl McpMiddleware for BlockInitMiddleware {
868        fn on_initialize<'a>(
869            &'a self,
870            _next: Next<'a>,
871        ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
872            Box::pin(async move { Err(McpError::internal("initialization blocked by middleware")) })
873        }
874    }
875
876    #[tokio::test]
877    async fn test_on_initialize_short_circuit() {
878        let stack = MiddlewareStack::new(TestHandler).with_middleware(BlockInitMiddleware);
879
880        let result = stack.on_initialize().await;
881        assert!(result.is_err());
882        assert!(result.unwrap_err().to_string().contains("blocked"));
883    }
884}