1use crate::mcp::{CallToolRequest, CallToolResult, McpError, McpResult};
4use serde_json::Value;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8#[derive(Debug, Clone)]
10pub struct MiddlewareContext {
11 pub request_id: String,
13 pub start_time: Instant,
15 pub metadata: std::collections::HashMap<String, Value>,
17}
18
19impl MiddlewareContext {
20 #[must_use]
22 pub fn new(request_id: String) -> Self {
23 Self {
24 request_id,
25 start_time: Instant::now(),
26 metadata: std::collections::HashMap::new(),
27 }
28 }
29
30 #[must_use]
32 pub fn elapsed(&self) -> Duration {
33 self.start_time.elapsed()
34 }
35
36 pub fn set_metadata(&mut self, key: String, value: Value) {
38 self.metadata.insert(key, value);
39 }
40
41 #[must_use]
43 pub fn get_metadata(&self, key: &str) -> Option<&Value> {
44 self.metadata.get(key)
45 }
46}
47
48#[derive(Debug)]
50pub enum MiddlewareResult {
51 Continue,
53 Stop(CallToolResult),
55 Error(McpError),
57}
58
59#[async_trait::async_trait]
61pub trait McpMiddleware: Send + Sync {
62 fn name(&self) -> &str;
64
65 fn priority(&self) -> i32 {
67 0
68 }
69
70 async fn before_request(
72 &self,
73 request: &CallToolRequest,
74 context: &mut MiddlewareContext,
75 ) -> McpResult<MiddlewareResult> {
76 let _ = (request, context);
77 Ok(MiddlewareResult::Continue)
78 }
79
80 async fn after_request(
82 &self,
83 request: &CallToolRequest,
84 response: &mut CallToolResult,
85 context: &mut MiddlewareContext,
86 ) -> McpResult<MiddlewareResult> {
87 let _ = (request, response, context);
88 Ok(MiddlewareResult::Continue)
89 }
90
91 async fn on_error(
93 &self,
94 request: &CallToolRequest,
95 error: &McpError,
96 context: &mut MiddlewareContext,
97 ) -> McpResult<MiddlewareResult> {
98 let _ = (request, error, context);
99 Ok(MiddlewareResult::Continue)
100 }
101}
102
103pub struct MiddlewareChain {
105 middlewares: Vec<Arc<dyn McpMiddleware>>,
106}
107
108impl MiddlewareChain {
109 #[must_use]
111 pub fn new() -> Self {
112 Self {
113 middlewares: Vec::new(),
114 }
115 }
116
117 #[must_use]
119 pub fn add_middleware<M: McpMiddleware + 'static>(mut self, middleware: M) -> Self {
120 self.middlewares.push(Arc::new(middleware));
121 self.sort_by_priority();
122 self
123 }
124
125 #[must_use]
127 pub fn add_arc(mut self, middleware: Arc<dyn McpMiddleware>) -> Self {
128 self.middlewares.push(middleware);
129 self.sort_by_priority();
130 self
131 }
132
133 fn sort_by_priority(&mut self) {
135 self.middlewares.sort_by_key(|m| m.priority());
136 }
137
138 pub async fn execute<F, Fut>(
147 &self,
148 request: CallToolRequest,
149 handler: F,
150 ) -> McpResult<CallToolResult>
151 where
152 F: FnOnce(CallToolRequest) -> Fut,
153 Fut: std::future::Future<Output = McpResult<CallToolResult>> + Send,
154 {
155 let request_id = uuid::Uuid::new_v4().to_string();
156 let mut context = MiddlewareContext::new(request_id);
157
158 for middleware in &self.middlewares {
160 match middleware.before_request(&request, &mut context).await? {
161 MiddlewareResult::Continue => {}
162 MiddlewareResult::Stop(result) => return Ok(result),
163 MiddlewareResult::Error(error) => return Err(error),
164 }
165 }
166
167 let request_clone = request.clone();
169
170 let mut result = match handler(request).await {
172 Ok(response) => response,
173 Err(error) => {
174 for middleware in &self.middlewares {
176 match middleware
177 .on_error(&request_clone, &error, &mut context)
178 .await?
179 {
180 MiddlewareResult::Continue => {}
181 MiddlewareResult::Stop(result) => return Ok(result),
182 MiddlewareResult::Error(middleware_error) => return Err(middleware_error),
183 }
184 }
185 return Err(error);
186 }
187 };
188
189 for middleware in &self.middlewares {
191 match middleware
192 .after_request(&request_clone, &mut result, &mut context)
193 .await?
194 {
195 MiddlewareResult::Continue => {}
196 MiddlewareResult::Stop(new_result) => return Ok(new_result),
197 MiddlewareResult::Error(error) => return Err(error),
198 }
199 }
200
201 Ok(result)
202 }
203
204 #[must_use]
206 pub fn len(&self) -> usize {
207 self.middlewares.len()
208 }
209
210 #[must_use]
212 pub fn is_empty(&self) -> bool {
213 self.middlewares.is_empty()
214 }
215}
216
217impl Default for MiddlewareChain {
218 fn default() -> Self {
219 Self::new()
220 }
221}
222
223mod auth;
224mod config;
225mod logging;
226mod performance;
227mod rate_limit;
228mod validation;
229
230pub use auth::{ApiKeyInfo, AuthenticationMiddleware, JwtClaims, OAuthConfig};
231pub use config::{
232 ApiKeyConfig, AuthenticationConfig, LoggingConfig, MiddlewareConfig, OAuth2Config,
233 PerformanceConfig, RateLimitingConfig, SecurityConfig, ValidationConfig,
234};
235pub use logging::{LogLevel, LoggingMiddleware};
236pub use performance::PerformanceMiddleware;
237pub use rate_limit::RateLimitMiddleware;
238pub use validation::ValidationMiddleware;
239
240#[derive(Debug, thiserror::Error)]
241pub enum MiddlewareError {
242 #[error("Middleware execution failed: {message}")]
243 ExecutionFailed { message: String },
244
245 #[error("Middleware configuration error: {message}")]
246 ConfigurationError { message: String },
247
248 #[error("Middleware chain error: {message}")]
249 ChainError { message: String },
250}
251
252impl From<MiddlewareError> for McpError {
253 fn from(error: MiddlewareError) -> Self {
254 McpError::internal_error(error.to_string())
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::mcp::Content;
262 use std::collections::HashMap;
263
264 struct TestMiddleware {
265 priority: i32,
266 }
267
268 #[async_trait::async_trait]
269 impl McpMiddleware for TestMiddleware {
270 fn name(&self) -> &'static str {
271 "test_middleware"
272 }
273
274 fn priority(&self) -> i32 {
275 self.priority
276 }
277 }
278
279 #[tokio::test]
280 async fn test_middleware_chain_creation() {
281 let chain = MiddlewareChain::new()
282 .add_middleware(TestMiddleware { priority: 100 })
283 .add_middleware(TestMiddleware { priority: 50 });
284
285 assert_eq!(chain.len(), 2);
286 assert!(!chain.is_empty());
287 }
288
289 #[tokio::test]
290 async fn test_middleware_priority_ordering() {
291 let chain = MiddlewareChain::new()
292 .add_middleware(TestMiddleware { priority: 10 })
293 .add_middleware(TestMiddleware { priority: 100 });
294
295 assert_eq!(chain.len(), 2);
297 }
298
299 #[tokio::test]
300 async fn test_middleware_execution() {
301 let chain = MiddlewareChain::new()
302 .add_middleware(LoggingMiddleware::info())
303 .add_middleware(ValidationMiddleware::lenient());
304
305 let request = CallToolRequest {
306 name: "test_tool".to_string(),
307 arguments: Some(serde_json::json!({"param": "value"})),
308 };
309
310 let handler = |_req: CallToolRequest| {
311 Box::pin(async move {
312 Ok(CallToolResult {
313 content: vec![Content::Text {
314 text: "Test response".to_string(),
315 }],
316 is_error: false,
317 })
318 })
319 };
320
321 let result = chain.execute(request, handler).await;
322 assert!(result.is_ok());
323 }
324
325 #[tokio::test]
326 async fn test_validation_middleware() {
327 let middleware = ValidationMiddleware::strict();
328 let mut context = MiddlewareContext::new("test".to_string());
329
330 let valid_request = CallToolRequest {
332 name: "valid_tool".to_string(),
333 arguments: Some(serde_json::json!({"param": "value"})),
334 };
335
336 let result = middleware
337 .before_request(&valid_request, &mut context)
338 .await;
339 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
340
341 let invalid_request = CallToolRequest {
343 name: String::new(),
344 arguments: None,
345 };
346
347 let result = middleware
348 .before_request(&invalid_request, &mut context)
349 .await;
350 assert!(matches!(result, Ok(MiddlewareResult::Error(_))));
351 }
352
353 #[tokio::test]
354 async fn test_performance_middleware() {
355 let middleware = PerformanceMiddleware::with_threshold(Duration::from_millis(100));
356 let mut context = MiddlewareContext::new("test".to_string());
357
358 tokio::time::sleep(Duration::from_millis(150)).await;
360
361 let mut response = CallToolResult {
362 content: vec![Content::Text {
363 text: "Test".to_string(),
364 }],
365 is_error: false,
366 };
367
368 let request = CallToolRequest {
369 name: "test".to_string(),
370 arguments: None,
371 };
372
373 let result = middleware
374 .after_request(&request, &mut response, &mut context)
375 .await;
376 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
377
378 assert!(context.get_metadata("duration_ms").is_some());
380 assert!(context.get_metadata("is_slow").is_some());
381 }
382
383 #[tokio::test]
384 async fn test_middleware_config() {
385 let config = MiddlewareConfig {
386 logging: LoggingConfig {
387 enabled: true,
388 level: "debug".to_string(),
389 },
390 validation: ValidationConfig {
391 enabled: true,
392 strict_mode: true,
393 },
394 performance: PerformanceConfig {
395 enabled: true,
396 slow_request_threshold_ms: 500,
397 },
398 security: SecurityConfig::default(),
399 };
400
401 let chain = config.build_chain();
402 assert!(!chain.is_empty());
403 assert!(chain.len() >= 3); }
405
406 #[tokio::test]
407 async fn test_middleware_context_creation() {
408 let context = MiddlewareContext::new("test-request-123".to_string());
409 assert_eq!(context.request_id, "test-request-123");
410 assert!(context.metadata.is_empty());
411 }
412
413 #[tokio::test]
414 async fn test_middleware_context_elapsed() {
415 let context = MiddlewareContext::new("test-request-123".to_string());
416 std::thread::sleep(std::time::Duration::from_millis(10));
417 let elapsed = context.elapsed();
418 assert!(elapsed.as_millis() >= 10);
419 }
420
421 #[tokio::test]
422 async fn test_middleware_context_metadata() {
423 let mut context = MiddlewareContext::new("test-request-123".to_string());
424
425 context.set_metadata(
427 "key1".to_string(),
428 serde_json::Value::String("value1".to_string()),
429 );
430 context.set_metadata(
431 "key2".to_string(),
432 serde_json::Value::Number(serde_json::Number::from(42)),
433 );
434
435 assert_eq!(
437 context.get_metadata("key1"),
438 Some(&serde_json::Value::String("value1".to_string()))
439 );
440 assert_eq!(
441 context.get_metadata("key2"),
442 Some(&serde_json::Value::Number(serde_json::Number::from(42)))
443 );
444 assert_eq!(context.get_metadata("nonexistent"), None);
445 }
446
447 #[tokio::test]
448 async fn test_middleware_result_variants() {
449 let continue_result = MiddlewareResult::Continue;
450 let stop_result = MiddlewareResult::Stop(CallToolResult {
451 content: vec![Content::Text {
452 text: "test".to_string(),
453 }],
454 is_error: false,
455 });
456 let error_result = MiddlewareResult::Error(McpError::tool_not_found("test error"));
457
458 match continue_result {
460 MiddlewareResult::Continue => {}
461 _ => panic!("Expected Continue"),
462 }
463
464 match stop_result {
465 MiddlewareResult::Stop(_) => {}
466 _ => panic!("Expected Stop"),
467 }
468
469 match error_result {
470 MiddlewareResult::Error(_) => {}
471 _ => panic!("Expected Error"),
472 }
473 }
474
475 #[tokio::test]
476 async fn test_logging_middleware_different_levels() {
477 let debug_middleware = LoggingMiddleware::new(LogLevel::Debug);
478 let info_middleware = LoggingMiddleware::new(LogLevel::Info);
479 let warn_middleware = LoggingMiddleware::new(LogLevel::Warn);
480 let error_middleware = LoggingMiddleware::new(LogLevel::Error);
481
482 assert_eq!(debug_middleware.name(), "logging");
483 assert_eq!(info_middleware.name(), "logging");
484 assert_eq!(warn_middleware.name(), "logging");
485 assert_eq!(error_middleware.name(), "logging");
486 }
487
488 #[tokio::test]
489 async fn test_logging_middleware_should_log() {
490 let debug_middleware = LoggingMiddleware::new(LogLevel::Debug);
491 let info_middleware = LoggingMiddleware::new(LogLevel::Info);
492 let warn_middleware = LoggingMiddleware::new(LogLevel::Warn);
493 let error_middleware = LoggingMiddleware::new(LogLevel::Error);
494
495 assert!(debug_middleware.should_log(LogLevel::Debug));
497 assert!(debug_middleware.should_log(LogLevel::Info));
498 assert!(debug_middleware.should_log(LogLevel::Warn));
499 assert!(debug_middleware.should_log(LogLevel::Error));
500
501 assert!(!info_middleware.should_log(LogLevel::Debug));
503 assert!(info_middleware.should_log(LogLevel::Info));
504 assert!(info_middleware.should_log(LogLevel::Warn));
505 assert!(info_middleware.should_log(LogLevel::Error));
506
507 assert!(!warn_middleware.should_log(LogLevel::Debug));
509 assert!(!warn_middleware.should_log(LogLevel::Info));
510 assert!(warn_middleware.should_log(LogLevel::Warn));
511 assert!(warn_middleware.should_log(LogLevel::Error));
512
513 assert!(!error_middleware.should_log(LogLevel::Debug));
515 assert!(!error_middleware.should_log(LogLevel::Info));
516 assert!(!error_middleware.should_log(LogLevel::Warn));
517 assert!(error_middleware.should_log(LogLevel::Error));
518 }
519
520 #[tokio::test]
521 async fn test_validation_middleware_strict_mode() {
522 let strict_middleware = ValidationMiddleware::strict();
523 let lenient_middleware = ValidationMiddleware::lenient();
524
525 assert_eq!(strict_middleware.name(), "validation");
526 assert_eq!(lenient_middleware.name(), "validation");
527 }
528
529 #[tokio::test]
530 async fn test_validation_middleware_creation() {
531 let middleware1 = ValidationMiddleware::new(true);
532 let middleware2 = ValidationMiddleware::new(false);
533
534 assert_eq!(middleware1.name(), "validation");
535 assert_eq!(middleware2.name(), "validation");
536 }
537
538 #[tokio::test]
539 async fn test_performance_middleware_creation() {
540 let middleware1 = PerformanceMiddleware::new(Duration::from_millis(100));
541 let middleware2 = PerformanceMiddleware::with_threshold(Duration::from_millis(200));
542 let middleware3 = PerformanceMiddleware::create_default();
543
544 assert_eq!(middleware1.name(), "performance");
545 assert_eq!(middleware2.name(), "performance");
546 assert_eq!(middleware3.name(), "performance");
547 }
548
549 #[tokio::test]
550 async fn test_middleware_chain_empty() {
551 let chain = MiddlewareChain::new();
552 assert!(chain.is_empty());
553 assert_eq!(chain.len(), 0);
554 }
555
556 #[tokio::test]
557 async fn test_middleware_chain_add_middleware() {
558 let chain = MiddlewareChain::new()
559 .add_middleware(LoggingMiddleware::new(LogLevel::Info))
560 .add_middleware(ValidationMiddleware::new(false));
561
562 assert!(!chain.is_empty());
563 assert_eq!(chain.len(), 2);
564 }
565
566 #[tokio::test]
567 async fn test_middleware_chain_add_arc() {
568 let middleware = Arc::new(LoggingMiddleware::new(LogLevel::Info)) as Arc<dyn McpMiddleware>;
569 let chain = MiddlewareChain::new().add_arc(middleware);
570
571 assert!(!chain.is_empty());
572 assert_eq!(chain.len(), 1);
573 }
574
575 #[tokio::test]
576 async fn test_middleware_chain_execution_with_empty_chain() {
577 let chain = MiddlewareChain::new();
578 let request = CallToolRequest {
579 name: "test_tool".to_string(),
580 arguments: None,
581 };
582
583 let result = chain
584 .execute(request, |_| async {
585 Ok(CallToolResult {
586 content: vec![Content::Text {
587 text: "success".to_string(),
588 }],
589 is_error: false,
590 })
591 })
592 .await;
593
594 assert!(result.is_ok());
595 let result = result.unwrap();
596 assert!(!result.is_error);
597 assert_eq!(result.content.len(), 1);
598 }
599
600 #[tokio::test]
601 async fn test_middleware_chain_execution_with_error() {
602 let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
603 let request = CallToolRequest {
604 name: "test_tool".to_string(),
605 arguments: None,
606 };
607
608 let result = chain
609 .execute(request, |_| async {
610 Err(McpError::tool_not_found("test error"))
611 })
612 .await;
613
614 assert!(result.is_err());
615 }
616
617 #[tokio::test]
618 async fn test_middleware_chain_execution_with_stop() {
619 struct StopMiddleware;
621 #[async_trait::async_trait]
622 impl McpMiddleware for StopMiddleware {
623 fn name(&self) -> &'static str {
624 "stop"
625 }
626
627 async fn before_request(
628 &self,
629 _request: &CallToolRequest,
630 _context: &mut MiddlewareContext,
631 ) -> McpResult<MiddlewareResult> {
632 Ok(MiddlewareResult::Stop(CallToolResult {
633 content: vec![Content::Text {
634 text: "stopped".to_string(),
635 }],
636 is_error: false,
637 }))
638 }
639
640 async fn after_request(
641 &self,
642 _request: &CallToolRequest,
643 _result: &mut CallToolResult,
644 _context: &mut MiddlewareContext,
645 ) -> McpResult<MiddlewareResult> {
646 Ok(MiddlewareResult::Continue)
647 }
648
649 async fn on_error(
650 &self,
651 _request: &CallToolRequest,
652 _error: &McpError,
653 _context: &mut MiddlewareContext,
654 ) -> McpResult<MiddlewareResult> {
655 Ok(MiddlewareResult::Continue)
656 }
657 }
658
659 let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
660 let request = CallToolRequest {
661 name: "test_tool".to_string(),
662 arguments: None,
663 };
664
665 let chain = chain.add_middleware(StopMiddleware);
666
667 let result = chain
668 .execute(request, |_| async {
669 Ok(CallToolResult {
670 content: vec![Content::Text {
671 text: "should not reach here".to_string(),
672 }],
673 is_error: false,
674 })
675 })
676 .await;
677
678 assert!(result.is_ok());
679 let result = result.unwrap();
680 let Content::Text { text } = &result.content[0];
681 assert_eq!(text, "stopped");
682 }
683
684 #[tokio::test]
685 async fn test_middleware_chain_execution_with_middleware_error() {
686 struct ErrorMiddleware;
688 #[async_trait::async_trait]
689 impl McpMiddleware for ErrorMiddleware {
690 fn name(&self) -> &'static str {
691 "error"
692 }
693
694 async fn before_request(
695 &self,
696 _request: &CallToolRequest,
697 _context: &mut MiddlewareContext,
698 ) -> McpResult<MiddlewareResult> {
699 Err(McpError::tool_not_found("middleware error"))
700 }
701
702 async fn after_request(
703 &self,
704 _request: &CallToolRequest,
705 _result: &mut CallToolResult,
706 _context: &mut MiddlewareContext,
707 ) -> McpResult<MiddlewareResult> {
708 Ok(MiddlewareResult::Continue)
709 }
710
711 async fn on_error(
712 &self,
713 _request: &CallToolRequest,
714 _error: &McpError,
715 _context: &mut MiddlewareContext,
716 ) -> McpResult<MiddlewareResult> {
717 Ok(MiddlewareResult::Continue)
718 }
719 }
720
721 let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
722 let request = CallToolRequest {
723 name: "test_tool".to_string(),
724 arguments: None,
725 };
726
727 let chain = chain.add_middleware(ErrorMiddleware);
728
729 let result = chain
730 .execute(request, |_| async {
731 Ok(CallToolResult {
732 content: vec![Content::Text {
733 text: "should not reach here".to_string(),
734 }],
735 is_error: false,
736 })
737 })
738 .await;
739
740 assert!(result.is_err());
741 let error = result.unwrap_err();
742 assert!(matches!(error, McpError::ToolNotFound { tool_name: _ }));
743 }
744
745 #[tokio::test]
746 async fn test_middleware_chain_execution_with_on_error() {
747 struct ErrorHandlerMiddleware;
749 #[async_trait::async_trait]
750 impl McpMiddleware for ErrorHandlerMiddleware {
751 fn name(&self) -> &'static str {
752 "error_handler"
753 }
754
755 async fn before_request(
756 &self,
757 _request: &CallToolRequest,
758 _context: &mut MiddlewareContext,
759 ) -> McpResult<MiddlewareResult> {
760 Ok(MiddlewareResult::Continue)
761 }
762
763 async fn after_request(
764 &self,
765 _request: &CallToolRequest,
766 _result: &mut CallToolResult,
767 _context: &mut MiddlewareContext,
768 ) -> McpResult<MiddlewareResult> {
769 Ok(MiddlewareResult::Continue)
770 }
771
772 async fn on_error(
773 &self,
774 _request: &CallToolRequest,
775 _error: &McpError,
776 _context: &mut MiddlewareContext,
777 ) -> McpResult<MiddlewareResult> {
778 Ok(MiddlewareResult::Stop(CallToolResult {
779 content: vec![Content::Text {
780 text: "error handled".to_string(),
781 }],
782 is_error: false,
783 }))
784 }
785 }
786
787 let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
788 let request = CallToolRequest {
789 name: "test_tool".to_string(),
790 arguments: None,
791 };
792
793 let chain = chain.add_middleware(ErrorHandlerMiddleware);
794
795 let result = chain
796 .execute(request, |_| async {
797 Err(McpError::tool_not_found("test error"))
798 })
799 .await;
800
801 assert!(result.is_ok());
802 let result = result.unwrap();
803 let Content::Text { text } = &result.content[0];
804 assert_eq!(text, "error handled");
805 }
806
807 #[tokio::test]
808 async fn test_config_structs_creation() {
809 let logging_config = LoggingConfig {
810 enabled: true,
811 level: "debug".to_string(),
812 };
813 let validation_config = ValidationConfig {
814 enabled: true,
815 strict_mode: true,
816 };
817 let performance_config = PerformanceConfig {
818 enabled: true,
819 slow_request_threshold_ms: 1000,
820 };
821
822 assert!(logging_config.enabled);
823 assert_eq!(logging_config.level, "debug");
824 assert!(validation_config.enabled);
825 assert!(validation_config.strict_mode);
826 assert!(performance_config.enabled);
827 assert_eq!(performance_config.slow_request_threshold_ms, 1000);
828 }
829
830 #[tokio::test]
831 async fn test_config_default() {
832 let config = MiddlewareConfig::default();
833 assert!(config.logging.enabled);
834 assert_eq!(config.logging.level, "info");
835 assert!(config.validation.enabled);
836 assert!(!config.validation.strict_mode);
837 assert!(config.performance.enabled);
838 assert_eq!(config.performance.slow_request_threshold_ms, 1000);
839 }
840
841 #[tokio::test]
842 async fn test_config_build_chain_with_disabled_middleware() {
843 let config = MiddlewareConfig {
844 logging: LoggingConfig {
845 enabled: false,
846 level: "debug".to_string(),
847 },
848 validation: ValidationConfig {
849 enabled: false,
850 strict_mode: true,
851 },
852 performance: PerformanceConfig {
853 enabled: false,
854 slow_request_threshold_ms: 1000,
855 },
856 security: SecurityConfig {
857 authentication: AuthenticationConfig {
858 enabled: false,
859 require_auth: false,
860 jwt_secret: "test".to_string(),
861 api_keys: vec![],
862 oauth: None,
863 },
864 rate_limiting: RateLimitingConfig {
865 enabled: false,
866 requests_per_minute: 60,
867 burst_limit: 10,
868 custom_limits: None,
869 },
870 },
871 };
872
873 let chain = config.build_chain();
874 assert!(chain.is_empty());
875 }
876
877 #[tokio::test]
878 async fn test_config_build_chain_with_partial_middleware() {
879 let config = MiddlewareConfig {
880 logging: LoggingConfig {
881 enabled: true,
882 level: "debug".to_string(),
883 },
884 validation: ValidationConfig {
885 enabled: false,
886 strict_mode: true,
887 },
888 performance: PerformanceConfig {
889 enabled: true,
890 slow_request_threshold_ms: 1000,
891 },
892 security: SecurityConfig::default(),
893 };
894
895 let chain = config.build_chain();
896 assert!(!chain.is_empty());
897 assert!(chain.len() >= 2); }
899
900 #[tokio::test]
901 async fn test_config_build_chain_with_invalid_log_level() {
902 let config = MiddlewareConfig {
903 logging: LoggingConfig {
904 enabled: true,
905 level: "invalid".to_string(),
906 },
907 validation: ValidationConfig {
908 enabled: true,
909 strict_mode: true,
910 },
911 performance: PerformanceConfig {
912 enabled: true,
913 slow_request_threshold_ms: 1000,
914 },
915 security: SecurityConfig::default(),
916 };
917
918 let chain = config.build_chain();
919 assert!(!chain.is_empty());
920 }
922
923 #[tokio::test]
924 async fn test_middleware_chain_execution_with_empty_middleware() {
925 let chain = MiddlewareChain::new();
926 let request = CallToolRequest {
927 name: "test_tool".to_string(),
928 arguments: Some(serde_json::json!({"param": "value"})),
929 };
930
931 let result = chain
932 .execute(request, |_| async {
933 Ok(CallToolResult {
934 content: vec![Content::Text {
935 text: "Test response".to_string(),
936 }],
937 is_error: false,
938 })
939 })
940 .await;
941
942 assert!(result.is_ok());
943 let result = result.unwrap();
944 assert!(!result.is_error);
945 assert_eq!(result.content.len(), 1);
946 }
947
948 #[tokio::test]
949 async fn test_middleware_chain_execution_with_multiple_middleware() {
950 let chain = MiddlewareChain::new()
951 .add_middleware(LoggingMiddleware::new(LogLevel::Info))
952 .add_middleware(ValidationMiddleware::new(false))
953 .add_middleware(PerformanceMiddleware::new(Duration::from_millis(100)));
954
955 let request = CallToolRequest {
956 name: "test_tool".to_string(),
957 arguments: Some(serde_json::json!({"param": "value"})),
958 };
959
960 let result = chain
961 .execute(request, |_| async {
962 Ok(CallToolResult {
963 content: vec![Content::Text {
964 text: "Test response".to_string(),
965 }],
966 is_error: false,
967 })
968 })
969 .await;
970
971 assert!(result.is_ok());
972 let result = result.unwrap();
973 assert!(!result.is_error);
974 assert_eq!(result.content.len(), 1);
975 }
976
977 #[tokio::test]
978 async fn test_middleware_chain_execution_with_middleware_stop() {
979 struct StopMiddleware;
980 #[async_trait::async_trait]
981 impl McpMiddleware for StopMiddleware {
982 fn name(&self) -> &'static str {
983 "stop_middleware"
984 }
985
986 fn priority(&self) -> i32 {
987 100
988 }
989
990 async fn before_request(
991 &self,
992 _request: &CallToolRequest,
993 _context: &mut MiddlewareContext,
994 ) -> McpResult<MiddlewareResult> {
995 Ok(MiddlewareResult::Stop(CallToolResult {
996 content: vec![Content::Text {
997 text: "Stopped by middleware".to_string(),
998 }],
999 is_error: false,
1000 }))
1001 }
1002
1003 async fn after_request(
1004 &self,
1005 _request: &CallToolRequest,
1006 _result: &mut CallToolResult,
1007 _context: &mut MiddlewareContext,
1008 ) -> McpResult<MiddlewareResult> {
1009 Ok(MiddlewareResult::Continue)
1010 }
1011
1012 async fn on_error(
1013 &self,
1014 _request: &CallToolRequest,
1015 _error: &McpError,
1016 _context: &mut MiddlewareContext,
1017 ) -> McpResult<MiddlewareResult> {
1018 Ok(MiddlewareResult::Continue)
1019 }
1020 }
1021
1022 let chain = MiddlewareChain::new().add_middleware(StopMiddleware);
1023
1024 let request = CallToolRequest {
1025 name: "test_tool".to_string(),
1026 arguments: None,
1027 };
1028
1029 let result = chain
1030 .execute(request, |_| async {
1031 Ok(CallToolResult {
1032 content: vec![Content::Text {
1033 text: "Should not reach here".to_string(),
1034 }],
1035 is_error: false,
1036 })
1037 })
1038 .await;
1039
1040 assert!(result.is_ok());
1041 let result = result.unwrap();
1042 assert!(!result.is_error);
1043 let Content::Text { text } = &result.content[0];
1044 assert_eq!(text, "Stopped by middleware");
1045 }
1046
1047 #[tokio::test]
1048 async fn test_middleware_chain_execution_with_middleware_error_duplicate() {
1049 struct ErrorMiddleware;
1050 #[async_trait::async_trait]
1051 impl McpMiddleware for ErrorMiddleware {
1052 fn name(&self) -> &'static str {
1053 "error_middleware"
1054 }
1055
1056 fn priority(&self) -> i32 {
1057 100
1058 }
1059
1060 async fn before_request(
1061 &self,
1062 _request: &CallToolRequest,
1063 _context: &mut MiddlewareContext,
1064 ) -> McpResult<MiddlewareResult> {
1065 Err(McpError::internal_error("Middleware error"))
1066 }
1067
1068 async fn after_request(
1069 &self,
1070 _request: &CallToolRequest,
1071 _result: &mut CallToolResult,
1072 _context: &mut MiddlewareContext,
1073 ) -> McpResult<MiddlewareResult> {
1074 Ok(MiddlewareResult::Continue)
1075 }
1076
1077 async fn on_error(
1078 &self,
1079 _request: &CallToolRequest,
1080 _error: &McpError,
1081 _context: &mut MiddlewareContext,
1082 ) -> McpResult<MiddlewareResult> {
1083 Ok(MiddlewareResult::Continue)
1084 }
1085 }
1086
1087 let chain = MiddlewareChain::new().add_middleware(ErrorMiddleware);
1088
1089 let request = CallToolRequest {
1090 name: "test_tool".to_string(),
1091 arguments: None,
1092 };
1093
1094 let result = chain
1095 .execute(request, |_| async {
1096 Ok(CallToolResult {
1097 content: vec![Content::Text {
1098 text: "Should not reach here".to_string(),
1099 }],
1100 is_error: false,
1101 })
1102 })
1103 .await;
1104
1105 assert!(result.is_err());
1106 let error = result.unwrap_err();
1107 assert!(matches!(error, McpError::InternalError { .. }));
1108 }
1109
1110 #[tokio::test]
1112 async fn test_authentication_middleware_permissive() {
1113 let middleware = AuthenticationMiddleware::permissive();
1114 let mut context = MiddlewareContext::new("test".to_string());
1115
1116 let request = CallToolRequest {
1117 name: "test_tool".to_string(),
1118 arguments: None,
1119 };
1120
1121 let result = middleware.before_request(&request, &mut context).await;
1122
1123 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1124 assert_eq!(
1125 context.get_metadata("auth_required"),
1126 Some(&Value::Bool(false))
1127 );
1128 }
1129
1130 #[tokio::test]
1131 async fn test_authentication_middleware_with_valid_api_key() {
1132 let mut api_keys = HashMap::new();
1133 api_keys.insert(
1134 "test-api-key".to_string(),
1135 ApiKeyInfo {
1136 key_id: "test-key-1".to_string(),
1137 permissions: vec!["read".to_string(), "write".to_string()],
1138 expires_at: None,
1139 },
1140 );
1141
1142 let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1143 let mut context = MiddlewareContext::new("test".to_string());
1144
1145 let request = CallToolRequest {
1146 name: "test_tool".to_string(),
1147 arguments: Some(serde_json::json!({
1148 "api_key": "test-api-key"
1149 })),
1150 };
1151
1152 let result = middleware.before_request(&request, &mut context).await;
1153
1154 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1155 assert_eq!(
1156 context.get_metadata("auth_type"),
1157 Some(&Value::String("api_key".to_string()))
1158 );
1159 assert_eq!(
1160 context.get_metadata("auth_key_id"),
1161 Some(&Value::String("test-key-1".to_string()))
1162 );
1163 }
1164
1165 #[tokio::test]
1166 async fn test_authentication_middleware_with_invalid_api_key() {
1167 let api_keys = HashMap::new();
1168 let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1169 let mut context = MiddlewareContext::new("test".to_string());
1170
1171 let request = CallToolRequest {
1172 name: "test_tool".to_string(),
1173 arguments: Some(serde_json::json!({
1174 "api_key": "invalid-key"
1175 })),
1176 };
1177
1178 let result = middleware.before_request(&request, &mut context).await;
1179
1180 assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
1181 }
1182
1183 #[tokio::test]
1184 async fn test_authentication_middleware_with_valid_jwt() {
1185 let api_keys = HashMap::new();
1186 let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1187
1188 let jwt_token = middleware.generate_test_jwt("user123", vec!["read".to_string()]);
1190
1191 let mut context = MiddlewareContext::new("test".to_string());
1192
1193 let request = CallToolRequest {
1194 name: "test_tool".to_string(),
1195 arguments: Some(serde_json::json!({
1196 "jwt_token": jwt_token
1197 })),
1198 };
1199
1200 let result = middleware.before_request(&request, &mut context).await;
1201
1202 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1203 assert_eq!(
1204 context.get_metadata("auth_type"),
1205 Some(&Value::String("jwt".to_string()))
1206 );
1207 assert_eq!(
1208 context.get_metadata("auth_user_id"),
1209 Some(&Value::String("user123".to_string()))
1210 );
1211 }
1212
1213 #[tokio::test]
1214 async fn test_authentication_middleware_with_invalid_jwt() {
1215 let api_keys = HashMap::new();
1216 let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1217 let mut context = MiddlewareContext::new("test".to_string());
1218
1219 let request = CallToolRequest {
1220 name: "test_tool".to_string(),
1221 arguments: Some(serde_json::json!({
1222 "jwt_token": "invalid.jwt.token"
1223 })),
1224 };
1225
1226 let result = middleware.before_request(&request, &mut context).await;
1227
1228 assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
1229 }
1230
1231 #[tokio::test]
1232 async fn test_authentication_middleware_no_auth_provided() {
1233 let api_keys = HashMap::new();
1234 let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1235 let mut context = MiddlewareContext::new("test".to_string());
1236
1237 let request = CallToolRequest {
1238 name: "test_tool".to_string(),
1239 arguments: None,
1240 };
1241
1242 let result = middleware.before_request(&request, &mut context).await;
1243
1244 assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
1245 }
1246
1247 #[tokio::test]
1249 async fn test_rate_limit_middleware_allows_request() {
1250 let middleware = RateLimitMiddleware::new(10, 5);
1251 let mut context = MiddlewareContext::new("test".to_string());
1252
1253 let request = CallToolRequest {
1254 name: "test_tool".to_string(),
1255 arguments: Some(serde_json::json!({
1256 "client_id": "test-client"
1257 })),
1258 };
1259
1260 let result = middleware.before_request(&request, &mut context).await;
1261
1262 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1263 assert_eq!(
1264 context.get_metadata("rate_limit_client_id"),
1265 Some(&Value::String("client:test-client".to_string()))
1266 );
1267 }
1268
1269 #[tokio::test]
1270 async fn test_rate_limit_middleware_uses_auth_context() {
1271 let middleware = RateLimitMiddleware::new(10, 5);
1272 let mut context = MiddlewareContext::new("test".to_string());
1273
1274 context.set_metadata(
1276 "auth_key_id".to_string(),
1277 Value::String("api-key-123".to_string()),
1278 );
1279
1280 let request = CallToolRequest {
1281 name: "test_tool".to_string(),
1282 arguments: None,
1283 };
1284
1285 let result = middleware.before_request(&request, &mut context).await;
1286
1287 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1288 assert_eq!(
1289 context.get_metadata("rate_limit_client_id"),
1290 Some(&Value::String("api_key:api-key-123".to_string()))
1291 );
1292 }
1293
1294 #[tokio::test]
1295 async fn test_rate_limit_middleware_uses_jwt_context() {
1296 let middleware = RateLimitMiddleware::new(10, 5);
1297 let mut context = MiddlewareContext::new("test".to_string());
1298
1299 context.set_metadata(
1301 "auth_user_id".to_string(),
1302 Value::String("user-456".to_string()),
1303 );
1304
1305 let request = CallToolRequest {
1306 name: "test_tool".to_string(),
1307 arguments: None,
1308 };
1309
1310 let result = middleware.before_request(&request, &mut context).await;
1311
1312 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1313 assert_eq!(
1314 context.get_metadata("rate_limit_client_id"),
1315 Some(&Value::String("jwt:user-456".to_string()))
1316 );
1317 }
1318
1319 #[tokio::test]
1321 async fn test_security_config_default() {
1322 let config = SecurityConfig::default();
1323 assert!(config.authentication.enabled);
1324 assert!(!config.authentication.require_auth); assert!(config.rate_limiting.enabled);
1326 assert_eq!(config.rate_limiting.requests_per_minute, 60);
1327 }
1328
1329 #[tokio::test]
1330 async fn test_middleware_config_with_security() {
1331 let config = MiddlewareConfig {
1332 logging: LoggingConfig {
1333 enabled: true,
1334 level: "debug".to_string(),
1335 },
1336 validation: ValidationConfig {
1337 enabled: true,
1338 strict_mode: true,
1339 },
1340 performance: PerformanceConfig {
1341 enabled: true,
1342 slow_request_threshold_ms: 500,
1343 },
1344 security: SecurityConfig {
1345 authentication: AuthenticationConfig {
1346 enabled: true,
1347 require_auth: true,
1348 jwt_secret: "test-secret".to_string(),
1349 api_keys: vec![ApiKeyConfig {
1350 key: "test-key".to_string(),
1351 key_id: "test-id".to_string(),
1352 permissions: vec!["read".to_string()],
1353 expires_at: None,
1354 }],
1355 oauth: None,
1356 },
1357 rate_limiting: RateLimitingConfig {
1358 enabled: true,
1359 requests_per_minute: 30,
1360 burst_limit: 5,
1361 custom_limits: None,
1362 },
1363 },
1364 };
1365
1366 let chain = config.build_chain();
1367 assert!(!chain.is_empty());
1368 assert!(chain.len() >= 5); }
1370
1371 #[tokio::test]
1372 async fn test_middleware_chain_with_security_middleware() {
1373 let mut api_keys = HashMap::new();
1374 api_keys.insert(
1375 "test-key".to_string(),
1376 ApiKeyInfo {
1377 key_id: "test-id".to_string(),
1378 permissions: vec!["read".to_string()],
1379 expires_at: None,
1380 },
1381 );
1382
1383 let chain = MiddlewareChain::new()
1384 .add_middleware(AuthenticationMiddleware::new(
1385 api_keys,
1386 "test-secret".to_string(),
1387 ))
1388 .add_middleware(RateLimitMiddleware::new(10, 5))
1389 .add_middleware(LoggingMiddleware::new(LogLevel::Info));
1390
1391 let request = CallToolRequest {
1392 name: "test_tool".to_string(),
1393 arguments: Some(serde_json::json!({
1394 "api_key": "test-key"
1395 })),
1396 };
1397
1398 let result = chain
1399 .execute(request, |_| async {
1400 Ok(CallToolResult {
1401 content: vec![Content::Text {
1402 text: "success".to_string(),
1403 }],
1404 is_error: false,
1405 })
1406 })
1407 .await;
1408
1409 assert!(result.is_ok());
1410 }
1411}