1use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use turbomcp_core::context::RequestContext;
13use turbomcp_core::error::McpResult;
14use turbomcp_core::handler::McpHandler;
15use turbomcp_types::{
16 Prompt, PromptResult, Resource, ResourceResult, ServerInfo, Tool, ToolResult,
17};
18
19pub trait McpMiddleware: Send + Sync + 'static {
56 fn on_list_tools<'a>(
60 &'a self,
61 next: Next<'a>,
62 ) -> Pin<Box<dyn Future<Output = Vec<Tool>> + Send + 'a>> {
63 Box::pin(async move { next.list_tools() })
64 }
65
66 fn on_list_resources<'a>(
68 &'a self,
69 next: Next<'a>,
70 ) -> Pin<Box<dyn Future<Output = Vec<Resource>> + Send + 'a>> {
71 Box::pin(async move { next.list_resources() })
72 }
73
74 fn on_list_prompts<'a>(
76 &'a self,
77 next: Next<'a>,
78 ) -> Pin<Box<dyn Future<Output = Vec<Prompt>> + Send + 'a>> {
79 Box::pin(async move { next.list_prompts() })
80 }
81
82 fn on_call_tool<'a>(
86 &'a self,
87 name: &'a str,
88 args: Value,
89 ctx: &'a RequestContext,
90 next: Next<'a>,
91 ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
92 Box::pin(async move { next.call_tool(name, args, ctx).await })
93 }
94
95 fn on_read_resource<'a>(
97 &'a self,
98 uri: &'a str,
99 ctx: &'a RequestContext,
100 next: Next<'a>,
101 ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
102 Box::pin(async move { next.read_resource(uri, ctx).await })
103 }
104
105 fn on_get_prompt<'a>(
107 &'a self,
108 name: &'a str,
109 args: Option<Value>,
110 ctx: &'a RequestContext,
111 next: Next<'a>,
112 ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
113 Box::pin(async move { next.get_prompt(name, args, ctx).await })
114 }
115
116 fn on_initialize<'a>(
121 &'a self,
122 next: Next<'a>,
123 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
124 Box::pin(async move { next.initialize().await })
125 }
126
127 fn on_shutdown<'a>(
131 &'a self,
132 next: Next<'a>,
133 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
134 Box::pin(async move { next.shutdown().await })
135 }
136}
137
138pub struct Next<'a> {
143 handler: &'a dyn DynHandler,
144 middlewares: &'a [Arc<dyn McpMiddleware>],
145 index: usize,
146}
147
148impl<'a> std::fmt::Debug for Next<'a> {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("Next")
151 .field("index", &self.index)
152 .field(
153 "remaining_middlewares",
154 &(self.middlewares.len() - self.index),
155 )
156 .finish()
157 }
158}
159
160impl<'a> Next<'a> {
161 fn new(
162 handler: &'a dyn DynHandler,
163 middlewares: &'a [Arc<dyn McpMiddleware>],
164 index: usize,
165 ) -> Self {
166 Self {
167 handler,
168 middlewares,
169 index,
170 }
171 }
172
173 pub fn list_tools(self) -> Vec<Tool> {
175 if self.index < self.middlewares.len() {
176 self.handler.dyn_list_tools()
179 } else {
180 self.handler.dyn_list_tools()
181 }
182 }
183
184 pub fn list_resources(self) -> Vec<Resource> {
186 self.handler.dyn_list_resources()
187 }
188
189 pub fn list_prompts(self) -> Vec<Prompt> {
191 self.handler.dyn_list_prompts()
192 }
193
194 pub async fn call_tool(
196 self,
197 name: &str,
198 args: Value,
199 ctx: &RequestContext,
200 ) -> McpResult<ToolResult> {
201 if self.index < self.middlewares.len() {
202 let middleware = &self.middlewares[self.index];
203 let next = Next::new(self.handler, self.middlewares, self.index + 1);
204 middleware.on_call_tool(name, args, ctx, next).await
205 } else {
206 self.handler.dyn_call_tool(name, args, ctx).await
207 }
208 }
209
210 pub async fn read_resource(self, uri: &str, ctx: &RequestContext) -> McpResult<ResourceResult> {
212 if self.index < self.middlewares.len() {
213 let middleware = &self.middlewares[self.index];
214 let next = Next::new(self.handler, self.middlewares, self.index + 1);
215 middleware.on_read_resource(uri, ctx, next).await
216 } else {
217 self.handler.dyn_read_resource(uri, ctx).await
218 }
219 }
220
221 pub async fn get_prompt(
223 self,
224 name: &str,
225 args: Option<Value>,
226 ctx: &RequestContext,
227 ) -> McpResult<PromptResult> {
228 if self.index < self.middlewares.len() {
229 let middleware = &self.middlewares[self.index];
230 let next = Next::new(self.handler, self.middlewares, self.index + 1);
231 middleware.on_get_prompt(name, args, ctx, next).await
232 } else {
233 self.handler.dyn_get_prompt(name, args, ctx).await
234 }
235 }
236
237 pub async fn initialize(self) -> McpResult<()> {
239 if self.index < self.middlewares.len() {
240 let middleware = &self.middlewares[self.index];
241 let next = Next::new(self.handler, self.middlewares, self.index + 1);
242 middleware.on_initialize(next).await
243 } else {
244 self.handler.dyn_on_initialize().await
245 }
246 }
247
248 pub async fn shutdown(self) -> McpResult<()> {
250 if self.index < self.middlewares.len() {
251 let middleware = &self.middlewares[self.index];
252 let next = Next::new(self.handler, self.middlewares, self.index + 1);
253 middleware.on_shutdown(next).await
254 } else {
255 self.handler.dyn_on_shutdown().await
256 }
257 }
258}
259
260#[allow(dead_code)]
262trait DynHandler: Send + Sync {
263 fn dyn_server_info(&self) -> ServerInfo;
264 fn dyn_list_tools(&self) -> Vec<Tool>;
265 fn dyn_list_resources(&self) -> Vec<Resource>;
266 fn dyn_list_prompts(&self) -> Vec<Prompt>;
267 fn dyn_call_tool<'a>(
268 &'a self,
269 name: &'a str,
270 args: Value,
271 ctx: &'a RequestContext,
272 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ToolResult>> + Send + 'a>>;
273 fn dyn_read_resource<'a>(
274 &'a self,
275 uri: &'a str,
276 ctx: &'a RequestContext,
277 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ResourceResult>> + Send + 'a>>;
278 fn dyn_get_prompt<'a>(
279 &'a self,
280 name: &'a str,
281 args: Option<Value>,
282 ctx: &'a RequestContext,
283 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<PromptResult>> + Send + 'a>>;
284 fn dyn_on_initialize<'a>(
285 &'a self,
286 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>>;
287 fn dyn_on_shutdown<'a>(
288 &'a self,
289 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>>;
290}
291
292struct HandlerWrapper<H: McpHandler> {
294 handler: H,
295}
296
297impl<H: McpHandler> DynHandler for HandlerWrapper<H> {
298 fn dyn_server_info(&self) -> ServerInfo {
299 self.handler.server_info()
300 }
301
302 fn dyn_list_tools(&self) -> Vec<Tool> {
303 self.handler.list_tools()
304 }
305
306 fn dyn_list_resources(&self) -> Vec<Resource> {
307 self.handler.list_resources()
308 }
309
310 fn dyn_list_prompts(&self) -> Vec<Prompt> {
311 self.handler.list_prompts()
312 }
313
314 fn dyn_call_tool<'a>(
315 &'a self,
316 name: &'a str,
317 args: Value,
318 ctx: &'a RequestContext,
319 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ToolResult>> + Send + 'a>>
320 {
321 Box::pin(self.handler.call_tool(name, args, ctx))
322 }
323
324 fn dyn_read_resource<'a>(
325 &'a self,
326 uri: &'a str,
327 ctx: &'a RequestContext,
328 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<ResourceResult>> + Send + 'a>>
329 {
330 Box::pin(self.handler.read_resource(uri, ctx))
331 }
332
333 fn dyn_get_prompt<'a>(
334 &'a self,
335 name: &'a str,
336 args: Option<Value>,
337 ctx: &'a RequestContext,
338 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<PromptResult>> + Send + 'a>>
339 {
340 Box::pin(self.handler.get_prompt(name, args, ctx))
341 }
342
343 fn dyn_on_initialize<'a>(
344 &'a self,
345 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>> {
346 Box::pin(self.handler.on_initialize())
347 }
348
349 fn dyn_on_shutdown<'a>(
350 &'a self,
351 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<()>> + Send + 'a>> {
352 Box::pin(self.handler.on_shutdown())
353 }
354}
355
356pub struct MiddlewareStack<H: McpHandler> {
360 handler: Arc<HandlerWrapper<H>>,
361 middlewares: Arc<Vec<Arc<dyn McpMiddleware>>>,
362}
363
364impl<H: McpHandler> Clone for MiddlewareStack<H> {
365 fn clone(&self) -> Self {
366 Self {
367 handler: Arc::clone(&self.handler),
368 middlewares: Arc::clone(&self.middlewares),
369 }
370 }
371}
372
373impl<H: McpHandler> std::fmt::Debug for MiddlewareStack<H> {
374 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375 f.debug_struct("MiddlewareStack")
376 .field("middleware_count", &self.middlewares.len())
377 .finish()
378 }
379}
380
381impl<H: McpHandler> MiddlewareStack<H> {
382 pub fn new(handler: H) -> Self {
384 Self {
385 handler: Arc::new(HandlerWrapper { handler }),
386 middlewares: Arc::new(Vec::new()),
387 }
388 }
389
390 #[must_use]
394 pub fn with_middleware<M: McpMiddleware>(mut self, middleware: M) -> Self {
395 let middlewares = Arc::make_mut(&mut self.middlewares);
396 middlewares.push(Arc::new(middleware));
397 self
398 }
399
400 pub fn middleware_count(&self) -> usize {
402 self.middlewares.len()
403 }
404
405 fn next(&self) -> Next<'_> {
406 Next::new(self.handler.as_ref(), &self.middlewares, 0)
407 }
408}
409
410#[allow(clippy::manual_async_fn)]
411impl<H: McpHandler> McpHandler for MiddlewareStack<H> {
412 fn server_info(&self) -> ServerInfo {
413 self.handler.dyn_server_info()
414 }
415
416 fn list_tools(&self) -> Vec<Tool> {
417 self.handler.dyn_list_tools()
418 }
419
420 fn list_resources(&self) -> Vec<Resource> {
421 self.handler.dyn_list_resources()
422 }
423
424 fn list_prompts(&self) -> Vec<Prompt> {
425 self.handler.dyn_list_prompts()
426 }
427
428 fn call_tool<'a>(
429 &'a self,
430 name: &'a str,
431 args: Value,
432 ctx: &'a RequestContext,
433 ) -> impl std::future::Future<Output = McpResult<ToolResult>> + turbomcp_core::marker::MaybeSend + 'a
434 {
435 async move { self.next().call_tool(name, args, ctx).await }
436 }
437
438 fn read_resource<'a>(
439 &'a self,
440 uri: &'a str,
441 ctx: &'a RequestContext,
442 ) -> impl std::future::Future<Output = McpResult<ResourceResult>>
443 + turbomcp_core::marker::MaybeSend
444 + 'a {
445 async move { self.next().read_resource(uri, ctx).await }
446 }
447
448 fn get_prompt<'a>(
449 &'a self,
450 name: &'a str,
451 args: Option<Value>,
452 ctx: &'a RequestContext,
453 ) -> impl std::future::Future<Output = McpResult<PromptResult>> + turbomcp_core::marker::MaybeSend + 'a
454 {
455 async move { self.next().get_prompt(name, args, ctx).await }
456 }
457
458 fn on_initialize(
459 &self,
460 ) -> impl std::future::Future<Output = McpResult<()>> + turbomcp_core::marker::MaybeSend {
461 async move { self.next().initialize().await }
462 }
463
464 fn on_shutdown(
465 &self,
466 ) -> impl std::future::Future<Output = McpResult<()>> + turbomcp_core::marker::MaybeSend {
467 async move { self.next().shutdown().await }
468 }
469}
470
471#[cfg(test)]
472#[allow(clippy::manual_async_fn)]
473mod tests {
474 use super::*;
475 use std::sync::atomic::{AtomicU32, Ordering};
476 use turbomcp_core::error::McpError;
477 use turbomcp_core::marker::MaybeSend;
478
479 #[derive(Clone)]
480 struct TestHandler;
481
482 impl McpHandler for TestHandler {
483 fn server_info(&self) -> ServerInfo {
484 ServerInfo::new("test", "1.0.0")
485 }
486
487 fn list_tools(&self) -> Vec<Tool> {
488 vec![Tool::new("test_tool", "A test tool")]
489 }
490
491 fn list_resources(&self) -> Vec<Resource> {
492 vec![Resource::new("test://resource", "A test resource")]
493 }
494
495 fn list_prompts(&self) -> Vec<Prompt> {
496 vec![Prompt::new("test_prompt", "A test prompt")]
497 }
498
499 fn call_tool<'a>(
500 &'a self,
501 name: &'a str,
502 _args: Value,
503 _ctx: &'a RequestContext,
504 ) -> impl std::future::Future<Output = McpResult<ToolResult>> + MaybeSend + 'a {
505 async move {
506 match name {
507 "test_tool" => Ok(ToolResult::text("Test result")),
508 _ => Err(McpError::tool_not_found(name)),
509 }
510 }
511 }
512
513 fn read_resource<'a>(
514 &'a self,
515 uri: &'a str,
516 _ctx: &'a RequestContext,
517 ) -> impl std::future::Future<Output = McpResult<ResourceResult>> + MaybeSend + 'a {
518 let uri = uri.to_string();
519 async move {
520 if uri == "test://resource" {
521 Ok(ResourceResult::text(&uri, "Test content"))
522 } else {
523 Err(McpError::resource_not_found(&uri))
524 }
525 }
526 }
527
528 fn get_prompt<'a>(
529 &'a self,
530 name: &'a str,
531 _args: Option<Value>,
532 _ctx: &'a RequestContext,
533 ) -> impl std::future::Future<Output = McpResult<PromptResult>> + MaybeSend + 'a {
534 let name = name.to_string();
535 async move {
536 if name == "test_prompt" {
537 Ok(PromptResult::user("Test prompt message"))
538 } else {
539 Err(McpError::prompt_not_found(&name))
540 }
541 }
542 }
543 }
544
545 struct CountingMiddleware {
547 tool_calls: AtomicU32,
548 resource_reads: AtomicU32,
549 prompt_gets: AtomicU32,
550 initializes: AtomicU32,
551 shutdowns: AtomicU32,
552 }
553
554 impl CountingMiddleware {
555 fn new() -> Self {
556 Self {
557 tool_calls: AtomicU32::new(0),
558 resource_reads: AtomicU32::new(0),
559 prompt_gets: AtomicU32::new(0),
560 initializes: AtomicU32::new(0),
561 shutdowns: AtomicU32::new(0),
562 }
563 }
564
565 fn tool_calls(&self) -> u32 {
566 self.tool_calls.load(Ordering::Relaxed)
567 }
568
569 fn resource_reads(&self) -> u32 {
570 self.resource_reads.load(Ordering::Relaxed)
571 }
572
573 fn prompt_gets(&self) -> u32 {
574 self.prompt_gets.load(Ordering::Relaxed)
575 }
576
577 fn initializes(&self) -> u32 {
578 self.initializes.load(Ordering::Relaxed)
579 }
580
581 fn shutdowns(&self) -> u32 {
582 self.shutdowns.load(Ordering::Relaxed)
583 }
584 }
585
586 impl McpMiddleware for CountingMiddleware {
587 fn on_call_tool<'a>(
588 &'a self,
589 name: &'a str,
590 args: Value,
591 ctx: &'a RequestContext,
592 next: Next<'a>,
593 ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
594 Box::pin(async move {
595 self.tool_calls.fetch_add(1, Ordering::Relaxed);
596 next.call_tool(name, args, ctx).await
597 })
598 }
599
600 fn on_read_resource<'a>(
601 &'a self,
602 uri: &'a str,
603 ctx: &'a RequestContext,
604 next: Next<'a>,
605 ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
606 Box::pin(async move {
607 self.resource_reads.fetch_add(1, Ordering::Relaxed);
608 next.read_resource(uri, ctx).await
609 })
610 }
611
612 fn on_get_prompt<'a>(
613 &'a self,
614 name: &'a str,
615 args: Option<Value>,
616 ctx: &'a RequestContext,
617 next: Next<'a>,
618 ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
619 Box::pin(async move {
620 self.prompt_gets.fetch_add(1, Ordering::Relaxed);
621 next.get_prompt(name, args, ctx).await
622 })
623 }
624
625 fn on_initialize<'a>(
626 &'a self,
627 next: Next<'a>,
628 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
629 Box::pin(async move {
630 self.initializes.fetch_add(1, Ordering::Relaxed);
631 next.initialize().await
632 })
633 }
634
635 fn on_shutdown<'a>(
636 &'a self,
637 next: Next<'a>,
638 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
639 Box::pin(async move {
640 self.shutdowns.fetch_add(1, Ordering::Relaxed);
641 next.shutdown().await
642 })
643 }
644 }
645
646 struct BlockingMiddleware {
648 blocked_tools: Vec<String>,
649 }
650
651 impl BlockingMiddleware {
652 fn new(blocked: Vec<&str>) -> Self {
653 Self {
654 blocked_tools: blocked.into_iter().map(String::from).collect(),
655 }
656 }
657 }
658
659 impl McpMiddleware for BlockingMiddleware {
660 fn on_call_tool<'a>(
661 &'a self,
662 name: &'a str,
663 args: Value,
664 ctx: &'a RequestContext,
665 next: Next<'a>,
666 ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
667 Box::pin(async move {
668 if self.blocked_tools.contains(&name.to_string()) {
669 return Err(McpError::internal(format!("Tool '{}' is blocked", name)));
670 }
671 next.call_tool(name, args, ctx).await
672 })
673 }
674 }
675
676 #[test]
677 fn test_middleware_stack_creation() {
678 let stack = MiddlewareStack::new(TestHandler)
679 .with_middleware(CountingMiddleware::new())
680 .with_middleware(BlockingMiddleware::new(vec!["blocked"]));
681
682 assert_eq!(stack.middleware_count(), 2);
683 }
684
685 #[test]
686 fn test_server_info_passthrough() {
687 let stack = MiddlewareStack::new(TestHandler);
688 let info = stack.server_info();
689 assert_eq!(info.name, "test");
690 assert_eq!(info.version, "1.0.0");
691 }
692
693 #[test]
694 fn test_list_tools_passthrough() {
695 let stack = MiddlewareStack::new(TestHandler);
696 let tools = stack.list_tools();
697 assert_eq!(tools.len(), 1);
698 assert_eq!(tools[0].name, "test_tool");
699 }
700
701 #[tokio::test]
702 async fn test_call_tool_through_middleware() {
703 let counting = Arc::new(CountingMiddleware::new());
704 let stack =
705 MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
706
707 let ctx = RequestContext::default();
708 let result = stack
709 .call_tool("test_tool", serde_json::json!({}), &ctx)
710 .await
711 .unwrap();
712
713 assert_eq!(result.first_text(), Some("Test result"));
714 assert_eq!(counting.tool_calls(), 1);
715 }
716
717 #[tokio::test]
718 async fn test_blocking_middleware() {
719 let stack = MiddlewareStack::new(TestHandler)
720 .with_middleware(BlockingMiddleware::new(vec!["test_tool"]));
721
722 let ctx = RequestContext::default();
723 let result = stack
724 .call_tool("test_tool", serde_json::json!({}), &ctx)
725 .await;
726
727 assert!(result.is_err());
728 assert!(result.unwrap_err().to_string().contains("blocked"));
729 }
730
731 #[tokio::test]
732 async fn test_middleware_chain_order() {
733 let counting1 = Arc::new(CountingMiddleware::new());
734 let counting2 = Arc::new(CountingMiddleware::new());
735
736 let stack = MiddlewareStack::new(TestHandler)
737 .with_middleware(CountingClone(counting1.clone()))
738 .with_middleware(CountingClone(counting2.clone()));
739
740 let ctx = RequestContext::default();
741 stack
742 .call_tool("test_tool", serde_json::json!({}), &ctx)
743 .await
744 .unwrap();
745
746 assert_eq!(counting1.tool_calls(), 1);
748 assert_eq!(counting2.tool_calls(), 1);
749 }
750
751 #[tokio::test]
752 async fn test_read_resource_through_middleware() {
753 let counting = Arc::new(CountingMiddleware::new());
754 let stack =
755 MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
756
757 let ctx = RequestContext::default();
758 let result = stack.read_resource("test://resource", &ctx).await.unwrap();
759
760 assert!(!result.contents.is_empty());
761 assert_eq!(counting.resource_reads(), 1);
762 }
763
764 #[tokio::test]
765 async fn test_get_prompt_through_middleware() {
766 let counting = Arc::new(CountingMiddleware::new());
767 let stack =
768 MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
769
770 let ctx = RequestContext::default();
771 let result = stack.get_prompt("test_prompt", None, &ctx).await.unwrap();
772
773 assert!(!result.messages.is_empty());
774 assert_eq!(counting.prompt_gets(), 1);
775 }
776
777 struct CountingClone(Arc<CountingMiddleware>);
779
780 impl McpMiddleware for CountingClone {
781 fn on_call_tool<'a>(
782 &'a self,
783 name: &'a str,
784 args: Value,
785 ctx: &'a RequestContext,
786 next: Next<'a>,
787 ) -> Pin<Box<dyn Future<Output = McpResult<ToolResult>> + Send + 'a>> {
788 self.0.on_call_tool(name, args, ctx, next)
789 }
790
791 fn on_read_resource<'a>(
792 &'a self,
793 uri: &'a str,
794 ctx: &'a RequestContext,
795 next: Next<'a>,
796 ) -> Pin<Box<dyn Future<Output = McpResult<ResourceResult>> + Send + 'a>> {
797 self.0.on_read_resource(uri, ctx, next)
798 }
799
800 fn on_get_prompt<'a>(
801 &'a self,
802 name: &'a str,
803 args: Option<Value>,
804 ctx: &'a RequestContext,
805 next: Next<'a>,
806 ) -> Pin<Box<dyn Future<Output = McpResult<PromptResult>> + Send + 'a>> {
807 self.0.on_get_prompt(name, args, ctx, next)
808 }
809
810 fn on_initialize<'a>(
811 &'a self,
812 next: Next<'a>,
813 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
814 self.0.on_initialize(next)
815 }
816
817 fn on_shutdown<'a>(
818 &'a self,
819 next: Next<'a>,
820 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
821 self.0.on_shutdown(next)
822 }
823 }
824
825 #[tokio::test]
826 async fn test_on_initialize_through_middleware() {
827 let counting = Arc::new(CountingMiddleware::new());
828 let stack =
829 MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
830
831 stack.on_initialize().await.unwrap();
832
833 assert_eq!(counting.initializes(), 1);
834 }
835
836 #[tokio::test]
837 async fn test_on_shutdown_through_middleware() {
838 let counting = Arc::new(CountingMiddleware::new());
839 let stack =
840 MiddlewareStack::new(TestHandler).with_middleware(CountingClone(counting.clone()));
841
842 stack.on_shutdown().await.unwrap();
843
844 assert_eq!(counting.shutdowns(), 1);
845 }
846
847 #[tokio::test]
848 async fn test_lifecycle_hooks_chain_through_multiple_middlewares() {
849 let counting1 = Arc::new(CountingMiddleware::new());
850 let counting2 = Arc::new(CountingMiddleware::new());
851
852 let stack = MiddlewareStack::new(TestHandler)
853 .with_middleware(CountingClone(counting1.clone()))
854 .with_middleware(CountingClone(counting2.clone()));
855
856 stack.on_initialize().await.unwrap();
857 stack.on_shutdown().await.unwrap();
858
859 assert_eq!(counting1.initializes(), 1);
860 assert_eq!(counting2.initializes(), 1);
861 assert_eq!(counting1.shutdowns(), 1);
862 assert_eq!(counting2.shutdowns(), 1);
863 }
864
865 struct BlockInitMiddleware;
867
868 impl McpMiddleware for BlockInitMiddleware {
869 fn on_initialize<'a>(
870 &'a self,
871 _next: Next<'a>,
872 ) -> Pin<Box<dyn Future<Output = McpResult<()>> + Send + 'a>> {
873 Box::pin(async move { Err(McpError::internal("initialization blocked by middleware")) })
874 }
875 }
876
877 #[tokio::test]
878 async fn test_on_initialize_short_circuit() {
879 let stack = MiddlewareStack::new(TestHandler).with_middleware(BlockInitMiddleware);
880
881 let result = stack.on_initialize().await;
882 assert!(result.is_err());
883 assert!(result.unwrap_err().to_string().contains("blocked"));
884 }
885}