Skip to main content

things3_cli/mcp/middleware/
mod.rs

1//! MCP Middleware system for cross-cutting concerns
2
3use crate::mcp::{CallToolRequest, CallToolResult, McpError, McpResult};
4use serde_json::Value;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8/// Middleware execution context
9#[derive(Debug, Clone)]
10pub struct MiddlewareContext {
11    /// Request ID for tracking
12    pub request_id: String,
13    /// Start time of the request
14    pub start_time: Instant,
15    /// Additional metadata
16    pub metadata: std::collections::HashMap<String, Value>,
17}
18
19impl MiddlewareContext {
20    /// Create a new middleware context
21    #[must_use]
22    pub fn new(request_id: String) -> Self {
23        Self {
24            request_id,
25            start_time: Instant::now(),
26            metadata: std::collections::HashMap::new(),
27        }
28    }
29
30    /// Get the elapsed time since request start
31    #[must_use]
32    pub fn elapsed(&self) -> Duration {
33        self.start_time.elapsed()
34    }
35
36    /// Set metadata value
37    pub fn set_metadata(&mut self, key: String, value: Value) {
38        self.metadata.insert(key, value);
39    }
40
41    /// Get metadata value
42    #[must_use]
43    pub fn get_metadata(&self, key: &str) -> Option<&Value> {
44        self.metadata.get(key)
45    }
46}
47
48/// Middleware execution result
49#[derive(Debug)]
50pub enum MiddlewareResult {
51    /// Continue to next middleware or handler
52    Continue,
53    /// Stop execution and return this result
54    Stop(CallToolResult),
55    /// Stop execution with error
56    Error(McpError),
57}
58
59/// MCP Middleware trait for intercepting and controlling server operations
60#[async_trait::async_trait]
61pub trait McpMiddleware: Send + Sync {
62    /// Name of the middleware for identification
63    fn name(&self) -> &str;
64
65    /// Priority/order of execution (lower numbers execute first)
66    fn priority(&self) -> i32 {
67        0
68    }
69
70    /// Called before the request is processed
71    async fn before_request(
72        &self,
73        request: &CallToolRequest,
74        context: &mut MiddlewareContext,
75    ) -> McpResult<MiddlewareResult> {
76        let _ = (request, context);
77        Ok(MiddlewareResult::Continue)
78    }
79
80    /// Called after the request is processed but before response is returned
81    async fn after_request(
82        &self,
83        request: &CallToolRequest,
84        response: &mut CallToolResult,
85        context: &mut MiddlewareContext,
86    ) -> McpResult<MiddlewareResult> {
87        let _ = (request, response, context);
88        Ok(MiddlewareResult::Continue)
89    }
90
91    /// Called when an error occurs during request processing
92    async fn on_error(
93        &self,
94        request: &CallToolRequest,
95        error: &McpError,
96        context: &mut MiddlewareContext,
97    ) -> McpResult<MiddlewareResult> {
98        let _ = (request, error, context);
99        Ok(MiddlewareResult::Continue)
100    }
101}
102
103/// Middleware chain for executing multiple middleware in order
104pub struct MiddlewareChain {
105    middlewares: Vec<Arc<dyn McpMiddleware>>,
106}
107
108impl MiddlewareChain {
109    /// Create a new middleware chain
110    #[must_use]
111    pub fn new() -> Self {
112        Self {
113            middlewares: Vec::new(),
114        }
115    }
116
117    /// Add middleware to the chain
118    #[must_use]
119    pub fn add_middleware<M: McpMiddleware + 'static>(mut self, middleware: M) -> Self {
120        self.middlewares.push(Arc::new(middleware));
121        self.sort_by_priority();
122        self
123    }
124
125    /// Add middleware from Arc
126    #[must_use]
127    pub fn add_arc(mut self, middleware: Arc<dyn McpMiddleware>) -> Self {
128        self.middlewares.push(middleware);
129        self.sort_by_priority();
130        self
131    }
132
133    /// Sort middlewares by priority (lower numbers first)
134    fn sort_by_priority(&mut self) {
135        self.middlewares.sort_by_key(|m| m.priority());
136    }
137
138    /// Execute the middleware chain for a request
139    ///
140    /// # Errors
141    ///
142    /// This function will return an error if:
143    /// - Any middleware in the chain returns an error
144    /// - The main handler function returns an error
145    /// - Any middleware fails during execution
146    pub async fn execute<F, Fut>(
147        &self,
148        request: CallToolRequest,
149        handler: F,
150    ) -> McpResult<CallToolResult>
151    where
152        F: FnOnce(CallToolRequest) -> Fut,
153        Fut: std::future::Future<Output = McpResult<CallToolResult>> + Send,
154    {
155        let request_id = uuid::Uuid::new_v4().to_string();
156        let mut context = MiddlewareContext::new(request_id);
157
158        // Execute before_request hooks
159        for middleware in &self.middlewares {
160            match middleware.before_request(&request, &mut context).await? {
161                MiddlewareResult::Continue => {}
162                MiddlewareResult::Stop(result) => return Ok(result),
163                MiddlewareResult::Error(error) => return Err(error),
164            }
165        }
166
167        // Clone request for use in after_request hooks
168        let request_clone = request.clone();
169
170        // Execute the main handler
171        let mut result = match handler(request).await {
172            Ok(response) => response,
173            Err(error) => {
174                // Execute on_error hooks
175                for middleware in &self.middlewares {
176                    match middleware
177                        .on_error(&request_clone, &error, &mut context)
178                        .await?
179                    {
180                        MiddlewareResult::Continue => {}
181                        MiddlewareResult::Stop(result) => return Ok(result),
182                        MiddlewareResult::Error(middleware_error) => return Err(middleware_error),
183                    }
184                }
185                return Err(error);
186            }
187        };
188
189        // Execute after_request hooks
190        for middleware in &self.middlewares {
191            match middleware
192                .after_request(&request_clone, &mut result, &mut context)
193                .await?
194            {
195                MiddlewareResult::Continue => {}
196                MiddlewareResult::Stop(new_result) => return Ok(new_result),
197                MiddlewareResult::Error(error) => return Err(error),
198            }
199        }
200
201        Ok(result)
202    }
203
204    /// Get the number of middlewares in the chain
205    #[must_use]
206    pub fn len(&self) -> usize {
207        self.middlewares.len()
208    }
209
210    /// Check if the chain is empty
211    #[must_use]
212    pub fn is_empty(&self) -> bool {
213        self.middlewares.is_empty()
214    }
215}
216
217impl Default for MiddlewareChain {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223mod auth;
224mod config;
225mod logging;
226mod performance;
227mod rate_limit;
228mod validation;
229
230pub use auth::{ApiKeyInfo, AuthenticationMiddleware, JwtClaims, OAuthConfig};
231pub use config::{
232    ApiKeyConfig, AuthenticationConfig, LoggingConfig, MiddlewareConfig, OAuth2Config,
233    PerformanceConfig, RateLimitingConfig, SecurityConfig, ValidationConfig,
234};
235pub use logging::{LogLevel, LoggingMiddleware};
236pub use performance::PerformanceMiddleware;
237pub use rate_limit::RateLimitMiddleware;
238pub use validation::ValidationMiddleware;
239
240#[derive(Debug, thiserror::Error)]
241pub enum MiddlewareError {
242    #[error("Middleware execution failed: {message}")]
243    ExecutionFailed { message: String },
244
245    #[error("Middleware configuration error: {message}")]
246    ConfigurationError { message: String },
247
248    #[error("Middleware chain error: {message}")]
249    ChainError { message: String },
250}
251
252impl From<MiddlewareError> for McpError {
253    fn from(error: MiddlewareError) -> Self {
254        McpError::internal_error(error.to_string())
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::mcp::Content;
262    use std::collections::HashMap;
263
264    struct TestMiddleware {
265        priority: i32,
266    }
267
268    #[async_trait::async_trait]
269    impl McpMiddleware for TestMiddleware {
270        fn name(&self) -> &'static str {
271            "test_middleware"
272        }
273
274        fn priority(&self) -> i32 {
275            self.priority
276        }
277    }
278
279    #[tokio::test]
280    async fn test_middleware_chain_creation() {
281        let chain = MiddlewareChain::new()
282            .add_middleware(TestMiddleware { priority: 100 })
283            .add_middleware(TestMiddleware { priority: 50 });
284
285        assert_eq!(chain.len(), 2);
286        assert!(!chain.is_empty());
287    }
288
289    #[tokio::test]
290    async fn test_middleware_priority_ordering() {
291        let chain = MiddlewareChain::new()
292            .add_middleware(TestMiddleware { priority: 10 })
293            .add_middleware(TestMiddleware { priority: 100 });
294
295        // The chain should be sorted by priority
296        assert_eq!(chain.len(), 2);
297    }
298
299    #[tokio::test]
300    async fn test_middleware_execution() {
301        let chain = MiddlewareChain::new()
302            .add_middleware(LoggingMiddleware::info())
303            .add_middleware(ValidationMiddleware::lenient());
304
305        let request = CallToolRequest {
306            name: "test_tool".to_string(),
307            arguments: Some(serde_json::json!({"param": "value"})),
308        };
309
310        let handler = |_req: CallToolRequest| {
311            Box::pin(async move {
312                Ok(CallToolResult {
313                    content: vec![Content::Text {
314                        text: "Test response".to_string(),
315                    }],
316                    is_error: false,
317                })
318            })
319        };
320
321        let result = chain.execute(request, handler).await;
322        assert!(result.is_ok());
323    }
324
325    #[tokio::test]
326    async fn test_validation_middleware() {
327        let middleware = ValidationMiddleware::strict();
328        let mut context = MiddlewareContext::new("test".to_string());
329
330        // Valid request
331        let valid_request = CallToolRequest {
332            name: "valid_tool".to_string(),
333            arguments: Some(serde_json::json!({"param": "value"})),
334        };
335
336        let result = middleware
337            .before_request(&valid_request, &mut context)
338            .await;
339        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
340
341        // Invalid request (empty name)
342        let invalid_request = CallToolRequest {
343            name: String::new(),
344            arguments: None,
345        };
346
347        let result = middleware
348            .before_request(&invalid_request, &mut context)
349            .await;
350        assert!(matches!(result, Ok(MiddlewareResult::Error(_))));
351    }
352
353    #[tokio::test]
354    async fn test_performance_middleware() {
355        let middleware = PerformanceMiddleware::with_threshold(Duration::from_millis(100));
356        let mut context = MiddlewareContext::new("test".to_string());
357
358        // Simulate a slow request
359        tokio::time::sleep(Duration::from_millis(150)).await;
360
361        let mut response = CallToolResult {
362            content: vec![Content::Text {
363                text: "Test".to_string(),
364            }],
365            is_error: false,
366        };
367
368        let request = CallToolRequest {
369            name: "test".to_string(),
370            arguments: None,
371        };
372
373        let result = middleware
374            .after_request(&request, &mut response, &mut context)
375            .await;
376        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
377
378        // Check that performance metadata was set
379        assert!(context.get_metadata("duration_ms").is_some());
380        assert!(context.get_metadata("is_slow").is_some());
381    }
382
383    #[tokio::test]
384    async fn test_middleware_config() {
385        let config = MiddlewareConfig {
386            logging: LoggingConfig {
387                enabled: true,
388                level: "debug".to_string(),
389            },
390            validation: ValidationConfig {
391                enabled: true,
392                strict_mode: true,
393            },
394            performance: PerformanceConfig {
395                enabled: true,
396                slow_request_threshold_ms: 500,
397            },
398            security: SecurityConfig::default(),
399        };
400
401        let chain = config.build_chain();
402        assert!(!chain.is_empty());
403        assert!(chain.len() >= 3); // Should have logging, validation, and performance
404    }
405
406    #[tokio::test]
407    async fn test_middleware_context_creation() {
408        let context = MiddlewareContext::new("test-request-123".to_string());
409        assert_eq!(context.request_id, "test-request-123");
410        assert!(context.metadata.is_empty());
411    }
412
413    #[tokio::test]
414    async fn test_middleware_context_elapsed() {
415        let context = MiddlewareContext::new("test-request-123".to_string());
416        std::thread::sleep(std::time::Duration::from_millis(10));
417        let elapsed = context.elapsed();
418        assert!(elapsed.as_millis() >= 10);
419    }
420
421    #[tokio::test]
422    async fn test_middleware_context_metadata() {
423        let mut context = MiddlewareContext::new("test-request-123".to_string());
424
425        // Test setting metadata
426        context.set_metadata(
427            "key1".to_string(),
428            serde_json::Value::String("value1".to_string()),
429        );
430        context.set_metadata(
431            "key2".to_string(),
432            serde_json::Value::Number(serde_json::Number::from(42)),
433        );
434
435        // Test getting metadata
436        assert_eq!(
437            context.get_metadata("key1"),
438            Some(&serde_json::Value::String("value1".to_string()))
439        );
440        assert_eq!(
441            context.get_metadata("key2"),
442            Some(&serde_json::Value::Number(serde_json::Number::from(42)))
443        );
444        assert_eq!(context.get_metadata("nonexistent"), None);
445    }
446
447    #[tokio::test]
448    async fn test_middleware_result_variants() {
449        let continue_result = MiddlewareResult::Continue;
450        let stop_result = MiddlewareResult::Stop(CallToolResult {
451            content: vec![Content::Text {
452                text: "test".to_string(),
453            }],
454            is_error: false,
455        });
456        let error_result = MiddlewareResult::Error(McpError::tool_not_found("test error"));
457
458        // Test that we can create all variants
459        match continue_result {
460            MiddlewareResult::Continue => {}
461            _ => panic!("Expected Continue"),
462        }
463
464        match stop_result {
465            MiddlewareResult::Stop(_) => {}
466            _ => panic!("Expected Stop"),
467        }
468
469        match error_result {
470            MiddlewareResult::Error(_) => {}
471            _ => panic!("Expected Error"),
472        }
473    }
474
475    #[tokio::test]
476    async fn test_logging_middleware_different_levels() {
477        let debug_middleware = LoggingMiddleware::new(LogLevel::Debug);
478        let info_middleware = LoggingMiddleware::new(LogLevel::Info);
479        let warn_middleware = LoggingMiddleware::new(LogLevel::Warn);
480        let error_middleware = LoggingMiddleware::new(LogLevel::Error);
481
482        assert_eq!(debug_middleware.name(), "logging");
483        assert_eq!(info_middleware.name(), "logging");
484        assert_eq!(warn_middleware.name(), "logging");
485        assert_eq!(error_middleware.name(), "logging");
486    }
487
488    #[tokio::test]
489    async fn test_logging_middleware_should_log() {
490        let debug_middleware = LoggingMiddleware::new(LogLevel::Debug);
491        let info_middleware = LoggingMiddleware::new(LogLevel::Info);
492        let warn_middleware = LoggingMiddleware::new(LogLevel::Warn);
493        let error_middleware = LoggingMiddleware::new(LogLevel::Error);
494
495        // Debug should log everything
496        assert!(debug_middleware.should_log(LogLevel::Debug));
497        assert!(debug_middleware.should_log(LogLevel::Info));
498        assert!(debug_middleware.should_log(LogLevel::Warn));
499        assert!(debug_middleware.should_log(LogLevel::Error));
500
501        // Info should log info, warn, error
502        assert!(!info_middleware.should_log(LogLevel::Debug));
503        assert!(info_middleware.should_log(LogLevel::Info));
504        assert!(info_middleware.should_log(LogLevel::Warn));
505        assert!(info_middleware.should_log(LogLevel::Error));
506
507        // Warn should log warn, error
508        assert!(!warn_middleware.should_log(LogLevel::Debug));
509        assert!(!warn_middleware.should_log(LogLevel::Info));
510        assert!(warn_middleware.should_log(LogLevel::Warn));
511        assert!(warn_middleware.should_log(LogLevel::Error));
512
513        // Error should only log error
514        assert!(!error_middleware.should_log(LogLevel::Debug));
515        assert!(!error_middleware.should_log(LogLevel::Info));
516        assert!(!error_middleware.should_log(LogLevel::Warn));
517        assert!(error_middleware.should_log(LogLevel::Error));
518    }
519
520    #[tokio::test]
521    async fn test_validation_middleware_strict_mode() {
522        let strict_middleware = ValidationMiddleware::strict();
523        let lenient_middleware = ValidationMiddleware::lenient();
524
525        assert_eq!(strict_middleware.name(), "validation");
526        assert_eq!(lenient_middleware.name(), "validation");
527    }
528
529    #[tokio::test]
530    async fn test_validation_middleware_creation() {
531        let middleware1 = ValidationMiddleware::new(true);
532        let middleware2 = ValidationMiddleware::new(false);
533
534        assert_eq!(middleware1.name(), "validation");
535        assert_eq!(middleware2.name(), "validation");
536    }
537
538    #[tokio::test]
539    async fn test_performance_middleware_creation() {
540        let middleware1 = PerformanceMiddleware::new(Duration::from_millis(100));
541        let middleware2 = PerformanceMiddleware::with_threshold(Duration::from_millis(200));
542        let middleware3 = PerformanceMiddleware::create_default();
543
544        assert_eq!(middleware1.name(), "performance");
545        assert_eq!(middleware2.name(), "performance");
546        assert_eq!(middleware3.name(), "performance");
547    }
548
549    #[tokio::test]
550    async fn test_middleware_chain_empty() {
551        let chain = MiddlewareChain::new();
552        assert!(chain.is_empty());
553        assert_eq!(chain.len(), 0);
554    }
555
556    #[tokio::test]
557    async fn test_middleware_chain_add_middleware() {
558        let chain = MiddlewareChain::new()
559            .add_middleware(LoggingMiddleware::new(LogLevel::Info))
560            .add_middleware(ValidationMiddleware::new(false));
561
562        assert!(!chain.is_empty());
563        assert_eq!(chain.len(), 2);
564    }
565
566    #[tokio::test]
567    async fn test_middleware_chain_add_arc() {
568        let middleware = Arc::new(LoggingMiddleware::new(LogLevel::Info)) as Arc<dyn McpMiddleware>;
569        let chain = MiddlewareChain::new().add_arc(middleware);
570
571        assert!(!chain.is_empty());
572        assert_eq!(chain.len(), 1);
573    }
574
575    #[tokio::test]
576    async fn test_middleware_chain_execution_with_empty_chain() {
577        let chain = MiddlewareChain::new();
578        let request = CallToolRequest {
579            name: "test_tool".to_string(),
580            arguments: None,
581        };
582
583        let result = chain
584            .execute(request, |_| async {
585                Ok(CallToolResult {
586                    content: vec![Content::Text {
587                        text: "success".to_string(),
588                    }],
589                    is_error: false,
590                })
591            })
592            .await;
593
594        assert!(result.is_ok());
595        let result = result.unwrap();
596        assert!(!result.is_error);
597        assert_eq!(result.content.len(), 1);
598    }
599
600    #[tokio::test]
601    async fn test_middleware_chain_execution_with_error() {
602        let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
603        let request = CallToolRequest {
604            name: "test_tool".to_string(),
605            arguments: None,
606        };
607
608        let result = chain
609            .execute(request, |_| async {
610                Err(McpError::tool_not_found("test error"))
611            })
612            .await;
613
614        assert!(result.is_err());
615    }
616
617    #[tokio::test]
618    async fn test_middleware_chain_execution_with_stop() {
619        // Create a middleware that stops execution
620        struct StopMiddleware;
621        #[async_trait::async_trait]
622        impl McpMiddleware for StopMiddleware {
623            fn name(&self) -> &'static str {
624                "stop"
625            }
626
627            async fn before_request(
628                &self,
629                _request: &CallToolRequest,
630                _context: &mut MiddlewareContext,
631            ) -> McpResult<MiddlewareResult> {
632                Ok(MiddlewareResult::Stop(CallToolResult {
633                    content: vec![Content::Text {
634                        text: "stopped".to_string(),
635                    }],
636                    is_error: false,
637                }))
638            }
639
640            async fn after_request(
641                &self,
642                _request: &CallToolRequest,
643                _result: &mut CallToolResult,
644                _context: &mut MiddlewareContext,
645            ) -> McpResult<MiddlewareResult> {
646                Ok(MiddlewareResult::Continue)
647            }
648
649            async fn on_error(
650                &self,
651                _request: &CallToolRequest,
652                _error: &McpError,
653                _context: &mut MiddlewareContext,
654            ) -> McpResult<MiddlewareResult> {
655                Ok(MiddlewareResult::Continue)
656            }
657        }
658
659        let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
660        let request = CallToolRequest {
661            name: "test_tool".to_string(),
662            arguments: None,
663        };
664
665        let chain = chain.add_middleware(StopMiddleware);
666
667        let result = chain
668            .execute(request, |_| async {
669                Ok(CallToolResult {
670                    content: vec![Content::Text {
671                        text: "should not reach here".to_string(),
672                    }],
673                    is_error: false,
674                })
675            })
676            .await;
677
678        assert!(result.is_ok());
679        let result = result.unwrap();
680        let Content::Text { text } = &result.content[0];
681        assert_eq!(text, "stopped");
682    }
683
684    #[tokio::test]
685    async fn test_middleware_chain_execution_with_middleware_error() {
686        // Create a middleware that returns an error
687        struct ErrorMiddleware;
688        #[async_trait::async_trait]
689        impl McpMiddleware for ErrorMiddleware {
690            fn name(&self) -> &'static str {
691                "error"
692            }
693
694            async fn before_request(
695                &self,
696                _request: &CallToolRequest,
697                _context: &mut MiddlewareContext,
698            ) -> McpResult<MiddlewareResult> {
699                Err(McpError::tool_not_found("middleware error"))
700            }
701
702            async fn after_request(
703                &self,
704                _request: &CallToolRequest,
705                _result: &mut CallToolResult,
706                _context: &mut MiddlewareContext,
707            ) -> McpResult<MiddlewareResult> {
708                Ok(MiddlewareResult::Continue)
709            }
710
711            async fn on_error(
712                &self,
713                _request: &CallToolRequest,
714                _error: &McpError,
715                _context: &mut MiddlewareContext,
716            ) -> McpResult<MiddlewareResult> {
717                Ok(MiddlewareResult::Continue)
718            }
719        }
720
721        let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
722        let request = CallToolRequest {
723            name: "test_tool".to_string(),
724            arguments: None,
725        };
726
727        let chain = chain.add_middleware(ErrorMiddleware);
728
729        let result = chain
730            .execute(request, |_| async {
731                Ok(CallToolResult {
732                    content: vec![Content::Text {
733                        text: "should not reach here".to_string(),
734                    }],
735                    is_error: false,
736                })
737            })
738            .await;
739
740        assert!(result.is_err());
741        let error = result.unwrap_err();
742        assert!(matches!(error, McpError::ToolNotFound { tool_name: _ }));
743    }
744
745    #[tokio::test]
746    async fn test_middleware_chain_execution_with_on_error() {
747        // Create a middleware that handles errors
748        struct ErrorHandlerMiddleware;
749        #[async_trait::async_trait]
750        impl McpMiddleware for ErrorHandlerMiddleware {
751            fn name(&self) -> &'static str {
752                "error_handler"
753            }
754
755            async fn before_request(
756                &self,
757                _request: &CallToolRequest,
758                _context: &mut MiddlewareContext,
759            ) -> McpResult<MiddlewareResult> {
760                Ok(MiddlewareResult::Continue)
761            }
762
763            async fn after_request(
764                &self,
765                _request: &CallToolRequest,
766                _result: &mut CallToolResult,
767                _context: &mut MiddlewareContext,
768            ) -> McpResult<MiddlewareResult> {
769                Ok(MiddlewareResult::Continue)
770            }
771
772            async fn on_error(
773                &self,
774                _request: &CallToolRequest,
775                _error: &McpError,
776                _context: &mut MiddlewareContext,
777            ) -> McpResult<MiddlewareResult> {
778                Ok(MiddlewareResult::Stop(CallToolResult {
779                    content: vec![Content::Text {
780                        text: "error handled".to_string(),
781                    }],
782                    is_error: false,
783                }))
784            }
785        }
786
787        let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
788        let request = CallToolRequest {
789            name: "test_tool".to_string(),
790            arguments: None,
791        };
792
793        let chain = chain.add_middleware(ErrorHandlerMiddleware);
794
795        let result = chain
796            .execute(request, |_| async {
797                Err(McpError::tool_not_found("test error"))
798            })
799            .await;
800
801        assert!(result.is_ok());
802        let result = result.unwrap();
803        let Content::Text { text } = &result.content[0];
804        assert_eq!(text, "error handled");
805    }
806
807    #[tokio::test]
808    async fn test_config_structs_creation() {
809        let logging_config = LoggingConfig {
810            enabled: true,
811            level: "debug".to_string(),
812        };
813        let validation_config = ValidationConfig {
814            enabled: true,
815            strict_mode: true,
816        };
817        let performance_config = PerformanceConfig {
818            enabled: true,
819            slow_request_threshold_ms: 1000,
820        };
821
822        assert!(logging_config.enabled);
823        assert_eq!(logging_config.level, "debug");
824        assert!(validation_config.enabled);
825        assert!(validation_config.strict_mode);
826        assert!(performance_config.enabled);
827        assert_eq!(performance_config.slow_request_threshold_ms, 1000);
828    }
829
830    #[tokio::test]
831    async fn test_config_default() {
832        let config = MiddlewareConfig::default();
833        assert!(config.logging.enabled);
834        assert_eq!(config.logging.level, "info");
835        assert!(config.validation.enabled);
836        assert!(!config.validation.strict_mode);
837        assert!(config.performance.enabled);
838        assert_eq!(config.performance.slow_request_threshold_ms, 1000);
839    }
840
841    #[tokio::test]
842    async fn test_config_build_chain_with_disabled_middleware() {
843        let config = MiddlewareConfig {
844            logging: LoggingConfig {
845                enabled: false,
846                level: "debug".to_string(),
847            },
848            validation: ValidationConfig {
849                enabled: false,
850                strict_mode: true,
851            },
852            performance: PerformanceConfig {
853                enabled: false,
854                slow_request_threshold_ms: 1000,
855            },
856            security: SecurityConfig {
857                authentication: AuthenticationConfig {
858                    enabled: false,
859                    require_auth: false,
860                    jwt_secret: "test".to_string(),
861                    api_keys: vec![],
862                    oauth: None,
863                },
864                rate_limiting: RateLimitingConfig {
865                    enabled: false,
866                    requests_per_minute: 60,
867                    burst_limit: 10,
868                    custom_limits: None,
869                },
870            },
871        };
872
873        let chain = config.build_chain();
874        assert!(chain.is_empty());
875    }
876
877    #[tokio::test]
878    async fn test_config_build_chain_with_partial_middleware() {
879        let config = MiddlewareConfig {
880            logging: LoggingConfig {
881                enabled: true,
882                level: "debug".to_string(),
883            },
884            validation: ValidationConfig {
885                enabled: false,
886                strict_mode: true,
887            },
888            performance: PerformanceConfig {
889                enabled: true,
890                slow_request_threshold_ms: 1000,
891            },
892            security: SecurityConfig::default(),
893        };
894
895        let chain = config.build_chain();
896        assert!(!chain.is_empty());
897        assert!(chain.len() >= 2); // At least logging and performance
898    }
899
900    #[tokio::test]
901    async fn test_config_build_chain_with_invalid_log_level() {
902        let config = MiddlewareConfig {
903            logging: LoggingConfig {
904                enabled: true,
905                level: "invalid".to_string(),
906            },
907            validation: ValidationConfig {
908                enabled: true,
909                strict_mode: true,
910            },
911            performance: PerformanceConfig {
912                enabled: true,
913                slow_request_threshold_ms: 1000,
914            },
915            security: SecurityConfig::default(),
916        };
917
918        let chain = config.build_chain();
919        assert!(!chain.is_empty());
920        // Should default to info level
921    }
922
923    #[tokio::test]
924    async fn test_middleware_chain_execution_with_empty_middleware() {
925        let chain = MiddlewareChain::new();
926        let request = CallToolRequest {
927            name: "test_tool".to_string(),
928            arguments: Some(serde_json::json!({"param": "value"})),
929        };
930
931        let result = chain
932            .execute(request, |_| async {
933                Ok(CallToolResult {
934                    content: vec![Content::Text {
935                        text: "Test response".to_string(),
936                    }],
937                    is_error: false,
938                })
939            })
940            .await;
941
942        assert!(result.is_ok());
943        let result = result.unwrap();
944        assert!(!result.is_error);
945        assert_eq!(result.content.len(), 1);
946    }
947
948    #[tokio::test]
949    async fn test_middleware_chain_execution_with_multiple_middleware() {
950        let chain = MiddlewareChain::new()
951            .add_middleware(LoggingMiddleware::new(LogLevel::Info))
952            .add_middleware(ValidationMiddleware::new(false))
953            .add_middleware(PerformanceMiddleware::new(Duration::from_millis(100)));
954
955        let request = CallToolRequest {
956            name: "test_tool".to_string(),
957            arguments: Some(serde_json::json!({"param": "value"})),
958        };
959
960        let result = chain
961            .execute(request, |_| async {
962                Ok(CallToolResult {
963                    content: vec![Content::Text {
964                        text: "Test response".to_string(),
965                    }],
966                    is_error: false,
967                })
968            })
969            .await;
970
971        assert!(result.is_ok());
972        let result = result.unwrap();
973        assert!(!result.is_error);
974        assert_eq!(result.content.len(), 1);
975    }
976
977    #[tokio::test]
978    async fn test_middleware_chain_execution_with_middleware_stop() {
979        struct StopMiddleware;
980        #[async_trait::async_trait]
981        impl McpMiddleware for StopMiddleware {
982            fn name(&self) -> &'static str {
983                "stop_middleware"
984            }
985
986            fn priority(&self) -> i32 {
987                100
988            }
989
990            async fn before_request(
991                &self,
992                _request: &CallToolRequest,
993                _context: &mut MiddlewareContext,
994            ) -> McpResult<MiddlewareResult> {
995                Ok(MiddlewareResult::Stop(CallToolResult {
996                    content: vec![Content::Text {
997                        text: "Stopped by middleware".to_string(),
998                    }],
999                    is_error: false,
1000                }))
1001            }
1002
1003            async fn after_request(
1004                &self,
1005                _request: &CallToolRequest,
1006                _result: &mut CallToolResult,
1007                _context: &mut MiddlewareContext,
1008            ) -> McpResult<MiddlewareResult> {
1009                Ok(MiddlewareResult::Continue)
1010            }
1011
1012            async fn on_error(
1013                &self,
1014                _request: &CallToolRequest,
1015                _error: &McpError,
1016                _context: &mut MiddlewareContext,
1017            ) -> McpResult<MiddlewareResult> {
1018                Ok(MiddlewareResult::Continue)
1019            }
1020        }
1021
1022        let chain = MiddlewareChain::new().add_middleware(StopMiddleware);
1023
1024        let request = CallToolRequest {
1025            name: "test_tool".to_string(),
1026            arguments: None,
1027        };
1028
1029        let result = chain
1030            .execute(request, |_| async {
1031                Ok(CallToolResult {
1032                    content: vec![Content::Text {
1033                        text: "Should not reach here".to_string(),
1034                    }],
1035                    is_error: false,
1036                })
1037            })
1038            .await;
1039
1040        assert!(result.is_ok());
1041        let result = result.unwrap();
1042        assert!(!result.is_error);
1043        let Content::Text { text } = &result.content[0];
1044        assert_eq!(text, "Stopped by middleware");
1045    }
1046
1047    #[tokio::test]
1048    async fn test_middleware_chain_execution_with_middleware_error_duplicate() {
1049        struct ErrorMiddleware;
1050        #[async_trait::async_trait]
1051        impl McpMiddleware for ErrorMiddleware {
1052            fn name(&self) -> &'static str {
1053                "error_middleware"
1054            }
1055
1056            fn priority(&self) -> i32 {
1057                100
1058            }
1059
1060            async fn before_request(
1061                &self,
1062                _request: &CallToolRequest,
1063                _context: &mut MiddlewareContext,
1064            ) -> McpResult<MiddlewareResult> {
1065                Err(McpError::internal_error("Middleware error"))
1066            }
1067
1068            async fn after_request(
1069                &self,
1070                _request: &CallToolRequest,
1071                _result: &mut CallToolResult,
1072                _context: &mut MiddlewareContext,
1073            ) -> McpResult<MiddlewareResult> {
1074                Ok(MiddlewareResult::Continue)
1075            }
1076
1077            async fn on_error(
1078                &self,
1079                _request: &CallToolRequest,
1080                _error: &McpError,
1081                _context: &mut MiddlewareContext,
1082            ) -> McpResult<MiddlewareResult> {
1083                Ok(MiddlewareResult::Continue)
1084            }
1085        }
1086
1087        let chain = MiddlewareChain::new().add_middleware(ErrorMiddleware);
1088
1089        let request = CallToolRequest {
1090            name: "test_tool".to_string(),
1091            arguments: None,
1092        };
1093
1094        let result = chain
1095            .execute(request, |_| async {
1096                Ok(CallToolResult {
1097                    content: vec![Content::Text {
1098                        text: "Should not reach here".to_string(),
1099                    }],
1100                    is_error: false,
1101                })
1102            })
1103            .await;
1104
1105        assert!(result.is_err());
1106        let error = result.unwrap_err();
1107        assert!(matches!(error, McpError::InternalError { .. }));
1108    }
1109
1110    // Authentication Middleware Tests
1111    #[tokio::test]
1112    async fn test_authentication_middleware_permissive() {
1113        let middleware = AuthenticationMiddleware::permissive();
1114        let mut context = MiddlewareContext::new("test".to_string());
1115
1116        let request = CallToolRequest {
1117            name: "test_tool".to_string(),
1118            arguments: None,
1119        };
1120
1121        let result = middleware.before_request(&request, &mut context).await;
1122
1123        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1124        assert_eq!(
1125            context.get_metadata("auth_required"),
1126            Some(&Value::Bool(false))
1127        );
1128    }
1129
1130    #[tokio::test]
1131    async fn test_authentication_middleware_with_valid_api_key() {
1132        let mut api_keys = HashMap::new();
1133        api_keys.insert(
1134            "test-api-key".to_string(),
1135            ApiKeyInfo {
1136                key_id: "test-key-1".to_string(),
1137                permissions: vec!["read".to_string(), "write".to_string()],
1138                expires_at: None,
1139            },
1140        );
1141
1142        let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1143        let mut context = MiddlewareContext::new("test".to_string());
1144
1145        let request = CallToolRequest {
1146            name: "test_tool".to_string(),
1147            arguments: Some(serde_json::json!({
1148                "api_key": "test-api-key"
1149            })),
1150        };
1151
1152        let result = middleware.before_request(&request, &mut context).await;
1153
1154        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1155        assert_eq!(
1156            context.get_metadata("auth_type"),
1157            Some(&Value::String("api_key".to_string()))
1158        );
1159        assert_eq!(
1160            context.get_metadata("auth_key_id"),
1161            Some(&Value::String("test-key-1".to_string()))
1162        );
1163    }
1164
1165    #[tokio::test]
1166    async fn test_authentication_middleware_with_invalid_api_key() {
1167        let api_keys = HashMap::new();
1168        let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1169        let mut context = MiddlewareContext::new("test".to_string());
1170
1171        let request = CallToolRequest {
1172            name: "test_tool".to_string(),
1173            arguments: Some(serde_json::json!({
1174                "api_key": "invalid-key"
1175            })),
1176        };
1177
1178        let result = middleware.before_request(&request, &mut context).await;
1179
1180        assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
1181    }
1182
1183    #[tokio::test]
1184    async fn test_authentication_middleware_with_valid_jwt() {
1185        let api_keys = HashMap::new();
1186        let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1187
1188        // Generate a test JWT token
1189        let jwt_token = middleware.generate_test_jwt("user123", vec!["read".to_string()]);
1190
1191        let mut context = MiddlewareContext::new("test".to_string());
1192
1193        let request = CallToolRequest {
1194            name: "test_tool".to_string(),
1195            arguments: Some(serde_json::json!({
1196                "jwt_token": jwt_token
1197            })),
1198        };
1199
1200        let result = middleware.before_request(&request, &mut context).await;
1201
1202        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1203        assert_eq!(
1204            context.get_metadata("auth_type"),
1205            Some(&Value::String("jwt".to_string()))
1206        );
1207        assert_eq!(
1208            context.get_metadata("auth_user_id"),
1209            Some(&Value::String("user123".to_string()))
1210        );
1211    }
1212
1213    #[tokio::test]
1214    async fn test_authentication_middleware_with_invalid_jwt() {
1215        let api_keys = HashMap::new();
1216        let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1217        let mut context = MiddlewareContext::new("test".to_string());
1218
1219        let request = CallToolRequest {
1220            name: "test_tool".to_string(),
1221            arguments: Some(serde_json::json!({
1222                "jwt_token": "invalid.jwt.token"
1223            })),
1224        };
1225
1226        let result = middleware.before_request(&request, &mut context).await;
1227
1228        assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
1229    }
1230
1231    #[tokio::test]
1232    async fn test_authentication_middleware_no_auth_provided() {
1233        let api_keys = HashMap::new();
1234        let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1235        let mut context = MiddlewareContext::new("test".to_string());
1236
1237        let request = CallToolRequest {
1238            name: "test_tool".to_string(),
1239            arguments: None,
1240        };
1241
1242        let result = middleware.before_request(&request, &mut context).await;
1243
1244        assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
1245    }
1246
1247    // Rate Limiting Middleware Tests
1248    #[tokio::test]
1249    async fn test_rate_limit_middleware_allows_request() {
1250        let middleware = RateLimitMiddleware::new(10, 5);
1251        let mut context = MiddlewareContext::new("test".to_string());
1252
1253        let request = CallToolRequest {
1254            name: "test_tool".to_string(),
1255            arguments: Some(serde_json::json!({
1256                "client_id": "test-client"
1257            })),
1258        };
1259
1260        let result = middleware.before_request(&request, &mut context).await;
1261
1262        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1263        assert_eq!(
1264            context.get_metadata("rate_limit_client_id"),
1265            Some(&Value::String("client:test-client".to_string()))
1266        );
1267    }
1268
1269    #[tokio::test]
1270    async fn test_rate_limit_middleware_uses_auth_context() {
1271        let middleware = RateLimitMiddleware::new(10, 5);
1272        let mut context = MiddlewareContext::new("test".to_string());
1273
1274        // Set up auth context
1275        context.set_metadata(
1276            "auth_key_id".to_string(),
1277            Value::String("api-key-123".to_string()),
1278        );
1279
1280        let request = CallToolRequest {
1281            name: "test_tool".to_string(),
1282            arguments: None,
1283        };
1284
1285        let result = middleware.before_request(&request, &mut context).await;
1286
1287        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1288        assert_eq!(
1289            context.get_metadata("rate_limit_client_id"),
1290            Some(&Value::String("api_key:api-key-123".to_string()))
1291        );
1292    }
1293
1294    #[tokio::test]
1295    async fn test_rate_limit_middleware_uses_jwt_context() {
1296        let middleware = RateLimitMiddleware::new(10, 5);
1297        let mut context = MiddlewareContext::new("test".to_string());
1298
1299        // Set up JWT context
1300        context.set_metadata(
1301            "auth_user_id".to_string(),
1302            Value::String("user-456".to_string()),
1303        );
1304
1305        let request = CallToolRequest {
1306            name: "test_tool".to_string(),
1307            arguments: None,
1308        };
1309
1310        let result = middleware.before_request(&request, &mut context).await;
1311
1312        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1313        assert_eq!(
1314            context.get_metadata("rate_limit_client_id"),
1315            Some(&Value::String("jwt:user-456".to_string()))
1316        );
1317    }
1318
1319    // Security Configuration Tests
1320    #[tokio::test]
1321    async fn test_security_config_default() {
1322        let config = SecurityConfig::default();
1323        assert!(config.authentication.enabled);
1324        assert!(!config.authentication.require_auth); // Should be false for easier development
1325        assert!(config.rate_limiting.enabled);
1326        assert_eq!(config.rate_limiting.requests_per_minute, 60);
1327    }
1328
1329    #[tokio::test]
1330    async fn test_middleware_config_with_security() {
1331        let config = MiddlewareConfig {
1332            logging: LoggingConfig {
1333                enabled: true,
1334                level: "debug".to_string(),
1335            },
1336            validation: ValidationConfig {
1337                enabled: true,
1338                strict_mode: true,
1339            },
1340            performance: PerformanceConfig {
1341                enabled: true,
1342                slow_request_threshold_ms: 500,
1343            },
1344            security: SecurityConfig {
1345                authentication: AuthenticationConfig {
1346                    enabled: true,
1347                    require_auth: true,
1348                    jwt_secret: "test-secret".to_string(),
1349                    api_keys: vec![ApiKeyConfig {
1350                        key: "test-key".to_string(),
1351                        key_id: "test-id".to_string(),
1352                        permissions: vec!["read".to_string()],
1353                        expires_at: None,
1354                    }],
1355                    oauth: None,
1356                },
1357                rate_limiting: RateLimitingConfig {
1358                    enabled: true,
1359                    requests_per_minute: 30,
1360                    burst_limit: 5,
1361                    custom_limits: None,
1362                },
1363            },
1364        };
1365
1366        let chain = config.build_chain();
1367        assert!(!chain.is_empty());
1368        assert!(chain.len() >= 5); // Should have auth, rate limiting, logging, validation, and performance
1369    }
1370
1371    #[tokio::test]
1372    async fn test_middleware_chain_with_security_middleware() {
1373        let mut api_keys = HashMap::new();
1374        api_keys.insert(
1375            "test-key".to_string(),
1376            ApiKeyInfo {
1377                key_id: "test-id".to_string(),
1378                permissions: vec!["read".to_string()],
1379                expires_at: None,
1380            },
1381        );
1382
1383        let chain = MiddlewareChain::new()
1384            .add_middleware(AuthenticationMiddleware::new(
1385                api_keys,
1386                "test-secret".to_string(),
1387            ))
1388            .add_middleware(RateLimitMiddleware::new(10, 5))
1389            .add_middleware(LoggingMiddleware::new(LogLevel::Info));
1390
1391        let request = CallToolRequest {
1392            name: "test_tool".to_string(),
1393            arguments: Some(serde_json::json!({
1394                "api_key": "test-key"
1395            })),
1396        };
1397
1398        let result = chain
1399            .execute(request, |_| async {
1400                Ok(CallToolResult {
1401                    content: vec![Content::Text {
1402                        text: "success".to_string(),
1403                    }],
1404                    is_error: false,
1405                })
1406            })
1407            .await;
1408
1409        assert!(result.is_ok());
1410    }
1411}