things3_cli/mcp/
middleware.rs

1//! MCP Middleware system for cross-cutting concerns
2
3use crate::mcp::{CallToolRequest, CallToolResult, McpError, McpResult};
4use governor::clock::DefaultClock;
5use governor::{state::keyed::DefaultKeyedStateStore, Quota, RateLimiter};
6#[allow(unused_imports)]
7use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
8use nonzero_ext::nonzero;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use thiserror::Error;
15
16/// Middleware execution context
17#[derive(Debug, Clone)]
18pub struct MiddlewareContext {
19    /// Request ID for tracking
20    pub request_id: String,
21    /// Start time of the request
22    pub start_time: Instant,
23    /// Additional metadata
24    pub metadata: std::collections::HashMap<String, Value>,
25}
26
27impl MiddlewareContext {
28    /// Create a new middleware context
29    #[must_use]
30    pub fn new(request_id: String) -> Self {
31        Self {
32            request_id,
33            start_time: Instant::now(),
34            metadata: std::collections::HashMap::new(),
35        }
36    }
37
38    /// Get the elapsed time since request start
39    #[must_use]
40    pub fn elapsed(&self) -> Duration {
41        self.start_time.elapsed()
42    }
43
44    /// Set metadata value
45    pub fn set_metadata(&mut self, key: String, value: Value) {
46        self.metadata.insert(key, value);
47    }
48
49    /// Get metadata value
50    #[must_use]
51    pub fn get_metadata(&self, key: &str) -> Option<&Value> {
52        self.metadata.get(key)
53    }
54}
55
56/// Middleware execution result
57#[derive(Debug)]
58pub enum MiddlewareResult {
59    /// Continue to next middleware or handler
60    Continue,
61    /// Stop execution and return this result
62    Stop(CallToolResult),
63    /// Stop execution with error
64    Error(McpError),
65}
66
67/// MCP Middleware trait for intercepting and controlling server operations
68#[async_trait::async_trait]
69pub trait McpMiddleware: Send + Sync {
70    /// Name of the middleware for identification
71    fn name(&self) -> &str;
72
73    /// Priority/order of execution (lower numbers execute first)
74    fn priority(&self) -> i32 {
75        0
76    }
77
78    /// Called before the request is processed
79    async fn before_request(
80        &self,
81        request: &CallToolRequest,
82        context: &mut MiddlewareContext,
83    ) -> McpResult<MiddlewareResult> {
84        let _ = (request, context);
85        Ok(MiddlewareResult::Continue)
86    }
87
88    /// Called after the request is processed but before response is returned
89    async fn after_request(
90        &self,
91        request: &CallToolRequest,
92        response: &mut CallToolResult,
93        context: &mut MiddlewareContext,
94    ) -> McpResult<MiddlewareResult> {
95        let _ = (request, response, context);
96        Ok(MiddlewareResult::Continue)
97    }
98
99    /// Called when an error occurs during request processing
100    async fn on_error(
101        &self,
102        request: &CallToolRequest,
103        error: &McpError,
104        context: &mut MiddlewareContext,
105    ) -> McpResult<MiddlewareResult> {
106        let _ = (request, error, context);
107        Ok(MiddlewareResult::Continue)
108    }
109}
110
111/// Middleware chain for executing multiple middleware in order
112pub struct MiddlewareChain {
113    middlewares: Vec<Arc<dyn McpMiddleware>>,
114}
115
116impl MiddlewareChain {
117    /// Create a new middleware chain
118    #[must_use]
119    pub fn new() -> Self {
120        Self {
121            middlewares: Vec::new(),
122        }
123    }
124
125    /// Add middleware to the chain
126    #[must_use]
127    pub fn add_middleware<M: McpMiddleware + 'static>(mut self, middleware: M) -> Self {
128        self.middlewares.push(Arc::new(middleware));
129        self.sort_by_priority();
130        self
131    }
132
133    /// Add middleware from Arc
134    #[must_use]
135    pub fn add_arc(mut self, middleware: Arc<dyn McpMiddleware>) -> Self {
136        self.middlewares.push(middleware);
137        self.sort_by_priority();
138        self
139    }
140
141    /// Sort middlewares by priority (lower numbers first)
142    fn sort_by_priority(&mut self) {
143        self.middlewares.sort_by_key(|m| m.priority());
144    }
145
146    /// Execute the middleware chain for a request
147    ///
148    /// # Errors
149    ///
150    /// This function will return an error if:
151    /// - Any middleware in the chain returns an error
152    /// - The main handler function returns an error
153    /// - Any middleware fails during execution
154    pub async fn execute<F, Fut>(
155        &self,
156        request: CallToolRequest,
157        handler: F,
158    ) -> McpResult<CallToolResult>
159    where
160        F: FnOnce(CallToolRequest) -> Fut,
161        Fut: std::future::Future<Output = McpResult<CallToolResult>> + Send,
162    {
163        let request_id = uuid::Uuid::new_v4().to_string();
164        let mut context = MiddlewareContext::new(request_id);
165
166        // Execute before_request hooks
167        for middleware in &self.middlewares {
168            match middleware.before_request(&request, &mut context).await? {
169                MiddlewareResult::Continue => {}
170                MiddlewareResult::Stop(result) => return Ok(result),
171                MiddlewareResult::Error(error) => return Err(error),
172            }
173        }
174
175        // Clone request for use in after_request hooks
176        let request_clone = request.clone();
177
178        // Execute the main handler
179        let mut result = match handler(request).await {
180            Ok(response) => response,
181            Err(error) => {
182                // Execute on_error hooks
183                for middleware in &self.middlewares {
184                    match middleware
185                        .on_error(&request_clone, &error, &mut context)
186                        .await?
187                    {
188                        MiddlewareResult::Continue => {}
189                        MiddlewareResult::Stop(result) => return Ok(result),
190                        MiddlewareResult::Error(middleware_error) => return Err(middleware_error),
191                    }
192                }
193                return Err(error);
194            }
195        };
196
197        // Execute after_request hooks
198        for middleware in &self.middlewares {
199            match middleware
200                .after_request(&request_clone, &mut result, &mut context)
201                .await?
202            {
203                MiddlewareResult::Continue => {}
204                MiddlewareResult::Stop(new_result) => return Ok(new_result),
205                MiddlewareResult::Error(error) => return Err(error),
206            }
207        }
208
209        Ok(result)
210    }
211
212    /// Get the number of middlewares in the chain
213    #[must_use]
214    pub fn len(&self) -> usize {
215        self.middlewares.len()
216    }
217
218    /// Check if the chain is empty
219    #[must_use]
220    pub fn is_empty(&self) -> bool {
221        self.middlewares.is_empty()
222    }
223}
224
225impl Default for MiddlewareChain {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231/// Built-in logging middleware
232pub struct LoggingMiddleware {
233    level: LogLevel,
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq)]
237pub enum LogLevel {
238    Debug,
239    Info,
240    Warn,
241    Error,
242}
243
244impl LoggingMiddleware {
245    /// Create a new logging middleware
246    #[must_use]
247    pub fn new(level: LogLevel) -> Self {
248        Self { level }
249    }
250
251    /// Create with debug level
252    #[must_use]
253    pub fn debug() -> Self {
254        Self::new(LogLevel::Debug)
255    }
256
257    /// Create with info level
258    #[must_use]
259    pub fn info() -> Self {
260        Self::new(LogLevel::Info)
261    }
262
263    /// Create with warn level
264    #[must_use]
265    pub fn warn() -> Self {
266        Self::new(LogLevel::Warn)
267    }
268
269    /// Create with error level
270    #[must_use]
271    pub fn error() -> Self {
272        Self::new(LogLevel::Error)
273    }
274
275    fn should_log(&self, level: LogLevel) -> bool {
276        matches!(
277            (self.level, level),
278            (LogLevel::Debug, _)
279                | (
280                    LogLevel::Info,
281                    LogLevel::Info | LogLevel::Warn | LogLevel::Error
282                )
283                | (LogLevel::Warn, LogLevel::Warn | LogLevel::Error)
284                | (LogLevel::Error, LogLevel::Error)
285        )
286    }
287
288    fn log(&self, level: LogLevel, message: &str) {
289        if self.should_log(level) {
290            match level {
291                LogLevel::Debug => println!("[DEBUG] {message}"),
292                LogLevel::Info => println!("[INFO] {message}"),
293                LogLevel::Warn => println!("[WARN] {message}"),
294                LogLevel::Error => println!("[ERROR] {message}"),
295            }
296        }
297    }
298}
299
300#[async_trait::async_trait]
301impl McpMiddleware for LoggingMiddleware {
302    fn name(&self) -> &'static str {
303        "logging"
304    }
305
306    fn priority(&self) -> i32 {
307        100 // Low priority to run early
308    }
309
310    async fn before_request(
311        &self,
312        request: &CallToolRequest,
313        context: &mut MiddlewareContext,
314    ) -> McpResult<MiddlewareResult> {
315        self.log(
316            LogLevel::Info,
317            &format!(
318                "Request started: {} (ID: {})",
319                request.name, context.request_id
320            ),
321        );
322        Ok(MiddlewareResult::Continue)
323    }
324
325    async fn after_request(
326        &self,
327        request: &CallToolRequest,
328        response: &mut CallToolResult,
329        context: &mut MiddlewareContext,
330    ) -> McpResult<MiddlewareResult> {
331        let elapsed = context.elapsed();
332        let status = if response.is_error {
333            "ERROR"
334        } else {
335            "SUCCESS"
336        };
337
338        self.log(
339            LogLevel::Info,
340            &format!(
341                "Request completed: {} (ID: {}) - {} in {:?}",
342                request.name, context.request_id, status, elapsed
343            ),
344        );
345        Ok(MiddlewareResult::Continue)
346    }
347
348    async fn on_error(
349        &self,
350        request: &CallToolRequest,
351        error: &McpError,
352        context: &mut MiddlewareContext,
353    ) -> McpResult<MiddlewareResult> {
354        self.log(
355            LogLevel::Error,
356            &format!(
357                "Request failed: {} (ID: {}) - {}",
358                request.name, context.request_id, error
359            ),
360        );
361        Ok(MiddlewareResult::Continue)
362    }
363}
364
365/// Built-in validation middleware
366pub struct ValidationMiddleware {
367    strict_mode: bool,
368}
369
370impl ValidationMiddleware {
371    /// Create a new validation middleware
372    #[must_use]
373    pub fn new(strict_mode: bool) -> Self {
374        Self { strict_mode }
375    }
376
377    /// Create with strict mode enabled
378    #[must_use]
379    pub fn strict() -> Self {
380        Self::new(true)
381    }
382
383    /// Create with strict mode disabled
384    #[must_use]
385    pub fn lenient() -> Self {
386        Self::new(false)
387    }
388
389    fn validate_request(&self, request: &CallToolRequest) -> McpResult<()> {
390        // Basic validation
391        if request.name.is_empty() {
392            return Err(McpError::validation_error("Tool name cannot be empty"));
393        }
394
395        // Validate tool name format (alphanumeric and underscores only)
396        if !request
397            .name
398            .chars()
399            .all(|c| c.is_alphanumeric() || c == '_')
400        {
401            return Err(McpError::validation_error(
402                "Tool name must contain only alphanumeric characters and underscores",
403            ));
404        }
405
406        // In strict mode, validate arguments structure
407        if self.strict_mode {
408            if let Some(args) = &request.arguments {
409                if !args.is_object() {
410                    return Err(McpError::validation_error(
411                        "Arguments must be a JSON object",
412                    ));
413                }
414            }
415        }
416
417        Ok(())
418    }
419}
420
421#[async_trait::async_trait]
422impl McpMiddleware for ValidationMiddleware {
423    fn name(&self) -> &'static str {
424        "validation"
425    }
426
427    fn priority(&self) -> i32 {
428        50 // Medium priority
429    }
430
431    async fn before_request(
432        &self,
433        request: &CallToolRequest,
434        context: &mut MiddlewareContext,
435    ) -> McpResult<MiddlewareResult> {
436        if let Err(error) = self.validate_request(request) {
437            context.set_metadata(
438                "validation_error".to_string(),
439                serde_json::Value::String(error.to_string()),
440            );
441            return Ok(MiddlewareResult::Error(error));
442        }
443
444        context.set_metadata("validated".to_string(), serde_json::Value::Bool(true));
445        Ok(MiddlewareResult::Continue)
446    }
447}
448
449/// Built-in performance monitoring middleware
450pub struct PerformanceMiddleware {
451    slow_request_threshold: Duration,
452}
453
454impl PerformanceMiddleware {
455    /// Create a new performance middleware
456    #[must_use]
457    pub fn new(slow_request_threshold: Duration) -> Self {
458        Self {
459            slow_request_threshold,
460        }
461    }
462
463    /// Create with default threshold (1 second)
464    #[must_use]
465    pub fn create_default() -> Self {
466        Self::new(Duration::from_secs(1))
467    }
468
469    /// Create with custom threshold
470    #[must_use]
471    pub fn with_threshold(threshold: Duration) -> Self {
472        Self::new(threshold)
473    }
474}
475
476#[async_trait::async_trait]
477impl McpMiddleware for PerformanceMiddleware {
478    fn name(&self) -> &'static str {
479        "performance"
480    }
481
482    fn priority(&self) -> i32 {
483        200 // High priority to run late
484    }
485
486    async fn after_request(
487        &self,
488        request: &CallToolRequest,
489        _response: &mut CallToolResult,
490        context: &mut MiddlewareContext,
491    ) -> McpResult<MiddlewareResult> {
492        let elapsed = context.elapsed();
493
494        // Record performance metrics
495        context.set_metadata(
496            "duration_ms".to_string(),
497            serde_json::Value::Number(serde_json::Number::from(
498                u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
499            )),
500        );
501
502        context.set_metadata(
503            "is_slow".to_string(),
504            serde_json::Value::Bool(elapsed > self.slow_request_threshold),
505        );
506
507        // Log slow requests
508        if elapsed > self.slow_request_threshold {
509            println!(
510                "[PERF] Slow request detected: {} took {:?} (threshold: {:?})",
511                request.name, elapsed, self.slow_request_threshold
512            );
513        }
514
515        Ok(MiddlewareResult::Continue)
516    }
517}
518
519/// Authentication middleware for API key and OAuth 2.0 support
520pub struct AuthenticationMiddleware {
521    api_keys: HashMap<String, ApiKeyInfo>,
522    jwt_secret: String,
523    #[allow(dead_code)]
524    oauth_config: Option<OAuthConfig>,
525    require_auth: bool,
526}
527
528#[derive(Debug, Clone)]
529pub struct ApiKeyInfo {
530    pub key_id: String,
531    pub permissions: Vec<String>,
532    pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
533}
534
535#[derive(Debug, Clone)]
536pub struct OAuthConfig {
537    pub client_id: String,
538    pub client_secret: String,
539    pub token_endpoint: String,
540    pub scope: Vec<String>,
541}
542
543#[derive(Debug, Serialize, Deserialize)]
544pub struct JwtClaims {
545    pub sub: String, // Subject (user ID)
546    pub exp: usize,  // Expiration time
547    pub iat: usize,  // Issued at
548    pub permissions: Vec<String>,
549}
550
551impl AuthenticationMiddleware {
552    /// Create a new authentication middleware
553    #[must_use]
554    pub fn new(api_keys: HashMap<String, ApiKeyInfo>, jwt_secret: String) -> Self {
555        Self {
556            api_keys,
557            jwt_secret,
558            oauth_config: None,
559            require_auth: true,
560        }
561    }
562
563    /// Create with OAuth 2.0 support
564    #[must_use]
565    pub fn with_oauth(
566        api_keys: HashMap<String, ApiKeyInfo>,
567        jwt_secret: String,
568        oauth_config: OAuthConfig,
569    ) -> Self {
570        Self {
571            api_keys,
572            jwt_secret,
573            oauth_config: Some(oauth_config),
574            require_auth: true,
575        }
576    }
577
578    /// Create without requiring authentication (for testing)
579    #[must_use]
580    pub fn permissive() -> Self {
581        Self {
582            api_keys: HashMap::new(),
583            jwt_secret: "test-secret".to_string(),
584            oauth_config: None,
585            require_auth: false,
586        }
587    }
588
589    /// Extract API key from request headers or arguments
590    fn extract_api_key(request: &CallToolRequest) -> Option<String> {
591        // Check if API key is in request arguments
592        if let Some(args) = &request.arguments {
593            if let Some(api_key) = args.get("api_key").and_then(|v| v.as_str()) {
594                return Some(api_key.to_string());
595            }
596        }
597        None
598    }
599
600    /// Extract JWT token from request headers or arguments
601    fn extract_jwt_token(request: &CallToolRequest) -> Option<String> {
602        // Check if JWT token is in request arguments
603        if let Some(args) = &request.arguments {
604            if let Some(token) = args.get("jwt_token").and_then(|v| v.as_str()) {
605                return Some(token.to_string());
606            }
607        }
608        None
609    }
610
611    /// Validate API key
612    fn validate_api_key(&self, api_key: &str) -> McpResult<ApiKeyInfo> {
613        self.api_keys
614            .get(api_key)
615            .cloned()
616            .ok_or_else(|| McpError::validation_error("Invalid API key"))
617    }
618
619    /// Validate JWT token
620    fn validate_jwt_token(&self, token: &str) -> McpResult<JwtClaims> {
621        let validation = Validation::new(Algorithm::HS256);
622        let key = DecodingKey::from_secret(self.jwt_secret.as_ref());
623
624        let token_data = decode::<JwtClaims>(token, &key, &validation)
625            .map_err(|_| McpError::validation_error("Invalid JWT token"))?;
626
627        // Check if token is expired
628        let now = chrono::Utc::now().timestamp().try_into().unwrap_or(0);
629        if token_data.claims.exp < now {
630            return Err(McpError::validation_error("JWT token has expired"));
631        }
632
633        Ok(token_data.claims)
634    }
635
636    /// Generate JWT token for testing
637    ///
638    /// # Panics
639    /// Panics if JWT encoding fails
640    #[cfg(test)]
641    #[must_use]
642    pub fn generate_test_jwt(&self, user_id: &str, permissions: Vec<String>) -> String {
643        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
644        let now = chrono::Utc::now().timestamp() as usize;
645        let claims = JwtClaims {
646            sub: user_id.to_string(),
647            exp: now + 3600, // 1 hour
648            iat: now,
649            permissions,
650        };
651
652        let header = Header::new(Algorithm::HS256);
653        let key = EncodingKey::from_secret(self.jwt_secret.as_ref());
654        encode(&header, &claims, &key).unwrap()
655    }
656}
657
658#[async_trait::async_trait]
659impl McpMiddleware for AuthenticationMiddleware {
660    fn name(&self) -> &'static str {
661        "authentication"
662    }
663
664    fn priority(&self) -> i32 {
665        10 // High priority to run early
666    }
667
668    async fn before_request(
669        &self,
670        request: &CallToolRequest,
671        context: &mut MiddlewareContext,
672    ) -> McpResult<MiddlewareResult> {
673        if !self.require_auth {
674            context.set_metadata("auth_required".to_string(), Value::Bool(false));
675            return Ok(MiddlewareResult::Continue);
676        }
677
678        // Try API key authentication first
679        if let Some(api_key) = Self::extract_api_key(request) {
680            if let Ok(api_key_info) = self.validate_api_key(&api_key) {
681                context.set_metadata(
682                    "auth_type".to_string(),
683                    Value::String("api_key".to_string()),
684                );
685                context.set_metadata(
686                    "auth_key_id".to_string(),
687                    Value::String(api_key_info.key_id),
688                );
689                context.set_metadata(
690                    "auth_permissions".to_string(),
691                    serde_json::to_value(api_key_info.permissions).unwrap_or(Value::Array(vec![])),
692                );
693                context.set_metadata("auth_required".to_string(), Value::Bool(true));
694                return Ok(MiddlewareResult::Continue);
695            }
696            // API key failed, try JWT
697        }
698
699        // Try JWT authentication
700        if let Some(jwt_token) = Self::extract_jwt_token(request) {
701            if let Ok(claims) = self.validate_jwt_token(&jwt_token) {
702                context.set_metadata("auth_type".to_string(), Value::String("jwt".to_string()));
703                context.set_metadata("auth_user_id".to_string(), Value::String(claims.sub));
704                context.set_metadata(
705                    "auth_permissions".to_string(),
706                    serde_json::to_value(claims.permissions).unwrap_or(Value::Array(vec![])),
707                );
708                context.set_metadata("auth_required".to_string(), Value::Bool(true));
709                return Ok(MiddlewareResult::Continue);
710            }
711            // JWT failed
712        }
713
714        // No valid authentication found
715        let error_result = CallToolResult {
716            content: vec![crate::mcp::Content::Text {
717                text: "Authentication required. Please provide a valid API key or JWT token."
718                    .to_string(),
719            }],
720            is_error: true,
721        };
722
723        Ok(MiddlewareResult::Stop(error_result))
724    }
725}
726
727/// Rate limiting middleware with per-client limits
728pub struct RateLimitMiddleware {
729    rate_limiter: Arc<RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>>,
730    default_limit: u32,
731    #[allow(dead_code)]
732    burst_limit: u32,
733}
734
735impl RateLimitMiddleware {
736    /// Create a new rate limiting middleware
737    #[must_use]
738    pub fn new(requests_per_minute: u32, burst_limit: u32) -> Self {
739        let quota = Quota::per_minute(nonzero!(60u32)); // Use a constant for now
740        let rate_limiter = Arc::new(RateLimiter::keyed(quota));
741
742        Self {
743            rate_limiter,
744            default_limit: requests_per_minute,
745            burst_limit,
746        }
747    }
748
749    /// Create with custom limits
750    #[must_use]
751    pub fn with_limits(requests_per_minute: u32, burst_limit: u32) -> Self {
752        Self::new(requests_per_minute, burst_limit)
753    }
754
755    /// Create with default limits (60 requests per minute, burst of 10)
756    #[allow(clippy::should_implement_trait)]
757    #[must_use]
758    pub fn default() -> Self {
759        Self::new(60, 10)
760    }
761
762    /// Extract client identifier from request
763    fn extract_client_id(request: &CallToolRequest, context: &MiddlewareContext) -> String {
764        // Try to get from authentication context first
765        if let Some(auth_key_id) = context.get_metadata("auth_key_id").and_then(|v| v.as_str()) {
766            return format!("api_key:{auth_key_id}");
767        }
768
769        if let Some(auth_user_id) = context
770            .get_metadata("auth_user_id")
771            .and_then(|v| v.as_str())
772        {
773            return format!("jwt:{auth_user_id}");
774        }
775
776        // Fallback to request-based identifier
777        if let Some(args) = &request.arguments {
778            if let Some(client_id) = args.get("client_id").and_then(|v| v.as_str()) {
779                return format!("client:{client_id}");
780            }
781        }
782
783        // Use request ID as fallback
784        format!("request:{}", context.request_id)
785    }
786
787    /// Check if request is within rate limits
788    fn check_rate_limit(&self, client_id: &str) -> bool {
789        self.rate_limiter.check_key(&client_id.to_string()).is_ok()
790    }
791
792    /// Get remaining requests for client
793    fn get_remaining_requests(&self, _client_id: &str) -> u32 {
794        // This is a simplified implementation
795        // In a real implementation, you'd want to track remaining requests more precisely
796        self.default_limit
797    }
798}
799
800#[async_trait::async_trait]
801impl McpMiddleware for RateLimitMiddleware {
802    fn name(&self) -> &'static str {
803        "rate_limiting"
804    }
805
806    fn priority(&self) -> i32 {
807        20 // Run after authentication but before other middleware
808    }
809
810    async fn before_request(
811        &self,
812        request: &CallToolRequest,
813        context: &mut MiddlewareContext,
814    ) -> McpResult<MiddlewareResult> {
815        let client_id = Self::extract_client_id(request, context);
816
817        if !self.check_rate_limit(&client_id) {
818            let error_result = CallToolResult {
819                content: vec![crate::mcp::Content::Text {
820                    text: format!(
821                        "Rate limit exceeded. Limit: {} requests per minute. Please try again later.",
822                        self.default_limit
823                    ),
824                }],
825                is_error: true,
826            };
827
828            context.set_metadata("rate_limited".to_string(), Value::Bool(true));
829            context.set_metadata("rate_limit_client_id".to_string(), Value::String(client_id));
830
831            return Ok(MiddlewareResult::Stop(error_result));
832        }
833
834        let remaining = self.get_remaining_requests(&client_id);
835        context.set_metadata(
836            "rate_limit_remaining".to_string(),
837            Value::Number(serde_json::Number::from(remaining)),
838        );
839        context.set_metadata("rate_limit_client_id".to_string(), Value::String(client_id));
840
841        Ok(MiddlewareResult::Continue)
842    }
843}
844
845/// Security configuration
846#[derive(Debug, Clone, Serialize, Deserialize)]
847pub struct SecurityConfig {
848    /// Authentication configuration
849    pub authentication: AuthenticationConfig,
850    /// Rate limiting configuration
851    pub rate_limiting: RateLimitingConfig,
852}
853
854#[derive(Debug, Clone, Serialize, Deserialize)]
855pub struct AuthenticationConfig {
856    /// Enable authentication middleware
857    pub enabled: bool,
858    /// Require authentication for all requests
859    pub require_auth: bool,
860    /// JWT secret for token validation
861    pub jwt_secret: String,
862    /// API keys configuration
863    pub api_keys: Vec<ApiKeyConfig>,
864    /// OAuth 2.0 configuration
865    pub oauth: Option<OAuth2Config>,
866}
867
868#[derive(Debug, Clone, Serialize, Deserialize)]
869pub struct ApiKeyConfig {
870    /// API key value
871    pub key: String,
872    /// Key identifier
873    pub key_id: String,
874    /// Permissions for this key
875    pub permissions: Vec<String>,
876    /// Optional expiration date
877    pub expires_at: Option<String>,
878}
879
880#[derive(Debug, Clone, Serialize, Deserialize)]
881pub struct OAuth2Config {
882    /// OAuth client ID
883    pub client_id: String,
884    /// OAuth client secret
885    pub client_secret: String,
886    /// Token endpoint URL
887    pub token_endpoint: String,
888    /// Required scopes
889    pub scopes: Vec<String>,
890}
891
892#[derive(Debug, Clone, Serialize, Deserialize)]
893pub struct RateLimitingConfig {
894    /// Enable rate limiting middleware
895    pub enabled: bool,
896    /// Requests per minute limit
897    pub requests_per_minute: u32,
898    /// Burst limit for short bursts
899    pub burst_limit: u32,
900    /// Custom limits per client type
901    pub custom_limits: Option<HashMap<String, u32>>,
902}
903
904impl Default for SecurityConfig {
905    fn default() -> Self {
906        Self {
907            authentication: AuthenticationConfig {
908                enabled: true,
909                require_auth: false, // Start with auth disabled for easier development
910                jwt_secret: "your-secret-key-change-this-in-production".to_string(),
911                api_keys: vec![],
912                oauth: None,
913            },
914            rate_limiting: RateLimitingConfig {
915                enabled: true,
916                requests_per_minute: 60,
917                burst_limit: 10,
918                custom_limits: None,
919            },
920        }
921    }
922}
923
924/// Middleware configuration
925#[derive(Debug, Clone, Serialize, Deserialize)]
926pub struct MiddlewareConfig {
927    /// Logging configuration
928    pub logging: LoggingConfig,
929    /// Validation configuration
930    pub validation: ValidationConfig,
931    /// Performance monitoring configuration
932    pub performance: PerformanceConfig,
933    /// Security configuration
934    pub security: SecurityConfig,
935}
936
937/// Logging middleware configuration
938#[derive(Debug, Clone, Serialize, Deserialize)]
939pub struct LoggingConfig {
940    /// Enable logging middleware
941    pub enabled: bool,
942    /// Log level for logging middleware
943    pub level: String,
944}
945
946/// Validation middleware configuration
947#[derive(Debug, Clone, Serialize, Deserialize)]
948pub struct ValidationConfig {
949    /// Enable validation middleware
950    pub enabled: bool,
951    /// Use strict validation mode
952    pub strict_mode: bool,
953}
954
955/// Performance monitoring configuration
956#[derive(Debug, Clone, Serialize, Deserialize)]
957pub struct PerformanceConfig {
958    /// Enable performance monitoring
959    pub enabled: bool,
960    /// Slow request threshold in milliseconds
961    pub slow_request_threshold_ms: u64,
962}
963
964impl Default for MiddlewareConfig {
965    fn default() -> Self {
966        Self {
967            logging: LoggingConfig {
968                enabled: true,
969                level: "info".to_string(),
970            },
971            validation: ValidationConfig {
972                enabled: true,
973                strict_mode: false,
974            },
975            performance: PerformanceConfig {
976                enabled: true,
977                slow_request_threshold_ms: 1000,
978            },
979            security: SecurityConfig::default(),
980        }
981    }
982}
983
984impl MiddlewareConfig {
985    /// Create a new middleware configuration
986    #[must_use]
987    pub fn new() -> Self {
988        Self::default()
989    }
990
991    /// Build a middleware chain from this configuration
992    #[must_use]
993    pub fn build_chain(self) -> MiddlewareChain {
994        let mut chain = MiddlewareChain::new();
995
996        // Security middleware (highest priority)
997        if self.security.authentication.enabled {
998            let api_keys: HashMap<String, ApiKeyInfo> = self
999                .security
1000                .authentication
1001                .api_keys
1002                .into_iter()
1003                .map(|config| {
1004                    let expires_at = config.expires_at.and_then(|date_str| {
1005                        chrono::DateTime::parse_from_rfc3339(&date_str)
1006                            .ok()
1007                            .map(|dt| dt.with_timezone(&chrono::Utc))
1008                    });
1009
1010                    let api_key_info = ApiKeyInfo {
1011                        key_id: config.key_id,
1012                        permissions: config.permissions,
1013                        expires_at,
1014                    };
1015
1016                    (config.key, api_key_info)
1017                })
1018                .collect();
1019
1020            let auth_middleware = if self.security.authentication.require_auth {
1021                if let Some(oauth_config) = self.security.authentication.oauth {
1022                    let oauth = OAuthConfig {
1023                        client_id: oauth_config.client_id,
1024                        client_secret: oauth_config.client_secret,
1025                        token_endpoint: oauth_config.token_endpoint,
1026                        scope: oauth_config.scopes,
1027                    };
1028                    AuthenticationMiddleware::with_oauth(
1029                        api_keys,
1030                        self.security.authentication.jwt_secret,
1031                        oauth,
1032                    )
1033                } else {
1034                    AuthenticationMiddleware::new(api_keys, self.security.authentication.jwt_secret)
1035                }
1036            } else {
1037                AuthenticationMiddleware::permissive()
1038            };
1039
1040            chain = chain.add_middleware(auth_middleware);
1041        }
1042
1043        if self.security.rate_limiting.enabled {
1044            let rate_limit_middleware = RateLimitMiddleware::with_limits(
1045                self.security.rate_limiting.requests_per_minute,
1046                self.security.rate_limiting.burst_limit,
1047            );
1048            chain = chain.add_middleware(rate_limit_middleware);
1049        }
1050
1051        if self.logging.enabled {
1052            let log_level = match self.logging.level.to_lowercase().as_str() {
1053                "debug" => LogLevel::Debug,
1054                "warn" => LogLevel::Warn,
1055                "error" => LogLevel::Error,
1056                _ => LogLevel::Info,
1057            };
1058            chain = chain.add_middleware(LoggingMiddleware::new(log_level));
1059        }
1060
1061        if self.validation.enabled {
1062            chain = chain.add_middleware(ValidationMiddleware::new(self.validation.strict_mode));
1063        }
1064
1065        if self.performance.enabled {
1066            let threshold = Duration::from_millis(self.performance.slow_request_threshold_ms);
1067            chain = chain.add_middleware(PerformanceMiddleware::with_threshold(threshold));
1068        }
1069
1070        chain
1071    }
1072}
1073
1074/// Middleware-specific errors
1075#[derive(Error, Debug)]
1076pub enum MiddlewareError {
1077    #[error("Middleware execution failed: {message}")]
1078    ExecutionFailed { message: String },
1079
1080    #[error("Middleware configuration error: {message}")]
1081    ConfigurationError { message: String },
1082
1083    #[error("Middleware chain error: {message}")]
1084    ChainError { message: String },
1085}
1086
1087impl From<MiddlewareError> for McpError {
1088    fn from(error: MiddlewareError) -> Self {
1089        McpError::internal_error(error.to_string())
1090    }
1091}
1092
1093#[cfg(test)]
1094mod tests {
1095    use super::*;
1096    use crate::mcp::Content;
1097
1098    struct TestMiddleware {
1099        priority: i32,
1100    }
1101
1102    #[async_trait::async_trait]
1103    impl McpMiddleware for TestMiddleware {
1104        fn name(&self) -> &'static str {
1105            "test_middleware"
1106        }
1107
1108        fn priority(&self) -> i32 {
1109            self.priority
1110        }
1111    }
1112
1113    #[tokio::test]
1114    async fn test_middleware_chain_creation() {
1115        let chain = MiddlewareChain::new()
1116            .add_middleware(TestMiddleware { priority: 100 })
1117            .add_middleware(TestMiddleware { priority: 50 });
1118
1119        assert_eq!(chain.len(), 2);
1120        assert!(!chain.is_empty());
1121    }
1122
1123    #[tokio::test]
1124    async fn test_middleware_priority_ordering() {
1125        let chain = MiddlewareChain::new()
1126            .add_middleware(TestMiddleware { priority: 10 })
1127            .add_middleware(TestMiddleware { priority: 100 });
1128
1129        // The chain should be sorted by priority
1130        assert_eq!(chain.len(), 2);
1131    }
1132
1133    #[tokio::test]
1134    async fn test_middleware_execution() {
1135        let chain = MiddlewareChain::new()
1136            .add_middleware(LoggingMiddleware::info())
1137            .add_middleware(ValidationMiddleware::lenient());
1138
1139        let request = CallToolRequest {
1140            name: "test_tool".to_string(),
1141            arguments: Some(serde_json::json!({"param": "value"})),
1142        };
1143
1144        let handler = |_req: CallToolRequest| {
1145            Box::pin(async move {
1146                Ok(CallToolResult {
1147                    content: vec![Content::Text {
1148                        text: "Test response".to_string(),
1149                    }],
1150                    is_error: false,
1151                })
1152            })
1153        };
1154
1155        let result = chain.execute(request, handler).await;
1156        assert!(result.is_ok());
1157    }
1158
1159    #[tokio::test]
1160    async fn test_validation_middleware() {
1161        let middleware = ValidationMiddleware::strict();
1162        let mut context = MiddlewareContext::new("test".to_string());
1163
1164        // Valid request
1165        let valid_request = CallToolRequest {
1166            name: "valid_tool".to_string(),
1167            arguments: Some(serde_json::json!({"param": "value"})),
1168        };
1169
1170        let result = middleware
1171            .before_request(&valid_request, &mut context)
1172            .await;
1173        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1174
1175        // Invalid request (empty name)
1176        let invalid_request = CallToolRequest {
1177            name: String::new(),
1178            arguments: None,
1179        };
1180
1181        let result = middleware
1182            .before_request(&invalid_request, &mut context)
1183            .await;
1184        assert!(matches!(result, Ok(MiddlewareResult::Error(_))));
1185    }
1186
1187    #[tokio::test]
1188    async fn test_performance_middleware() {
1189        let middleware = PerformanceMiddleware::with_threshold(Duration::from_millis(100));
1190        let mut context = MiddlewareContext::new("test".to_string());
1191
1192        // Simulate a slow request
1193        tokio::time::sleep(Duration::from_millis(150)).await;
1194
1195        let mut response = CallToolResult {
1196            content: vec![Content::Text {
1197                text: "Test".to_string(),
1198            }],
1199            is_error: false,
1200        };
1201
1202        let request = CallToolRequest {
1203            name: "test".to_string(),
1204            arguments: None,
1205        };
1206
1207        let result = middleware
1208            .after_request(&request, &mut response, &mut context)
1209            .await;
1210        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1211
1212        // Check that performance metadata was set
1213        assert!(context.get_metadata("duration_ms").is_some());
1214        assert!(context.get_metadata("is_slow").is_some());
1215    }
1216
1217    #[tokio::test]
1218    async fn test_middleware_config() {
1219        let config = MiddlewareConfig {
1220            logging: LoggingConfig {
1221                enabled: true,
1222                level: "debug".to_string(),
1223            },
1224            validation: ValidationConfig {
1225                enabled: true,
1226                strict_mode: true,
1227            },
1228            performance: PerformanceConfig {
1229                enabled: true,
1230                slow_request_threshold_ms: 500,
1231            },
1232            security: SecurityConfig::default(),
1233        };
1234
1235        let chain = config.build_chain();
1236        assert!(!chain.is_empty());
1237        assert!(chain.len() >= 3); // Should have logging, validation, and performance
1238    }
1239
1240    #[tokio::test]
1241    async fn test_middleware_context_creation() {
1242        let context = MiddlewareContext::new("test-request-123".to_string());
1243        assert_eq!(context.request_id, "test-request-123");
1244        assert!(context.metadata.is_empty());
1245    }
1246
1247    #[tokio::test]
1248    async fn test_middleware_context_elapsed() {
1249        let context = MiddlewareContext::new("test-request-123".to_string());
1250        std::thread::sleep(std::time::Duration::from_millis(10));
1251        let elapsed = context.elapsed();
1252        assert!(elapsed.as_millis() >= 10);
1253    }
1254
1255    #[tokio::test]
1256    async fn test_middleware_context_metadata() {
1257        let mut context = MiddlewareContext::new("test-request-123".to_string());
1258
1259        // Test setting metadata
1260        context.set_metadata(
1261            "key1".to_string(),
1262            serde_json::Value::String("value1".to_string()),
1263        );
1264        context.set_metadata(
1265            "key2".to_string(),
1266            serde_json::Value::Number(serde_json::Number::from(42)),
1267        );
1268
1269        // Test getting metadata
1270        assert_eq!(
1271            context.get_metadata("key1"),
1272            Some(&serde_json::Value::String("value1".to_string()))
1273        );
1274        assert_eq!(
1275            context.get_metadata("key2"),
1276            Some(&serde_json::Value::Number(serde_json::Number::from(42)))
1277        );
1278        assert_eq!(context.get_metadata("nonexistent"), None);
1279    }
1280
1281    #[tokio::test]
1282    async fn test_middleware_result_variants() {
1283        let continue_result = MiddlewareResult::Continue;
1284        let stop_result = MiddlewareResult::Stop(CallToolResult {
1285            content: vec![Content::Text {
1286                text: "test".to_string(),
1287            }],
1288            is_error: false,
1289        });
1290        let error_result = MiddlewareResult::Error(McpError::tool_not_found("test error"));
1291
1292        // Test that we can create all variants
1293        match continue_result {
1294            MiddlewareResult::Continue => {}
1295            _ => panic!("Expected Continue"),
1296        }
1297
1298        match stop_result {
1299            MiddlewareResult::Stop(_) => {}
1300            _ => panic!("Expected Stop"),
1301        }
1302
1303        match error_result {
1304            MiddlewareResult::Error(_) => {}
1305            _ => panic!("Expected Error"),
1306        }
1307    }
1308
1309    #[tokio::test]
1310    async fn test_logging_middleware_different_levels() {
1311        let debug_middleware = LoggingMiddleware::new(LogLevel::Debug);
1312        let info_middleware = LoggingMiddleware::new(LogLevel::Info);
1313        let warn_middleware = LoggingMiddleware::new(LogLevel::Warn);
1314        let error_middleware = LoggingMiddleware::new(LogLevel::Error);
1315
1316        assert_eq!(debug_middleware.name(), "logging");
1317        assert_eq!(info_middleware.name(), "logging");
1318        assert_eq!(warn_middleware.name(), "logging");
1319        assert_eq!(error_middleware.name(), "logging");
1320    }
1321
1322    #[tokio::test]
1323    async fn test_logging_middleware_should_log() {
1324        let debug_middleware = LoggingMiddleware::new(LogLevel::Debug);
1325        let info_middleware = LoggingMiddleware::new(LogLevel::Info);
1326        let warn_middleware = LoggingMiddleware::new(LogLevel::Warn);
1327        let error_middleware = LoggingMiddleware::new(LogLevel::Error);
1328
1329        // Debug should log everything
1330        assert!(debug_middleware.should_log(LogLevel::Debug));
1331        assert!(debug_middleware.should_log(LogLevel::Info));
1332        assert!(debug_middleware.should_log(LogLevel::Warn));
1333        assert!(debug_middleware.should_log(LogLevel::Error));
1334
1335        // Info should log info, warn, error
1336        assert!(!info_middleware.should_log(LogLevel::Debug));
1337        assert!(info_middleware.should_log(LogLevel::Info));
1338        assert!(info_middleware.should_log(LogLevel::Warn));
1339        assert!(info_middleware.should_log(LogLevel::Error));
1340
1341        // Warn should log warn, error
1342        assert!(!warn_middleware.should_log(LogLevel::Debug));
1343        assert!(!warn_middleware.should_log(LogLevel::Info));
1344        assert!(warn_middleware.should_log(LogLevel::Warn));
1345        assert!(warn_middleware.should_log(LogLevel::Error));
1346
1347        // Error should only log error
1348        assert!(!error_middleware.should_log(LogLevel::Debug));
1349        assert!(!error_middleware.should_log(LogLevel::Info));
1350        assert!(!error_middleware.should_log(LogLevel::Warn));
1351        assert!(error_middleware.should_log(LogLevel::Error));
1352    }
1353
1354    #[tokio::test]
1355    async fn test_validation_middleware_strict_mode() {
1356        let strict_middleware = ValidationMiddleware::strict();
1357        let lenient_middleware = ValidationMiddleware::lenient();
1358
1359        assert_eq!(strict_middleware.name(), "validation");
1360        assert_eq!(lenient_middleware.name(), "validation");
1361    }
1362
1363    #[tokio::test]
1364    async fn test_validation_middleware_creation() {
1365        let middleware1 = ValidationMiddleware::new(true);
1366        let middleware2 = ValidationMiddleware::new(false);
1367
1368        assert_eq!(middleware1.name(), "validation");
1369        assert_eq!(middleware2.name(), "validation");
1370    }
1371
1372    #[tokio::test]
1373    async fn test_performance_middleware_creation() {
1374        let middleware1 = PerformanceMiddleware::new(Duration::from_millis(100));
1375        let middleware2 = PerformanceMiddleware::with_threshold(Duration::from_millis(200));
1376        let middleware3 = PerformanceMiddleware::create_default();
1377
1378        assert_eq!(middleware1.name(), "performance");
1379        assert_eq!(middleware2.name(), "performance");
1380        assert_eq!(middleware3.name(), "performance");
1381    }
1382
1383    #[tokio::test]
1384    async fn test_middleware_chain_empty() {
1385        let chain = MiddlewareChain::new();
1386        assert!(chain.is_empty());
1387        assert_eq!(chain.len(), 0);
1388    }
1389
1390    #[tokio::test]
1391    async fn test_middleware_chain_add_middleware() {
1392        let chain = MiddlewareChain::new()
1393            .add_middleware(LoggingMiddleware::new(LogLevel::Info))
1394            .add_middleware(ValidationMiddleware::new(false));
1395
1396        assert!(!chain.is_empty());
1397        assert_eq!(chain.len(), 2);
1398    }
1399
1400    #[tokio::test]
1401    async fn test_middleware_chain_add_arc() {
1402        let middleware = Arc::new(LoggingMiddleware::new(LogLevel::Info)) as Arc<dyn McpMiddleware>;
1403        let chain = MiddlewareChain::new().add_arc(middleware);
1404
1405        assert!(!chain.is_empty());
1406        assert_eq!(chain.len(), 1);
1407    }
1408
1409    #[tokio::test]
1410    async fn test_middleware_chain_execution_with_empty_chain() {
1411        let chain = MiddlewareChain::new();
1412        let request = CallToolRequest {
1413            name: "test_tool".to_string(),
1414            arguments: None,
1415        };
1416
1417        let result = chain
1418            .execute(request, |_| async {
1419                Ok(CallToolResult {
1420                    content: vec![Content::Text {
1421                        text: "success".to_string(),
1422                    }],
1423                    is_error: false,
1424                })
1425            })
1426            .await;
1427
1428        assert!(result.is_ok());
1429        let result = result.unwrap();
1430        assert!(!result.is_error);
1431        assert_eq!(result.content.len(), 1);
1432    }
1433
1434    #[tokio::test]
1435    async fn test_middleware_chain_execution_with_error() {
1436        let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
1437        let request = CallToolRequest {
1438            name: "test_tool".to_string(),
1439            arguments: None,
1440        };
1441
1442        let result = chain
1443            .execute(request, |_| async {
1444                Err(McpError::tool_not_found("test error"))
1445            })
1446            .await;
1447
1448        assert!(result.is_err());
1449    }
1450
1451    #[tokio::test]
1452    async fn test_middleware_chain_execution_with_stop() {
1453        // Create a middleware that stops execution
1454        struct StopMiddleware;
1455        #[async_trait::async_trait]
1456        impl McpMiddleware for StopMiddleware {
1457            fn name(&self) -> &'static str {
1458                "stop"
1459            }
1460
1461            async fn before_request(
1462                &self,
1463                _request: &CallToolRequest,
1464                _context: &mut MiddlewareContext,
1465            ) -> McpResult<MiddlewareResult> {
1466                Ok(MiddlewareResult::Stop(CallToolResult {
1467                    content: vec![Content::Text {
1468                        text: "stopped".to_string(),
1469                    }],
1470                    is_error: false,
1471                }))
1472            }
1473
1474            async fn after_request(
1475                &self,
1476                _request: &CallToolRequest,
1477                _result: &mut CallToolResult,
1478                _context: &mut MiddlewareContext,
1479            ) -> McpResult<MiddlewareResult> {
1480                Ok(MiddlewareResult::Continue)
1481            }
1482
1483            async fn on_error(
1484                &self,
1485                _request: &CallToolRequest,
1486                _error: &McpError,
1487                _context: &mut MiddlewareContext,
1488            ) -> McpResult<MiddlewareResult> {
1489                Ok(MiddlewareResult::Continue)
1490            }
1491        }
1492
1493        let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
1494        let request = CallToolRequest {
1495            name: "test_tool".to_string(),
1496            arguments: None,
1497        };
1498
1499        let chain = chain.add_middleware(StopMiddleware);
1500
1501        let result = chain
1502            .execute(request, |_| async {
1503                Ok(CallToolResult {
1504                    content: vec![Content::Text {
1505                        text: "should not reach here".to_string(),
1506                    }],
1507                    is_error: false,
1508                })
1509            })
1510            .await;
1511
1512        assert!(result.is_ok());
1513        let result = result.unwrap();
1514        let Content::Text { text } = &result.content[0];
1515        assert_eq!(text, "stopped");
1516    }
1517
1518    #[tokio::test]
1519    async fn test_middleware_chain_execution_with_middleware_error() {
1520        // Create a middleware that returns an error
1521        struct ErrorMiddleware;
1522        #[async_trait::async_trait]
1523        impl McpMiddleware for ErrorMiddleware {
1524            fn name(&self) -> &'static str {
1525                "error"
1526            }
1527
1528            async fn before_request(
1529                &self,
1530                _request: &CallToolRequest,
1531                _context: &mut MiddlewareContext,
1532            ) -> McpResult<MiddlewareResult> {
1533                Err(McpError::tool_not_found("middleware error"))
1534            }
1535
1536            async fn after_request(
1537                &self,
1538                _request: &CallToolRequest,
1539                _result: &mut CallToolResult,
1540                _context: &mut MiddlewareContext,
1541            ) -> McpResult<MiddlewareResult> {
1542                Ok(MiddlewareResult::Continue)
1543            }
1544
1545            async fn on_error(
1546                &self,
1547                _request: &CallToolRequest,
1548                _error: &McpError,
1549                _context: &mut MiddlewareContext,
1550            ) -> McpResult<MiddlewareResult> {
1551                Ok(MiddlewareResult::Continue)
1552            }
1553        }
1554
1555        let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
1556        let request = CallToolRequest {
1557            name: "test_tool".to_string(),
1558            arguments: None,
1559        };
1560
1561        let chain = chain.add_middleware(ErrorMiddleware);
1562
1563        let result = chain
1564            .execute(request, |_| async {
1565                Ok(CallToolResult {
1566                    content: vec![Content::Text {
1567                        text: "should not reach here".to_string(),
1568                    }],
1569                    is_error: false,
1570                })
1571            })
1572            .await;
1573
1574        assert!(result.is_err());
1575        let error = result.unwrap_err();
1576        assert!(matches!(error, McpError::ToolNotFound { tool_name: _ }));
1577    }
1578
1579    #[tokio::test]
1580    async fn test_middleware_chain_execution_with_on_error() {
1581        // Create a middleware that handles errors
1582        struct ErrorHandlerMiddleware;
1583        #[async_trait::async_trait]
1584        impl McpMiddleware for ErrorHandlerMiddleware {
1585            fn name(&self) -> &'static str {
1586                "error_handler"
1587            }
1588
1589            async fn before_request(
1590                &self,
1591                _request: &CallToolRequest,
1592                _context: &mut MiddlewareContext,
1593            ) -> McpResult<MiddlewareResult> {
1594                Ok(MiddlewareResult::Continue)
1595            }
1596
1597            async fn after_request(
1598                &self,
1599                _request: &CallToolRequest,
1600                _result: &mut CallToolResult,
1601                _context: &mut MiddlewareContext,
1602            ) -> McpResult<MiddlewareResult> {
1603                Ok(MiddlewareResult::Continue)
1604            }
1605
1606            async fn on_error(
1607                &self,
1608                _request: &CallToolRequest,
1609                _error: &McpError,
1610                _context: &mut MiddlewareContext,
1611            ) -> McpResult<MiddlewareResult> {
1612                Ok(MiddlewareResult::Stop(CallToolResult {
1613                    content: vec![Content::Text {
1614                        text: "error handled".to_string(),
1615                    }],
1616                    is_error: false,
1617                }))
1618            }
1619        }
1620
1621        let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
1622        let request = CallToolRequest {
1623            name: "test_tool".to_string(),
1624            arguments: None,
1625        };
1626
1627        let chain = chain.add_middleware(ErrorHandlerMiddleware);
1628
1629        let result = chain
1630            .execute(request, |_| async {
1631                Err(McpError::tool_not_found("test error"))
1632            })
1633            .await;
1634
1635        assert!(result.is_ok());
1636        let result = result.unwrap();
1637        let Content::Text { text } = &result.content[0];
1638        assert_eq!(text, "error handled");
1639    }
1640
1641    #[tokio::test]
1642    async fn test_config_structs_creation() {
1643        let logging_config = LoggingConfig {
1644            enabled: true,
1645            level: "debug".to_string(),
1646        };
1647        let validation_config = ValidationConfig {
1648            enabled: true,
1649            strict_mode: true,
1650        };
1651        let performance_config = PerformanceConfig {
1652            enabled: true,
1653            slow_request_threshold_ms: 1000,
1654        };
1655
1656        assert!(logging_config.enabled);
1657        assert_eq!(logging_config.level, "debug");
1658        assert!(validation_config.enabled);
1659        assert!(validation_config.strict_mode);
1660        assert!(performance_config.enabled);
1661        assert_eq!(performance_config.slow_request_threshold_ms, 1000);
1662    }
1663
1664    #[tokio::test]
1665    async fn test_config_default() {
1666        let config = MiddlewareConfig::default();
1667        assert!(config.logging.enabled);
1668        assert_eq!(config.logging.level, "info");
1669        assert!(config.validation.enabled);
1670        assert!(!config.validation.strict_mode);
1671        assert!(config.performance.enabled);
1672        assert_eq!(config.performance.slow_request_threshold_ms, 1000);
1673    }
1674
1675    #[tokio::test]
1676    async fn test_config_build_chain_with_disabled_middleware() {
1677        let config = MiddlewareConfig {
1678            logging: LoggingConfig {
1679                enabled: false,
1680                level: "debug".to_string(),
1681            },
1682            validation: ValidationConfig {
1683                enabled: false,
1684                strict_mode: true,
1685            },
1686            performance: PerformanceConfig {
1687                enabled: false,
1688                slow_request_threshold_ms: 1000,
1689            },
1690            security: SecurityConfig {
1691                authentication: AuthenticationConfig {
1692                    enabled: false,
1693                    require_auth: false,
1694                    jwt_secret: "test".to_string(),
1695                    api_keys: vec![],
1696                    oauth: None,
1697                },
1698                rate_limiting: RateLimitingConfig {
1699                    enabled: false,
1700                    requests_per_minute: 60,
1701                    burst_limit: 10,
1702                    custom_limits: None,
1703                },
1704            },
1705        };
1706
1707        let chain = config.build_chain();
1708        assert!(chain.is_empty());
1709    }
1710
1711    #[tokio::test]
1712    async fn test_config_build_chain_with_partial_middleware() {
1713        let config = MiddlewareConfig {
1714            logging: LoggingConfig {
1715                enabled: true,
1716                level: "debug".to_string(),
1717            },
1718            validation: ValidationConfig {
1719                enabled: false,
1720                strict_mode: true,
1721            },
1722            performance: PerformanceConfig {
1723                enabled: true,
1724                slow_request_threshold_ms: 1000,
1725            },
1726            security: SecurityConfig::default(),
1727        };
1728
1729        let chain = config.build_chain();
1730        assert!(!chain.is_empty());
1731        assert!(chain.len() >= 2); // At least logging and performance
1732    }
1733
1734    #[tokio::test]
1735    async fn test_config_build_chain_with_invalid_log_level() {
1736        let config = MiddlewareConfig {
1737            logging: LoggingConfig {
1738                enabled: true,
1739                level: "invalid".to_string(),
1740            },
1741            validation: ValidationConfig {
1742                enabled: true,
1743                strict_mode: true,
1744            },
1745            performance: PerformanceConfig {
1746                enabled: true,
1747                slow_request_threshold_ms: 1000,
1748            },
1749            security: SecurityConfig::default(),
1750        };
1751
1752        let chain = config.build_chain();
1753        assert!(!chain.is_empty());
1754        // Should default to info level
1755    }
1756
1757    #[tokio::test]
1758    async fn test_middleware_chain_execution_with_empty_middleware() {
1759        let chain = MiddlewareChain::new();
1760        let request = CallToolRequest {
1761            name: "test_tool".to_string(),
1762            arguments: Some(serde_json::json!({"param": "value"})),
1763        };
1764
1765        let result = chain
1766            .execute(request, |_| async {
1767                Ok(CallToolResult {
1768                    content: vec![Content::Text {
1769                        text: "Test response".to_string(),
1770                    }],
1771                    is_error: false,
1772                })
1773            })
1774            .await;
1775
1776        assert!(result.is_ok());
1777        let result = result.unwrap();
1778        assert!(!result.is_error);
1779        assert_eq!(result.content.len(), 1);
1780    }
1781
1782    #[tokio::test]
1783    async fn test_middleware_chain_execution_with_multiple_middleware() {
1784        let chain = MiddlewareChain::new()
1785            .add_middleware(LoggingMiddleware::new(LogLevel::Info))
1786            .add_middleware(ValidationMiddleware::new(false))
1787            .add_middleware(PerformanceMiddleware::new(Duration::from_millis(100)));
1788
1789        let request = CallToolRequest {
1790            name: "test_tool".to_string(),
1791            arguments: Some(serde_json::json!({"param": "value"})),
1792        };
1793
1794        let result = chain
1795            .execute(request, |_| async {
1796                Ok(CallToolResult {
1797                    content: vec![Content::Text {
1798                        text: "Test response".to_string(),
1799                    }],
1800                    is_error: false,
1801                })
1802            })
1803            .await;
1804
1805        assert!(result.is_ok());
1806        let result = result.unwrap();
1807        assert!(!result.is_error);
1808        assert_eq!(result.content.len(), 1);
1809    }
1810
1811    #[tokio::test]
1812    async fn test_middleware_chain_execution_with_middleware_stop() {
1813        struct StopMiddleware;
1814        #[async_trait::async_trait]
1815        impl McpMiddleware for StopMiddleware {
1816            fn name(&self) -> &'static str {
1817                "stop_middleware"
1818            }
1819
1820            fn priority(&self) -> i32 {
1821                100
1822            }
1823
1824            async fn before_request(
1825                &self,
1826                _request: &CallToolRequest,
1827                _context: &mut MiddlewareContext,
1828            ) -> McpResult<MiddlewareResult> {
1829                Ok(MiddlewareResult::Stop(CallToolResult {
1830                    content: vec![Content::Text {
1831                        text: "Stopped by middleware".to_string(),
1832                    }],
1833                    is_error: false,
1834                }))
1835            }
1836
1837            async fn after_request(
1838                &self,
1839                _request: &CallToolRequest,
1840                _result: &mut CallToolResult,
1841                _context: &mut MiddlewareContext,
1842            ) -> McpResult<MiddlewareResult> {
1843                Ok(MiddlewareResult::Continue)
1844            }
1845
1846            async fn on_error(
1847                &self,
1848                _request: &CallToolRequest,
1849                _error: &McpError,
1850                _context: &mut MiddlewareContext,
1851            ) -> McpResult<MiddlewareResult> {
1852                Ok(MiddlewareResult::Continue)
1853            }
1854        }
1855
1856        let chain = MiddlewareChain::new().add_middleware(StopMiddleware);
1857
1858        let request = CallToolRequest {
1859            name: "test_tool".to_string(),
1860            arguments: None,
1861        };
1862
1863        let result = chain
1864            .execute(request, |_| async {
1865                Ok(CallToolResult {
1866                    content: vec![Content::Text {
1867                        text: "Should not reach here".to_string(),
1868                    }],
1869                    is_error: false,
1870                })
1871            })
1872            .await;
1873
1874        assert!(result.is_ok());
1875        let result = result.unwrap();
1876        assert!(!result.is_error);
1877        let Content::Text { text } = &result.content[0];
1878        assert_eq!(text, "Stopped by middleware");
1879    }
1880
1881    #[tokio::test]
1882    async fn test_middleware_chain_execution_with_middleware_error_duplicate() {
1883        struct ErrorMiddleware;
1884        #[async_trait::async_trait]
1885        impl McpMiddleware for ErrorMiddleware {
1886            fn name(&self) -> &'static str {
1887                "error_middleware"
1888            }
1889
1890            fn priority(&self) -> i32 {
1891                100
1892            }
1893
1894            async fn before_request(
1895                &self,
1896                _request: &CallToolRequest,
1897                _context: &mut MiddlewareContext,
1898            ) -> McpResult<MiddlewareResult> {
1899                Err(McpError::internal_error("Middleware error"))
1900            }
1901
1902            async fn after_request(
1903                &self,
1904                _request: &CallToolRequest,
1905                _result: &mut CallToolResult,
1906                _context: &mut MiddlewareContext,
1907            ) -> McpResult<MiddlewareResult> {
1908                Ok(MiddlewareResult::Continue)
1909            }
1910
1911            async fn on_error(
1912                &self,
1913                _request: &CallToolRequest,
1914                _error: &McpError,
1915                _context: &mut MiddlewareContext,
1916            ) -> McpResult<MiddlewareResult> {
1917                Ok(MiddlewareResult::Continue)
1918            }
1919        }
1920
1921        let chain = MiddlewareChain::new().add_middleware(ErrorMiddleware);
1922
1923        let request = CallToolRequest {
1924            name: "test_tool".to_string(),
1925            arguments: None,
1926        };
1927
1928        let result = chain
1929            .execute(request, |_| async {
1930                Ok(CallToolResult {
1931                    content: vec![Content::Text {
1932                        text: "Should not reach here".to_string(),
1933                    }],
1934                    is_error: false,
1935                })
1936            })
1937            .await;
1938
1939        assert!(result.is_err());
1940        let error = result.unwrap_err();
1941        assert!(matches!(error, McpError::InternalError { .. }));
1942    }
1943
1944    // Authentication Middleware Tests
1945    #[tokio::test]
1946    async fn test_authentication_middleware_permissive() {
1947        let middleware = AuthenticationMiddleware::permissive();
1948        let mut context = MiddlewareContext::new("test".to_string());
1949
1950        let request = CallToolRequest {
1951            name: "test_tool".to_string(),
1952            arguments: None,
1953        };
1954
1955        let result = middleware.before_request(&request, &mut context).await;
1956
1957        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1958        assert_eq!(
1959            context.get_metadata("auth_required"),
1960            Some(&Value::Bool(false))
1961        );
1962    }
1963
1964    #[tokio::test]
1965    async fn test_authentication_middleware_with_valid_api_key() {
1966        let mut api_keys = HashMap::new();
1967        api_keys.insert(
1968            "test-api-key".to_string(),
1969            ApiKeyInfo {
1970                key_id: "test-key-1".to_string(),
1971                permissions: vec!["read".to_string(), "write".to_string()],
1972                expires_at: None,
1973            },
1974        );
1975
1976        let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1977        let mut context = MiddlewareContext::new("test".to_string());
1978
1979        let request = CallToolRequest {
1980            name: "test_tool".to_string(),
1981            arguments: Some(serde_json::json!({
1982                "api_key": "test-api-key"
1983            })),
1984        };
1985
1986        let result = middleware.before_request(&request, &mut context).await;
1987
1988        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1989        assert_eq!(
1990            context.get_metadata("auth_type"),
1991            Some(&Value::String("api_key".to_string()))
1992        );
1993        assert_eq!(
1994            context.get_metadata("auth_key_id"),
1995            Some(&Value::String("test-key-1".to_string()))
1996        );
1997    }
1998
1999    #[tokio::test]
2000    async fn test_authentication_middleware_with_invalid_api_key() {
2001        let api_keys = HashMap::new();
2002        let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
2003        let mut context = MiddlewareContext::new("test".to_string());
2004
2005        let request = CallToolRequest {
2006            name: "test_tool".to_string(),
2007            arguments: Some(serde_json::json!({
2008                "api_key": "invalid-key"
2009            })),
2010        };
2011
2012        let result = middleware.before_request(&request, &mut context).await;
2013
2014        assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
2015    }
2016
2017    #[tokio::test]
2018    async fn test_authentication_middleware_with_valid_jwt() {
2019        let api_keys = HashMap::new();
2020        let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
2021
2022        // Generate a test JWT token
2023        let jwt_token = middleware.generate_test_jwt("user123", vec!["read".to_string()]);
2024
2025        let mut context = MiddlewareContext::new("test".to_string());
2026
2027        let request = CallToolRequest {
2028            name: "test_tool".to_string(),
2029            arguments: Some(serde_json::json!({
2030                "jwt_token": jwt_token
2031            })),
2032        };
2033
2034        let result = middleware.before_request(&request, &mut context).await;
2035
2036        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
2037        assert_eq!(
2038            context.get_metadata("auth_type"),
2039            Some(&Value::String("jwt".to_string()))
2040        );
2041        assert_eq!(
2042            context.get_metadata("auth_user_id"),
2043            Some(&Value::String("user123".to_string()))
2044        );
2045    }
2046
2047    #[tokio::test]
2048    async fn test_authentication_middleware_with_invalid_jwt() {
2049        let api_keys = HashMap::new();
2050        let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
2051        let mut context = MiddlewareContext::new("test".to_string());
2052
2053        let request = CallToolRequest {
2054            name: "test_tool".to_string(),
2055            arguments: Some(serde_json::json!({
2056                "jwt_token": "invalid.jwt.token"
2057            })),
2058        };
2059
2060        let result = middleware.before_request(&request, &mut context).await;
2061
2062        assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
2063    }
2064
2065    #[tokio::test]
2066    async fn test_authentication_middleware_no_auth_provided() {
2067        let api_keys = HashMap::new();
2068        let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
2069        let mut context = MiddlewareContext::new("test".to_string());
2070
2071        let request = CallToolRequest {
2072            name: "test_tool".to_string(),
2073            arguments: None,
2074        };
2075
2076        let result = middleware.before_request(&request, &mut context).await;
2077
2078        assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
2079    }
2080
2081    // Rate Limiting Middleware Tests
2082    #[tokio::test]
2083    async fn test_rate_limit_middleware_allows_request() {
2084        let middleware = RateLimitMiddleware::new(10, 5);
2085        let mut context = MiddlewareContext::new("test".to_string());
2086
2087        let request = CallToolRequest {
2088            name: "test_tool".to_string(),
2089            arguments: Some(serde_json::json!({
2090                "client_id": "test-client"
2091            })),
2092        };
2093
2094        let result = middleware.before_request(&request, &mut context).await;
2095
2096        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
2097        assert_eq!(
2098            context.get_metadata("rate_limit_client_id"),
2099            Some(&Value::String("client:test-client".to_string()))
2100        );
2101    }
2102
2103    #[tokio::test]
2104    async fn test_rate_limit_middleware_uses_auth_context() {
2105        let middleware = RateLimitMiddleware::new(10, 5);
2106        let mut context = MiddlewareContext::new("test".to_string());
2107
2108        // Set up auth context
2109        context.set_metadata(
2110            "auth_key_id".to_string(),
2111            Value::String("api-key-123".to_string()),
2112        );
2113
2114        let request = CallToolRequest {
2115            name: "test_tool".to_string(),
2116            arguments: None,
2117        };
2118
2119        let result = middleware.before_request(&request, &mut context).await;
2120
2121        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
2122        assert_eq!(
2123            context.get_metadata("rate_limit_client_id"),
2124            Some(&Value::String("api_key:api-key-123".to_string()))
2125        );
2126    }
2127
2128    #[tokio::test]
2129    async fn test_rate_limit_middleware_uses_jwt_context() {
2130        let middleware = RateLimitMiddleware::new(10, 5);
2131        let mut context = MiddlewareContext::new("test".to_string());
2132
2133        // Set up JWT context
2134        context.set_metadata(
2135            "auth_user_id".to_string(),
2136            Value::String("user-456".to_string()),
2137        );
2138
2139        let request = CallToolRequest {
2140            name: "test_tool".to_string(),
2141            arguments: None,
2142        };
2143
2144        let result = middleware.before_request(&request, &mut context).await;
2145
2146        assert!(matches!(result, Ok(MiddlewareResult::Continue)));
2147        assert_eq!(
2148            context.get_metadata("rate_limit_client_id"),
2149            Some(&Value::String("jwt:user-456".to_string()))
2150        );
2151    }
2152
2153    // Security Configuration Tests
2154    #[tokio::test]
2155    async fn test_security_config_default() {
2156        let config = SecurityConfig::default();
2157        assert!(config.authentication.enabled);
2158        assert!(!config.authentication.require_auth); // Should be false for easier development
2159        assert!(config.rate_limiting.enabled);
2160        assert_eq!(config.rate_limiting.requests_per_minute, 60);
2161    }
2162
2163    #[tokio::test]
2164    async fn test_middleware_config_with_security() {
2165        let config = MiddlewareConfig {
2166            logging: LoggingConfig {
2167                enabled: true,
2168                level: "debug".to_string(),
2169            },
2170            validation: ValidationConfig {
2171                enabled: true,
2172                strict_mode: true,
2173            },
2174            performance: PerformanceConfig {
2175                enabled: true,
2176                slow_request_threshold_ms: 500,
2177            },
2178            security: SecurityConfig {
2179                authentication: AuthenticationConfig {
2180                    enabled: true,
2181                    require_auth: true,
2182                    jwt_secret: "test-secret".to_string(),
2183                    api_keys: vec![ApiKeyConfig {
2184                        key: "test-key".to_string(),
2185                        key_id: "test-id".to_string(),
2186                        permissions: vec!["read".to_string()],
2187                        expires_at: None,
2188                    }],
2189                    oauth: None,
2190                },
2191                rate_limiting: RateLimitingConfig {
2192                    enabled: true,
2193                    requests_per_minute: 30,
2194                    burst_limit: 5,
2195                    custom_limits: None,
2196                },
2197            },
2198        };
2199
2200        let chain = config.build_chain();
2201        assert!(!chain.is_empty());
2202        assert!(chain.len() >= 5); // Should have auth, rate limiting, logging, validation, and performance
2203    }
2204
2205    #[tokio::test]
2206    async fn test_middleware_chain_with_security_middleware() {
2207        let mut api_keys = HashMap::new();
2208        api_keys.insert(
2209            "test-key".to_string(),
2210            ApiKeyInfo {
2211                key_id: "test-id".to_string(),
2212                permissions: vec!["read".to_string()],
2213                expires_at: None,
2214            },
2215        );
2216
2217        let chain = MiddlewareChain::new()
2218            .add_middleware(AuthenticationMiddleware::new(
2219                api_keys,
2220                "test-secret".to_string(),
2221            ))
2222            .add_middleware(RateLimitMiddleware::new(10, 5))
2223            .add_middleware(LoggingMiddleware::new(LogLevel::Info));
2224
2225        let request = CallToolRequest {
2226            name: "test_tool".to_string(),
2227            arguments: Some(serde_json::json!({
2228                "api_key": "test-key"
2229            })),
2230        };
2231
2232        let result = chain
2233            .execute(request, |_| async {
2234                Ok(CallToolResult {
2235                    content: vec![Content::Text {
2236                        text: "success".to_string(),
2237                    }],
2238                    is_error: false,
2239                })
2240            })
2241            .await;
2242
2243        assert!(result.is_ok());
2244    }
2245}