Skip to main content

turbomcp_server/middleware/
typed.rs

1//! Typed middleware with per-method hooks.
2//!
3//! This module provides a middleware trait with typed hooks for each MCP operation,
4//! enabling request interception, modification, and short-circuiting.
5
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use turbomcp_core::context::RequestContext;
13use turbomcp_core::error::McpResult;
14use turbomcp_core::handler::McpHandler;
15use turbomcp_types::{
16    Prompt, PromptResult, Resource, ResourceResult, ServerInfo, Tool, ToolResult,
17};
18
19/// Typed middleware trait with hooks for each MCP operation.
20///
21/// Implement this trait to intercept and modify MCP requests and responses.
22/// Each hook receives the request parameters and a `Next` object for calling
23/// the next middleware or the final handler.
24///
25/// # Default Implementations
26///
27/// All hooks have default implementations that simply pass through to the next
28/// middleware. Override only the hooks you need.
29///
30/// # Example
31///
32/// ```rust,ignore
33/// use turbomcp_server::middleware::{McpMiddleware, Next};
34///
35/// struct RateLimitMiddleware {
36///     max_calls_per_minute: u32,
37/// }
38///
39/// impl McpMiddleware for RateLimitMiddleware {
40///     async fn on_call_tool<'a>(
41///         &'a self,
42///         name: &'a str,
43///         args: Value,
44///         ctx: &'a RequestContext,
45///         next: Next<'a>,
46///     ) -> McpResult<ToolResult> {
47///         // Check rate limit
48///         if self.is_rate_limited(ctx) {
49///             return Err(McpError::internal("Rate limit exceeded"));
50///         }
51///         next.call_tool(name, args, ctx).await
52///     }
53/// }
54/// ```
55pub trait McpMiddleware: Send + Sync + 'static {
56    /// Hook called when listing tools.
57    ///
58    /// Can filter, modify, or replace the tool list.
59    fn on_list_tools<'a>(
60        &'a self,
61        next: Next<'a>,
62    ) -> Pin<Box<dyn Future<Output = Vec<Tool>> + Send + 'a>> {
63        Box::pin(async move { next.list_tools() })
64    }
65
66    /// Hook called when listing resources.
67    fn on_list_resources<'a>(
68        &'a self,
69        next: Next<'a>,
70    ) -> Pin<Box<dyn Future<Output = Vec<Resource>> + Send + 'a>> {
71        Box::pin(async move { next.list_resources() })
72    }
73
74    /// Hook called when listing prompts.
75    fn on_list_prompts<'a>(
76        &'a self,
77        next: Next<'a>,
78    ) -> Pin<Box<dyn Future<Output = Vec<Prompt>> + Send + 'a>> {
79        Box::pin(async move { next.list_prompts() })
80    }
81
82    /// Hook called when a tool is invoked.
83    ///
84    /// Can modify arguments, short-circuit with an error, or transform the result.
85    fn on_call_tool<'a>(
86        &'a self,
87        name: &'a str,
88        args: Value,
89        ctx: &'a RequestContext,
90        next: Next<'a>,
91    ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
92        Box::pin(async move { next.call_tool(name, args, ctx).await })
93    }
94
95    /// Hook called when a resource is read.
96    fn on_read_resource<'a>(
97        &'a self,
98        uri: &'a str,
99        ctx: &'a RequestContext,
100        next: Next<'a>,
101    ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
102        Box::pin(async move { next.read_resource(uri, ctx).await })
103    }
104
105    /// Hook called when a prompt is retrieved.
106    fn on_get_prompt<'a>(
107        &'a self,
108        name: &'a str,
109        args: Option<Value>,
110        ctx: &'a RequestContext,
111        next: Next<'a>,
112    ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
113        Box::pin(async move { next.get_prompt(name, args, ctx).await })
114    }
115
116    /// Hook called when the server is initialized.
117    ///
118    /// Can perform setup tasks, validate configuration, or short-circuit
119    /// initialization by returning an error.
120    fn on_initialize<'a>(
121        &'a self,
122        next: Next<'a>,
123    ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
124        Box::pin(async move { next.initialize().await })
125    }
126
127    /// Hook called when the server is shutting down.
128    ///
129    /// Can perform cleanup tasks like flushing buffers or closing connections.
130    fn on_shutdown<'a>(
131        &'a self,
132        next: Next<'a>,
133    ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
134        Box::pin(async move { next.shutdown().await })
135    }
136}
137
138/// Continuation for calling the next middleware or handler.
139///
140/// This struct is passed to each middleware hook and provides methods
141/// to continue processing with the next middleware in the chain.
142pub struct Next<'a> {
143    handler: &'a dyn DynHandler,
144    middlewares: &'a [Arc<dyn McpMiddleware>],
145    index: usize,
146}
147
148impl<'a> std::fmt::Debug for Next<'a> {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("Next")
151            .field("index", &self.index)
152            .field(
153                "remaining_middlewares",
154                &(self.middlewares.len() - self.index),
155            )
156            .finish()
157    }
158}
159
160impl<'a> Next<'a> {
161    fn new(
162        handler: &'a dyn DynHandler,
163        middlewares: &'a [Arc<dyn McpMiddleware>],
164        index: usize,
165    ) -> Self {
166        Self {
167            handler,
168            middlewares,
169            index,
170        }
171    }
172
173    /// List tools from the next middleware or handler.
174    pub fn list_tools(self) -> Vec<Tool> {
175        if self.index < self.middlewares.len() {
176            // Can't easily make this recursive async, so just call handler directly
177            // In a full implementation, we'd use a different pattern
178            self.handler.dyn_list_tools()
179        } else {
180            self.handler.dyn_list_tools()
181        }
182    }
183
184    /// List resources from the next middleware or handler.
185    pub fn list_resources(self) -> Vec<Resource> {
186        self.handler.dyn_list_resources()
187    }
188
189    /// List prompts from the next middleware or handler.
190    pub fn list_prompts(self) -> Vec<Prompt> {
191        self.handler.dyn_list_prompts()
192    }
193
194    /// Call a tool through the next middleware or handler.
195    pub async fn call_tool(
196        self,
197        name: &str,
198        args: Value,
199        ctx: &RequestContext,
200    ) -> McpResult<ToolResult> {
201        if self.index < self.middlewares.len() {
202            let middleware = &self.middlewares[self.index];
203            let next = Next::new(self.handler, self.middlewares, self.index + 1);
204            middleware.on_call_tool(name, args, ctx, next).await
205        } else {
206            self.handler.dyn_call_tool(name, args, ctx).await
207        }
208    }
209
210    /// Read a resource through the next middleware or handler.
211    pub async fn read_resource(self, uri: &str, ctx: &RequestContext) -> McpResult<ResourceResult> {
212        if self.index < self.middlewares.len() {
213            let middleware = &self.middlewares[self.index];
214            let next = Next::new(self.handler, self.middlewares, self.index + 1);
215            middleware.on_read_resource(uri, ctx, next).await
216        } else {
217            self.handler.dyn_read_resource(uri, ctx).await
218        }
219    }
220
221    /// Get a prompt through the next middleware or handler.
222    pub async fn get_prompt(
223        self,
224        name: &str,
225        args: Option<Value>,
226        ctx: &RequestContext,
227    ) -> McpResult<PromptResult> {
228        if self.index < self.middlewares.len() {
229            let middleware = &self.middlewares[self.index];
230            let next = Next::new(self.handler, self.middlewares, self.index + 1);
231            middleware.on_get_prompt(name, args, ctx, next).await
232        } else {
233            self.handler.dyn_get_prompt(name, args, ctx).await
234        }
235    }
236
237    /// Run initialization through the next middleware or handler.
238    pub async fn initialize(self) -> McpResult<()> {
239        if self.index < self.middlewares.len() {
240            let middleware = &self.middlewares[self.index];
241            let next = Next::new(self.handler, self.middlewares, self.index + 1);
242            middleware.on_initialize(next).await
243        } else {
244            self.handler.dyn_on_initialize().await
245        }
246    }
247
248    /// Run shutdown through the next middleware or handler.
249    pub async fn shutdown(self) -> McpResult<()> {
250        if self.index < self.middlewares.len() {
251            let middleware = &self.middlewares[self.index];
252            let next = Next::new(self.handler, self.middlewares, self.index + 1);
253            middleware.on_shutdown(next).await
254        } else {
255            self.handler.dyn_on_shutdown().await
256        }
257    }
258}
259
260/// Internal trait for type-erased handler access.
261#[allow(dead_code)]
262trait DynHandler: Send + Sync {
263    fn dyn_server_info(&self) -> ServerInfo;
264    fn dyn_list_tools(&self) -> Vec<Tool>;
265    fn dyn_list_resources(&self) -> Vec<Resource>;
266    fn dyn_list_prompts(&self) -> Vec<Prompt>;
267    fn dyn_call_tool<'a>(
268        &'a self,
269        name: &'a str,
270        args: Value,
271        ctx: &'a RequestContext,
272    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ToolResult>> + Send + 'a>>;
273    fn dyn_read_resource<'a>(
274        &'a self,
275        uri: &'a str,
276        ctx: &'a RequestContext,
277    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ResourceResult>> + Send + 'a>>;
278    fn dyn_get_prompt<'a>(
279        &'a self,
280        name: &'a str,
281        args: Option<Value>,
282        ctx: &'a RequestContext,
283    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<PromptResult>> + Send + 'a>>;
284    fn dyn_on_initialize<'a>(
285        &'a self,
286    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>>;
287    fn dyn_on_shutdown<'a>(
288        &'a self,
289    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>>;
290}
291
292/// Wrapper for type-erased handler access.
293struct HandlerWrapper<H: McpHandler> {
294    handler: H,
295}
296
297impl<H: McpHandler> DynHandler for HandlerWrapper<H> {
298    fn dyn_server_info(&self) -> ServerInfo {
299        self.handler.server_info()
300    }
301
302    fn dyn_list_tools(&self) -> Vec<Tool> {
303        self.handler.list_tools()
304    }
305
306    fn dyn_list_resources(&self) -> Vec<Resource> {
307        self.handler.list_resources()
308    }
309
310    fn dyn_list_prompts(&self) -> Vec<Prompt> {
311        self.handler.list_prompts()
312    }
313
314    fn dyn_call_tool<'a>(
315        &'a self,
316        name: &'a str,
317        args: Value,
318        ctx: &'a RequestContext,
319    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ToolResult>> + Send + 'a>>
320    {
321        Box::pin(self.handler.call_tool(name, args, ctx))
322    }
323
324    fn dyn_read_resource<'a>(
325        &'a self,
326        uri: &'a str,
327        ctx: &'a RequestContext,
328    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ResourceResult>> + Send + 'a>>
329    {
330        Box::pin(self.handler.read_resource(uri, ctx))
331    }
332
333    fn dyn_get_prompt<'a>(
334        &'a self,
335        name: &'a str,
336        args: Option<Value>,
337        ctx: &'a RequestContext,
338    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<PromptResult>> + Send + 'a>>
339    {
340        Box::pin(self.handler.get_prompt(name, args, ctx))
341    }
342
343    fn dyn_on_initialize<'a>(
344        &'a self,
345    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>> {
346        Box::pin(self.handler.on_initialize())
347    }
348
349    fn dyn_on_shutdown<'a>(
350        &'a self,
351    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>> {
352        Box::pin(self.handler.on_shutdown())
353    }
354}
355
356/// A handler wrapped with a middleware stack.
357///
358/// This implements `McpHandler` and runs requests through the middleware chain.
359pub struct MiddlewareStack<H: McpHandler> {
360    handler: Arc<HandlerWrapper<H>>,
361    middlewares: Arc<Vec<Arc<dyn McpMiddleware>>>,
362}
363
364impl<H: McpHandler> Clone for MiddlewareStack<H> {
365    fn clone(&self) -> Self {
366        Self {
367            handler: Arc::clone(&self.handler),
368            middlewares: Arc::clone(&self.middlewares),
369        }
370    }
371}
372
373impl<H: McpHandler> std::fmt::Debug for MiddlewareStack<H> {
374    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375        f.debug_struct("MiddlewareStack")
376            .field("middleware_count", &self.middlewares.len())
377            .finish()
378    }
379}
380
381impl<H: McpHandler> MiddlewareStack<H> {
382    /// Create a new middleware stack wrapping the given handler.
383    pub fn new(handler: H) -> Self {
384        Self {
385            handler: Arc::new(HandlerWrapper { handler }),
386            middlewares: Arc::new(Vec::new()),
387        }
388    }
389
390    /// Add a middleware to the stack.
391    ///
392    /// Middlewares are called in the order they are added.
393    #[must_use]
394    pub fn with_middleware<M: McpMiddleware>(mut self, middleware: M) -> Self {
395        let middlewares = Arc::make_mut(&mut self.middlewares);
396        middlewares.push(Arc::new(middleware));
397        self
398    }
399
400    /// Get the number of middlewares in the stack.
401    pub fn middleware_count(&self) -> usize {
402        self.middlewares.len()
403    }
404
405    fn next(&self) -> Next<'_> {
406        Next::new(self.handler.as_ref(), &self.middlewares, 0)
407    }
408}
409
410#[allow(clippy::manual_async_fn)]
411impl<H: McpHandler> McpHandler for MiddlewareStack<H> {
412    fn server_info(&self) -> ServerInfo {
413        self.handler.dyn_server_info()
414    }
415
416    fn list_tools(&self) -> Vec<Tool> {
417        self.handler.dyn_list_tools()
418    }
419
420    fn list_resources(&self) -> Vec<Resource> {
421        self.handler.dyn_list_resources()
422    }
423
424    fn list_prompts(&self) -> Vec<Prompt> {
425        self.handler.dyn_list_prompts()
426    }
427
428    fn call_tool<'a>(
429        &'a self,
430        name: &'a str,
431        args: Value,
432        ctx: &'a RequestContext,
433    ) -> impl std::future::Future<Output = McpResult<ToolResult>> + turbomcp_core::marker::MaybeSend + 'a
434    {
435        async move { self.next().call_tool(name, args, ctx).await }
436    }
437
438    fn read_resource<'a>(
439        &'a self,
440        uri: &'a str,
441        ctx: &'a RequestContext,
442    ) -> impl std::future::Future<Output = McpResult<ResourceResult>>
443    + turbomcp_core::marker::MaybeSend
444    + 'a {
445        async move { self.next().read_resource(uri, ctx).await }
446    }
447
448    fn get_prompt<'a>(
449        &'a self,
450        name: &'a str,
451        args: Option<Value>,
452        ctx: &'a RequestContext,
453    ) -> impl std::future::Future<Output = McpResult<PromptResult>> + turbomcp_core::marker::MaybeSend + 'a
454    {
455        async move { self.next().get_prompt(name, args, ctx).await }
456    }
457
458    fn on_initialize(
459        &self,
460    ) -> impl std::future::Future<Output = McpResult<()>> + turbomcp_core::marker::MaybeSend {
461        async move { self.next().initialize().await }
462    }
463
464    fn on_shutdown(
465        &self,
466    ) -> impl std::future::Future<Output = McpResult<()>> + turbomcp_core::marker::MaybeSend {
467        async move { self.next().shutdown().await }
468    }
469}
470
471#[cfg(test)]
472#[allow(clippy::manual_async_fn)]
473mod tests {
474    use super::*;
475    use std::sync::atomic::{AtomicU32, Ordering};
476    use turbomcp_core::error::McpError;
477    use turbomcp_core::marker::MaybeSend;
478
479    #[derive(Clone)]
480    struct TestHandler;
481
482    impl McpHandler for TestHandler {
483        fn server_info(&self) -> ServerInfo {
484            ServerInfo::new("test", "1.0.0")
485        }
486
487        fn list_tools(&self) -> Vec<Tool> {
488            vec![Tool::new("test_tool", "A test tool")]
489        }
490
491        fn list_resources(&self) -> Vec<Resource> {
492            vec![Resource::new("test://resource", "A test resource")]
493        }
494
495        fn list_prompts(&self) -> Vec<Prompt> {
496            vec![Prompt::new("test_prompt", "A test prompt")]
497        }
498
499        fn call_tool<'a>(
500            &'a self,
501            name: &'a str,
502            _args: Value,
503            _ctx: &'a RequestContext,
504        ) -> impl std::future::Future<Output = McpResult<ToolResult>> + MaybeSend + 'a {
505            async move {
506                match name {
507                    "test_tool" => Ok(ToolResult::text("Test result")),
508                    _ => Err(McpError::tool_not_found(name)),
509                }
510            }
511        }
512
513        fn read_resource<'a>(
514            &'a self,
515            uri: &'a str,
516            _ctx: &'a RequestContext,
517        ) -> impl std::future::Future<Output = McpResult<ResourceResult>> + MaybeSend + 'a {
518            let uri = uri.to_string();
519            async move {
520                if uri == "test://resource" {
521                    Ok(ResourceResult::text(&uri, "Test content"))
522                } else {
523                    Err(McpError::resource_not_found(&uri))
524                }
525            }
526        }
527
528        fn get_prompt<'a>(
529            &'a self,
530            name: &'a str,
531            _args: Option<Value>,
532            _ctx: &'a RequestContext,
533        ) -> impl std::future::Future<Output = McpResult<PromptResult>> + MaybeSend + 'a {
534            let name = name.to_string();
535            async move {
536                if name == "test_prompt" {
537                    Ok(PromptResult::user("Test prompt message"))
538                } else {
539                    Err(McpError::prompt_not_found(&name))
540                }
541            }
542        }
543    }
544
545    /// A simple counting middleware for testing.
546    struct CountingMiddleware {
547        tool_calls: AtomicU32,
548        resource_reads: AtomicU32,
549        prompt_gets: AtomicU32,
550        initializes: AtomicU32,
551        shutdowns: AtomicU32,
552    }
553
554    impl CountingMiddleware {
555        fn new() -> Self {
556            Self {
557                tool_calls: AtomicU32::new(0),
558                resource_reads: AtomicU32::new(0),
559                prompt_gets: AtomicU32::new(0),
560                initializes: AtomicU32::new(0),
561                shutdowns: AtomicU32::new(0),
562            }
563        }
564
565        fn tool_calls(&self) -> u32 {
566            self.tool_calls.load(Ordering::Relaxed)
567        }
568
569        fn resource_reads(&self) -> u32 {
570            self.resource_reads.load(Ordering::Relaxed)
571        }
572
573        fn prompt_gets(&self) -> u32 {
574            self.prompt_gets.load(Ordering::Relaxed)
575        }
576
577        fn initializes(&self) -> u32 {
578            self.initializes.load(Ordering::Relaxed)
579        }
580
581        fn shutdowns(&self) -> u32 {
582            self.shutdowns.load(Ordering::Relaxed)
583        }
584    }
585
586    impl McpMiddleware for CountingMiddleware {
587        fn on_call_tool<'a>(
588            &'a self,
589            name: &'a str,
590            args: Value,
591            ctx: &'a RequestContext,
592            next: Next<'a>,
593        ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
594            Box::pin(async move {
595                self.tool_calls.fetch_add(1, Ordering::Relaxed);
596                next.call_tool(name, args, ctx).await
597            })
598        }
599
600        fn on_read_resource<'a>(
601            &'a self,
602            uri: &'a str,
603            ctx: &'a RequestContext,
604            next: Next<'a>,
605        ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
606            Box::pin(async move {
607                self.resource_reads.fetch_add(1, Ordering::Relaxed);
608                next.read_resource(uri, ctx).await
609            })
610        }
611
612        fn on_get_prompt<'a>(
613            &'a self,
614            name: &'a str,
615            args: Option<Value>,
616            ctx: &'a RequestContext,
617            next: Next<'a>,
618        ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
619            Box::pin(async move {
620                self.prompt_gets.fetch_add(1, Ordering::Relaxed);
621                next.get_prompt(name, args, ctx).await
622            })
623        }
624
625        fn on_initialize<'a>(
626            &'a self,
627            next: Next<'a>,
628        ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
629            Box::pin(async move {
630                self.initializes.fetch_add(1, Ordering::Relaxed);
631                next.initialize().await
632            })
633        }
634
635        fn on_shutdown<'a>(
636            &'a self,
637            next: Next<'a>,
638        ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
639            Box::pin(async move {
640                self.shutdowns.fetch_add(1, Ordering::Relaxed);
641                next.shutdown().await
642            })
643        }
644    }
645
646    /// A middleware that blocks certain tools.
647    struct BlockingMiddleware {
648        blocked_tools: Vec<String>,
649    }
650
651    impl BlockingMiddleware {
652        fn new(blocked: Vec<&str>) -> Self {
653            Self {
654                blocked_tools: blocked.into_iter().map(String::from).collect(),
655            }
656        }
657    }
658
659    impl McpMiddleware for BlockingMiddleware {
660        fn on_call_tool<'a>(
661            &'a self,
662            name: &'a str,
663            args: Value,
664            ctx: &'a RequestContext,
665            next: Next<'a>,
666        ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
667            Box::pin(async move {
668                if self.blocked_tools.contains(&name.to_string()) {
669                    return Err(McpError::internal(format!("Tool '{}' is blocked", name)));
670                }
671                next.call_tool(name, args, ctx).await
672            })
673        }
674    }
675
676    #[test]
677    fn test_middleware_stack_creation() {
678        let stack = MiddlewareStack::new(TestHandler)
679            .with_middleware(CountingMiddleware::new())
680            .with_middleware(BlockingMiddleware::new(vec!["blocked"]));
681
682        assert_eq!(stack.middleware_count(), 2);
683    }
684
685    #[test]
686    fn test_server_info_passthrough() {
687        let stack = MiddlewareStack::new(TestHandler);
688        let info = stack.server_info();
689        assert_eq!(info.name, "test");
690        assert_eq!(info.version, "1.0.0");
691    }
692
693    #[test]
694    fn test_list_tools_passthrough() {
695        let stack = MiddlewareStack::new(TestHandler);
696        let tools = stack.list_tools();
697        assert_eq!(tools.len(), 1);
698        assert_eq!(tools[0].name, "test_tool");
699    }
700
701    #[tokio::test]
702    async fn test_call_tool_through_middleware() {
703        let counting = Arc::new(CountingMiddleware::new());
704        let stack =
705            MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
706
707        let ctx = RequestContext::default();
708        let result = stack
709            .call_tool("test_tool", serde_json::json!({}), &ctx)
710            .await
711            .unwrap();
712
713        assert_eq!(result.first_text(), Some("Test result"));
714        assert_eq!(counting.tool_calls(), 1);
715    }
716
717    #[tokio::test]
718    async fn test_blocking_middleware() {
719        let stack = MiddlewareStack::new(TestHandler)
720            .with_middleware(BlockingMiddleware::new(vec!["test_tool"]));
721
722        let ctx = RequestContext::default();
723        let result = stack
724            .call_tool("test_tool", serde_json::json!({}), &ctx)
725            .await;
726
727        assert!(result.is_err());
728        assert!(result.unwrap_err().to_string().contains("blocked"));
729    }
730
731    #[tokio::test]
732    async fn test_middleware_chain_order() {
733        let counting1 = Arc::new(CountingMiddleware::new());
734        let counting2 = Arc::new(CountingMiddleware::new());
735
736        let stack = MiddlewareStack::new(TestHandler)
737            .with_middleware(CountingClone(counting1.clone()))
738            .with_middleware(CountingClone(counting2.clone()));
739
740        let ctx = RequestContext::default();
741        stack
742            .call_tool("test_tool", serde_json::json!({}), &ctx)
743            .await
744            .unwrap();
745
746        // Both middlewares should be called
747        assert_eq!(counting1.tool_calls(), 1);
748        assert_eq!(counting2.tool_calls(), 1);
749    }
750
751    #[tokio::test]
752    async fn test_read_resource_through_middleware() {
753        let counting = Arc::new(CountingMiddleware::new());
754        let stack =
755            MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
756
757        let ctx = RequestContext::default();
758        let result = stack.read_resource("test://resource", &ctx).await.unwrap();
759
760        assert!(!result.contents.is_empty());
761        assert_eq!(counting.resource_reads(), 1);
762    }
763
764    #[tokio::test]
765    async fn test_get_prompt_through_middleware() {
766        let counting = Arc::new(CountingMiddleware::new());
767        let stack =
768            MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
769
770        let ctx = RequestContext::default();
771        let result = stack.get_prompt("test_prompt", None, &ctx).await.unwrap();
772
773        assert!(!result.messages.is_empty());
774        assert_eq!(counting.prompt_gets(), 1);
775    }
776
777    /// Wrapper to make Arc<CountingMiddleware> work as middleware.
778    struct CountingClone(Arc<CountingMiddleware>);
779
780    impl McpMiddleware for CountingClone {
781        fn on_call_tool<'a>(
782            &'a self,
783            name: &'a str,
784            args: Value,
785            ctx: &'a RequestContext,
786            next: Next<'a>,
787        ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
788            self.0.on_call_tool(name, args, ctx, next)
789        }
790
791        fn on_read_resource<'a>(
792            &'a self,
793            uri: &'a str,
794            ctx: &'a RequestContext,
795            next: Next<'a>,
796        ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
797            self.0.on_read_resource(uri, ctx, next)
798        }
799
800        fn on_get_prompt<'a>(
801            &'a self,
802            name: &'a str,
803            args: Option<Value>,
804            ctx: &'a RequestContext,
805            next: Next<'a>,
806        ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
807            self.0.on_get_prompt(name, args, ctx, next)
808        }
809
810        fn on_initialize<'a>(
811            &'a self,
812            next: Next<'a>,
813        ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
814            self.0.on_initialize(next)
815        }
816
817        fn on_shutdown<'a>(
818            &'a self,
819            next: Next<'a>,
820        ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
821            self.0.on_shutdown(next)
822        }
823    }
824
825    #[tokio::test]
826    async fn test_on_initialize_through_middleware() {
827        let counting = Arc::new(CountingMiddleware::new());
828        let stack =
829            MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
830
831        stack.on_initialize().await.unwrap();
832
833        assert_eq!(counting.initializes(), 1);
834    }
835
836    #[tokio::test]
837    async fn test_on_shutdown_through_middleware() {
838        let counting = Arc::new(CountingMiddleware::new());
839        let stack =
840            MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
841
842        stack.on_shutdown().await.unwrap();
843
844        assert_eq!(counting.shutdowns(), 1);
845    }
846
847    #[tokio::test]
848    async fn test_lifecycle_hooks_chain_through_multiple_middlewares() {
849        let counting1 = Arc::new(CountingMiddleware::new());
850        let counting2 = Arc::new(CountingMiddleware::new());
851
852        let stack = MiddlewareStack::new(TestHandler)
853            .with_middleware(CountingClone(counting1.clone()))
854            .with_middleware(CountingClone(counting2.clone()));
855
856        stack.on_initialize().await.unwrap();
857        stack.on_shutdown().await.unwrap();
858
859        assert_eq!(counting1.initializes(), 1);
860        assert_eq!(counting2.initializes(), 1);
861        assert_eq!(counting1.shutdowns(), 1);
862        assert_eq!(counting2.shutdowns(), 1);
863    }
864
865    /// A middleware that blocks initialization.
866    struct BlockInitMiddleware;
867
868    impl McpMiddleware for BlockInitMiddleware {
869        fn on_initialize<'a>(
870            &'a self,
871            _next: Next<'a>,
872        ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
873            Box::pin(async move { Err(McpError::internal("initialization blocked by middleware")) })
874        }
875    }
876
877    #[tokio::test]
878    async fn test_on_initialize_short_circuit() {
879        let stack = MiddlewareStack::new(TestHandler).with_middleware(BlockInitMiddleware);
880
881        let result = stack.on_initialize().await;
882        assert!(result.is_err());
883        assert!(result.unwrap_err().to_string().contains("blocked"));
884    }
885}