1use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use turbomcp_core::context::RequestContext;
13use turbomcp_core::error::McpResult;
14use turbomcp_core::handler::McpHandler;
15use turbomcp_types::{
16 Prompt, PromptResult, Resource, ResourceResult, ServerInfo, Tool, ToolResult,
17};
18
19pub trait McpMiddleware: Send + Sync + 'static {
56 fn on_list_tools<'a>(
60 &'a self,
61 next: Next<'a>,
62 ) -> Pin<Box<dyn Future<Output = Vec<Tool>> + Send + 'a>> {
63 Box::pin(async move { next.list_tools() })
64 }
65
66 fn on_list_resources<'a>(
68 &'a self,
69 next: Next<'a>,
70 ) -> Pin<Box<dyn Future<Output = Vec<Resource>> + Send + 'a>> {
71 Box::pin(async move { next.list_resources() })
72 }
73
74 fn on_list_prompts<'a>(
76 &'a self,
77 next: Next<'a>,
78 ) -> Pin<Box<dyn Future<Output = Vec<Prompt>> + Send + 'a>> {
79 Box::pin(async move { next.list_prompts() })
80 }
81
82 fn on_call_tool<'a>(
86 &'a self,
87 name: &'a str,
88 args: Value,
89 ctx: &'a RequestContext,
90 next: Next<'a>,
91 ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
92 Box::pin(async move { next.call_tool(name, args, ctx).await })
93 }
94
95 fn on_read_resource<'a>(
97 &'a self,
98 uri: &'a str,
99 ctx: &'a RequestContext,
100 next: Next<'a>,
101 ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
102 Box::pin(async move { next.read_resource(uri, ctx).await })
103 }
104
105 fn on_get_prompt<'a>(
107 &'a self,
108 name: &'a str,
109 args: Option<Value>,
110 ctx: &'a RequestContext,
111 next: Next<'a>,
112 ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
113 Box::pin(async move { next.get_prompt(name, args, ctx).await })
114 }
115
116 fn on_initialize<'a>(
121 &'a self,
122 next: Next<'a>,
123 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
124 Box::pin(async move { next.initialize().await })
125 }
126
127 fn on_shutdown<'a>(
131 &'a self,
132 next: Next<'a>,
133 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
134 Box::pin(async move { next.shutdown().await })
135 }
136}
137
138pub struct Next<'a> {
143 handler: &'a dyn DynHandler,
144 middlewares: &'a [Arc<dyn McpMiddleware>],
145 index: usize,
146}
147
148impl<'a> std::fmt::Debug for Next<'a> {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("Next")
151 .field("index", &self.index)
152 .field(
153 "remaining_middlewares",
154 &(self.middlewares.len() - self.index),
155 )
156 .finish()
157 }
158}
159
160impl<'a> Next<'a> {
161 fn new(
162 handler: &'a dyn DynHandler,
163 middlewares: &'a [Arc<dyn McpMiddleware>],
164 index: usize,
165 ) -> Self {
166 Self {
167 handler,
168 middlewares,
169 index,
170 }
171 }
172
173 pub fn list_tools(self) -> Vec<Tool> {
175 if self.index < self.middlewares.len() {
176 self.handler.dyn_list_tools()
179 } else {
180 self.handler.dyn_list_tools()
181 }
182 }
183
184 pub fn list_resources(self) -> Vec<Resource> {
186 self.handler.dyn_list_resources()
187 }
188
189 pub fn list_prompts(self) -> Vec<Prompt> {
191 self.handler.dyn_list_prompts()
192 }
193
194 pub async fn call_tool(
196 self,
197 name: &str,
198 args: Value,
199 ctx: &RequestContext,
200 ) -> McpResult<ToolResult> {
201 if self.index < self.middlewares.len() {
202 let middleware = &self.middlewares[self.index];
203 let next = Next::new(self.handler, self.middlewares, self.index + 1);
204 middleware.on_call_tool(name, args, ctx, next).await
205 } else {
206 self.handler.dyn_call_tool(name, args, ctx).await
207 }
208 }
209
210 pub async fn read_resource(self, uri: &str, ctx: &RequestContext) -> McpResult<ResourceResult> {
212 if self.index < self.middlewares.len() {
213 let middleware = &self.middlewares[self.index];
214 let next = Next::new(self.handler, self.middlewares, self.index + 1);
215 middleware.on_read_resource(uri, ctx, next).await
216 } else {
217 self.handler.dyn_read_resource(uri, ctx).await
218 }
219 }
220
221 pub async fn get_prompt(
223 self,
224 name: &str,
225 args: Option<Value>,
226 ctx: &RequestContext,
227 ) -> McpResult<PromptResult> {
228 if self.index < self.middlewares.len() {
229 let middleware = &self.middlewares[self.index];
230 let next = Next::new(self.handler, self.middlewares, self.index + 1);
231 middleware.on_get_prompt(name, args, ctx, next).await
232 } else {
233 self.handler.dyn_get_prompt(name, args, ctx).await
234 }
235 }
236
237 pub async fn initialize(self) -> McpResult<()> {
239 if self.index < self.middlewares.len() {
240 let middleware = &self.middlewares[self.index];
241 let next = Next::new(self.handler, self.middlewares, self.index + 1);
242 middleware.on_initialize(next).await
243 } else {
244 self.handler.dyn_on_initialize().await
245 }
246 }
247
248 pub async fn shutdown(self) -> McpResult<()> {
250 if self.index < self.middlewares.len() {
251 let middleware = &self.middlewares[self.index];
252 let next = Next::new(self.handler, self.middlewares, self.index + 1);
253 middleware.on_shutdown(next).await
254 } else {
255 self.handler.dyn_on_shutdown().await
256 }
257 }
258}
259
260trait DynHandler: Send + Sync {
262 fn dyn_server_info(&self) -> ServerInfo;
263 fn dyn_list_tools(&self) -> Vec<Tool>;
264 fn dyn_list_resources(&self) -> Vec<Resource>;
265 fn dyn_list_prompts(&self) -> Vec<Prompt>;
266 fn dyn_call_tool<'a>(
267 &'a self,
268 name: &'a str,
269 args: Value,
270 ctx: &'a RequestContext,
271 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ToolResult>> + Send + 'a>>;
272 fn dyn_read_resource<'a>(
273 &'a self,
274 uri: &'a str,
275 ctx: &'a RequestContext,
276 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ResourceResult>> + Send + 'a>>;
277 fn dyn_get_prompt<'a>(
278 &'a self,
279 name: &'a str,
280 args: Option<Value>,
281 ctx: &'a RequestContext,
282 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<PromptResult>> + Send + 'a>>;
283 fn dyn_on_initialize<'a>(
284 &'a self,
285 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>>;
286 fn dyn_on_shutdown<'a>(
287 &'a self,
288 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>>;
289}
290
291struct HandlerWrapper<H: McpHandler> {
293 handler: H,
294}
295
296impl<H: McpHandler> DynHandler for HandlerWrapper<H> {
297 fn dyn_server_info(&self) -> ServerInfo {
298 self.handler.server_info()
299 }
300
301 fn dyn_list_tools(&self) -> Vec<Tool> {
302 self.handler.list_tools()
303 }
304
305 fn dyn_list_resources(&self) -> Vec<Resource> {
306 self.handler.list_resources()
307 }
308
309 fn dyn_list_prompts(&self) -> Vec<Prompt> {
310 self.handler.list_prompts()
311 }
312
313 fn dyn_call_tool<'a>(
314 &'a self,
315 name: &'a str,
316 args: Value,
317 ctx: &'a RequestContext,
318 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ToolResult>> + Send + 'a>>
319 {
320 Box::pin(self.handler.call_tool(name, args, ctx))
321 }
322
323 fn dyn_read_resource<'a>(
324 &'a self,
325 uri: &'a str,
326 ctx: &'a RequestContext,
327 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ResourceResult>> + Send + 'a>>
328 {
329 Box::pin(self.handler.read_resource(uri, ctx))
330 }
331
332 fn dyn_get_prompt<'a>(
333 &'a self,
334 name: &'a str,
335 args: Option<Value>,
336 ctx: &'a RequestContext,
337 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<PromptResult>> + Send + 'a>>
338 {
339 Box::pin(self.handler.get_prompt(name, args, ctx))
340 }
341
342 fn dyn_on_initialize<'a>(
343 &'a self,
344 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>> {
345 Box::pin(self.handler.on_initialize())
346 }
347
348 fn dyn_on_shutdown<'a>(
349 &'a self,
350 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>> {
351 Box::pin(self.handler.on_shutdown())
352 }
353}
354
355pub struct MiddlewareStack<H: McpHandler> {
359 handler: Arc<HandlerWrapper<H>>,
360 middlewares: Arc<Vec<Arc<dyn McpMiddleware>>>,
361}
362
363impl<H: McpHandler> Clone for MiddlewareStack<H> {
364 fn clone(&self) -> Self {
365 Self {
366 handler: Arc::clone(&self.handler),
367 middlewares: Arc::clone(&self.middlewares),
368 }
369 }
370}
371
372impl<H: McpHandler> std::fmt::Debug for MiddlewareStack<H> {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 f.debug_struct("MiddlewareStack")
375 .field("middleware_count", &self.middlewares.len())
376 .finish()
377 }
378}
379
380impl<H: McpHandler> MiddlewareStack<H> {
381 pub fn new(handler: H) -> Self {
383 Self {
384 handler: Arc::new(HandlerWrapper { handler }),
385 middlewares: Arc::new(Vec::new()),
386 }
387 }
388
389 #[must_use]
393 pub fn with_middleware<M: McpMiddleware>(mut self, middleware: M) -> Self {
394 let middlewares = Arc::make_mut(&mut self.middlewares);
395 middlewares.push(Arc::new(middleware));
396 self
397 }
398
399 pub fn middleware_count(&self) -> usize {
401 self.middlewares.len()
402 }
403
404 fn next(&self) -> Next<'_> {
405 Next::new(self.handler.as_ref(), &self.middlewares, 0)
406 }
407}
408
409#[allow(clippy::manual_async_fn)]
410impl<H: McpHandler> McpHandler for MiddlewareStack<H> {
411 fn server_info(&self) -> ServerInfo {
412 self.handler.dyn_server_info()
413 }
414
415 fn list_tools(&self) -> Vec<Tool> {
416 self.handler.dyn_list_tools()
417 }
418
419 fn list_resources(&self) -> Vec<Resource> {
420 self.handler.dyn_list_resources()
421 }
422
423 fn list_prompts(&self) -> Vec<Prompt> {
424 self.handler.dyn_list_prompts()
425 }
426
427 fn call_tool<'a>(
428 &'a self,
429 name: &'a str,
430 args: Value,
431 ctx: &'a RequestContext,
432 ) -> impl std::future::Future<Output = McpResult<ToolResult>> + turbomcp_core::marker::MaybeSend + 'a
433 {
434 async move { self.next().call_tool(name, args, ctx).await }
435 }
436
437 fn read_resource<'a>(
438 &'a self,
439 uri: &'a str,
440 ctx: &'a RequestContext,
441 ) -> impl std::future::Future<Output = McpResult<ResourceResult>>
442 + turbomcp_core::marker::MaybeSend
443 + 'a {
444 async move { self.next().read_resource(uri, ctx).await }
445 }
446
447 fn get_prompt<'a>(
448 &'a self,
449 name: &'a str,
450 args: Option<Value>,
451 ctx: &'a RequestContext,
452 ) -> impl std::future::Future<Output = McpResult<PromptResult>> + turbomcp_core::marker::MaybeSend + 'a
453 {
454 async move { self.next().get_prompt(name, args, ctx).await }
455 }
456
457 fn on_initialize(
458 &self,
459 ) -> impl std::future::Future<Output = McpResult<()>> + turbomcp_core::marker::MaybeSend {
460 async move { self.next().initialize().await }
461 }
462
463 fn on_shutdown(
464 &self,
465 ) -> impl std::future::Future<Output = McpResult<()>> + turbomcp_core::marker::MaybeSend {
466 async move { self.next().shutdown().await }
467 }
468}
469
470#[cfg(test)]
471#[allow(clippy::manual_async_fn)]
472mod tests {
473 use super::*;
474 use std::sync::atomic::{AtomicU32, Ordering};
475 use turbomcp_core::error::McpError;
476 use turbomcp_core::marker::MaybeSend;
477
478 #[derive(Clone)]
479 struct TestHandler;
480
481 impl McpHandler for TestHandler {
482 fn server_info(&self) -> ServerInfo {
483 ServerInfo::new("test", "1.0.0")
484 }
485
486 fn list_tools(&self) -> Vec<Tool> {
487 vec![Tool::new("test_tool", "A test tool")]
488 }
489
490 fn list_resources(&self) -> Vec<Resource> {
491 vec![Resource::new("test://resource", "A test resource")]
492 }
493
494 fn list_prompts(&self) -> Vec<Prompt> {
495 vec![Prompt::new("test_prompt", "A test prompt")]
496 }
497
498 fn call_tool<'a>(
499 &'a self,
500 name: &'a str,
501 _args: Value,
502 _ctx: &'a RequestContext,
503 ) -> impl std::future::Future<Output = McpResult<ToolResult>> + MaybeSend + 'a {
504 async move {
505 match name {
506 "test_tool" => Ok(ToolResult::text("Test result")),
507 _ => Err(McpError::tool_not_found(name)),
508 }
509 }
510 }
511
512 fn read_resource<'a>(
513 &'a self,
514 uri: &'a str,
515 _ctx: &'a RequestContext,
516 ) -> impl std::future::Future<Output = McpResult<ResourceResult>> + MaybeSend + 'a {
517 let uri = uri.to_string();
518 async move {
519 if uri == "test://resource" {
520 Ok(ResourceResult::text(&uri, "Test content"))
521 } else {
522 Err(McpError::resource_not_found(&uri))
523 }
524 }
525 }
526
527 fn get_prompt<'a>(
528 &'a self,
529 name: &'a str,
530 _args: Option<Value>,
531 _ctx: &'a RequestContext,
532 ) -> impl std::future::Future<Output = McpResult<PromptResult>> + MaybeSend + 'a {
533 let name = name.to_string();
534 async move {
535 if name == "test_prompt" {
536 Ok(PromptResult::user("Test prompt message"))
537 } else {
538 Err(McpError::prompt_not_found(&name))
539 }
540 }
541 }
542 }
543
544 struct CountingMiddleware {
546 tool_calls: AtomicU32,
547 resource_reads: AtomicU32,
548 prompt_gets: AtomicU32,
549 initializes: AtomicU32,
550 shutdowns: AtomicU32,
551 }
552
553 impl CountingMiddleware {
554 fn new() -> Self {
555 Self {
556 tool_calls: AtomicU32::new(0),
557 resource_reads: AtomicU32::new(0),
558 prompt_gets: AtomicU32::new(0),
559 initializes: AtomicU32::new(0),
560 shutdowns: AtomicU32::new(0),
561 }
562 }
563
564 fn tool_calls(&self) -> u32 {
565 self.tool_calls.load(Ordering::Relaxed)
566 }
567
568 fn resource_reads(&self) -> u32 {
569 self.resource_reads.load(Ordering::Relaxed)
570 }
571
572 fn prompt_gets(&self) -> u32 {
573 self.prompt_gets.load(Ordering::Relaxed)
574 }
575
576 fn initializes(&self) -> u32 {
577 self.initializes.load(Ordering::Relaxed)
578 }
579
580 fn shutdowns(&self) -> u32 {
581 self.shutdowns.load(Ordering::Relaxed)
582 }
583 }
584
585 impl McpMiddleware for CountingMiddleware {
586 fn on_call_tool<'a>(
587 &'a self,
588 name: &'a str,
589 args: Value,
590 ctx: &'a RequestContext,
591 next: Next<'a>,
592 ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
593 Box::pin(async move {
594 self.tool_calls.fetch_add(1, Ordering::Relaxed);
595 next.call_tool(name, args, ctx).await
596 })
597 }
598
599 fn on_read_resource<'a>(
600 &'a self,
601 uri: &'a str,
602 ctx: &'a RequestContext,
603 next: Next<'a>,
604 ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
605 Box::pin(async move {
606 self.resource_reads.fetch_add(1, Ordering::Relaxed);
607 next.read_resource(uri, ctx).await
608 })
609 }
610
611 fn on_get_prompt<'a>(
612 &'a self,
613 name: &'a str,
614 args: Option<Value>,
615 ctx: &'a RequestContext,
616 next: Next<'a>,
617 ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
618 Box::pin(async move {
619 self.prompt_gets.fetch_add(1, Ordering::Relaxed);
620 next.get_prompt(name, args, ctx).await
621 })
622 }
623
624 fn on_initialize<'a>(
625 &'a self,
626 next: Next<'a>,
627 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
628 Box::pin(async move {
629 self.initializes.fetch_add(1, Ordering::Relaxed);
630 next.initialize().await
631 })
632 }
633
634 fn on_shutdown<'a>(
635 &'a self,
636 next: Next<'a>,
637 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
638 Box::pin(async move {
639 self.shutdowns.fetch_add(1, Ordering::Relaxed);
640 next.shutdown().await
641 })
642 }
643 }
644
645 struct BlockingMiddleware {
647 blocked_tools: Vec<String>,
648 }
649
650 impl BlockingMiddleware {
651 fn new(blocked: Vec<&str>) -> Self {
652 Self {
653 blocked_tools: blocked.into_iter().map(String::from).collect(),
654 }
655 }
656 }
657
658 impl McpMiddleware for BlockingMiddleware {
659 fn on_call_tool<'a>(
660 &'a self,
661 name: &'a str,
662 args: Value,
663 ctx: &'a RequestContext,
664 next: Next<'a>,
665 ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
666 Box::pin(async move {
667 if self.blocked_tools.contains(&name.to_string()) {
668 return Err(McpError::internal(format!("Tool '{}' is blocked", name)));
669 }
670 next.call_tool(name, args, ctx).await
671 })
672 }
673 }
674
675 #[test]
676 fn test_middleware_stack_creation() {
677 let stack = MiddlewareStack::new(TestHandler)
678 .with_middleware(CountingMiddleware::new())
679 .with_middleware(BlockingMiddleware::new(vec!["blocked"]));
680
681 assert_eq!(stack.middleware_count(), 2);
682 }
683
684 #[test]
685 fn test_server_info_passthrough() {
686 let stack = MiddlewareStack::new(TestHandler);
687 let info = stack.server_info();
688 assert_eq!(info.name, "test");
689 assert_eq!(info.version, "1.0.0");
690 }
691
692 #[test]
693 fn test_list_tools_passthrough() {
694 let stack = MiddlewareStack::new(TestHandler);
695 let tools = stack.list_tools();
696 assert_eq!(tools.len(), 1);
697 assert_eq!(tools[0].name, "test_tool");
698 }
699
700 #[tokio::test]
701 async fn test_call_tool_through_middleware() {
702 let counting = Arc::new(CountingMiddleware::new());
703 let stack =
704 MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
705
706 let ctx = RequestContext::default();
707 let result = stack
708 .call_tool("test_tool", serde_json::json!({}), &ctx)
709 .await
710 .unwrap();
711
712 assert_eq!(result.first_text(), Some("Test result"));
713 assert_eq!(counting.tool_calls(), 1);
714 }
715
716 #[tokio::test]
717 async fn test_blocking_middleware() {
718 let stack = MiddlewareStack::new(TestHandler)
719 .with_middleware(BlockingMiddleware::new(vec!["test_tool"]));
720
721 let ctx = RequestContext::default();
722 let result = stack
723 .call_tool("test_tool", serde_json::json!({}), &ctx)
724 .await;
725
726 assert!(result.is_err());
727 assert!(result.unwrap_err().to_string().contains("blocked"));
728 }
729
730 #[tokio::test]
731 async fn test_middleware_chain_order() {
732 let counting1 = Arc::new(CountingMiddleware::new());
733 let counting2 = Arc::new(CountingMiddleware::new());
734
735 let stack = MiddlewareStack::new(TestHandler)
736 .with_middleware(CountingClone(counting1.clone()))
737 .with_middleware(CountingClone(counting2.clone()));
738
739 let ctx = RequestContext::default();
740 stack
741 .call_tool("test_tool", serde_json::json!({}), &ctx)
742 .await
743 .unwrap();
744
745 assert_eq!(counting1.tool_calls(), 1);
747 assert_eq!(counting2.tool_calls(), 1);
748 }
749
750 #[tokio::test]
751 async fn test_read_resource_through_middleware() {
752 let counting = Arc::new(CountingMiddleware::new());
753 let stack =
754 MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
755
756 let ctx = RequestContext::default();
757 let result = stack.read_resource("test://resource", &ctx).await.unwrap();
758
759 assert!(!result.contents.is_empty());
760 assert_eq!(counting.resource_reads(), 1);
761 }
762
763 #[tokio::test]
764 async fn test_get_prompt_through_middleware() {
765 let counting = Arc::new(CountingMiddleware::new());
766 let stack =
767 MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
768
769 let ctx = RequestContext::default();
770 let result = stack.get_prompt("test_prompt", None, &ctx).await.unwrap();
771
772 assert!(!result.messages.is_empty());
773 assert_eq!(counting.prompt_gets(), 1);
774 }
775
776 struct CountingClone(Arc<CountingMiddleware>);
778
779 impl McpMiddleware for CountingClone {
780 fn on_call_tool<'a>(
781 &'a self,
782 name: &'a str,
783 args: Value,
784 ctx: &'a RequestContext,
785 next: Next<'a>,
786 ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
787 self.0.on_call_tool(name, args, ctx, next)
788 }
789
790 fn on_read_resource<'a>(
791 &'a self,
792 uri: &'a str,
793 ctx: &'a RequestContext,
794 next: Next<'a>,
795 ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
796 self.0.on_read_resource(uri, ctx, next)
797 }
798
799 fn on_get_prompt<'a>(
800 &'a self,
801 name: &'a str,
802 args: Option<Value>,
803 ctx: &'a RequestContext,
804 next: Next<'a>,
805 ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
806 self.0.on_get_prompt(name, args, ctx, next)
807 }
808
809 fn on_initialize<'a>(
810 &'a self,
811 next: Next<'a>,
812 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
813 self.0.on_initialize(next)
814 }
815
816 fn on_shutdown<'a>(
817 &'a self,
818 next: Next<'a>,
819 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
820 self.0.on_shutdown(next)
821 }
822 }
823
824 #[tokio::test]
825 async fn test_on_initialize_through_middleware() {
826 let counting = Arc::new(CountingMiddleware::new());
827 let stack =
828 MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
829
830 stack.on_initialize().await.unwrap();
831
832 assert_eq!(counting.initializes(), 1);
833 }
834
835 #[tokio::test]
836 async fn test_on_shutdown_through_middleware() {
837 let counting = Arc::new(CountingMiddleware::new());
838 let stack =
839 MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
840
841 stack.on_shutdown().await.unwrap();
842
843 assert_eq!(counting.shutdowns(), 1);
844 }
845
846 #[tokio::test]
847 async fn test_lifecycle_hooks_chain_through_multiple_middlewares() {
848 let counting1 = Arc::new(CountingMiddleware::new());
849 let counting2 = Arc::new(CountingMiddleware::new());
850
851 let stack = MiddlewareStack::new(TestHandler)
852 .with_middleware(CountingClone(counting1.clone()))
853 .with_middleware(CountingClone(counting2.clone()));
854
855 stack.on_initialize().await.unwrap();
856 stack.on_shutdown().await.unwrap();
857
858 assert_eq!(counting1.initializes(), 1);
859 assert_eq!(counting2.initializes(), 1);
860 assert_eq!(counting1.shutdowns(), 1);
861 assert_eq!(counting2.shutdowns(), 1);
862 }
863
864 struct BlockInitMiddleware;
866
867 impl McpMiddleware for BlockInitMiddleware {
868 fn on_initialize<'a>(
869 &'a self,
870 _next: Next<'a>,
871 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
872 Box::pin(async move { Err(McpError::internal("initialization blocked by middleware")) })
873 }
874 }
875
876 #[tokio::test]
877 async fn test_on_initialize_short_circuit() {
878 let stack = MiddlewareStack::new(TestHandler).with_middleware(BlockInitMiddleware);
879
880 let result = stack.on_initialize().await;
881 assert!(result.is_err());
882 assert!(result.unwrap_err().to_string().contains("blocked"));
883 }
884}