1use crate::mcp::{CallToolRequest, CallToolResult, McpError, McpResult};
4use governor::clock::DefaultClock;
5use governor::{state::keyed::DefaultKeyedStateStore, Quota, RateLimiter};
6#[allow(unused_imports)]
7use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
8use nonzero_ext::nonzero;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use thiserror::Error;
15
16#[derive(Debug, Clone)]
18pub struct MiddlewareContext {
19 pub request_id: String,
21 pub start_time: Instant,
23 pub metadata: std::collections::HashMap<String, Value>,
25}
26
27impl MiddlewareContext {
28 #[must_use]
30 pub fn new(request_id: String) -> Self {
31 Self {
32 request_id,
33 start_time: Instant::now(),
34 metadata: std::collections::HashMap::new(),
35 }
36 }
37
38 #[must_use]
40 pub fn elapsed(&self) -> Duration {
41 self.start_time.elapsed()
42 }
43
44 pub fn set_metadata(&mut self, key: String, value: Value) {
46 self.metadata.insert(key, value);
47 }
48
49 #[must_use]
51 pub fn get_metadata(&self, key: &str) -> Option<&Value> {
52 self.metadata.get(key)
53 }
54}
55
56#[derive(Debug)]
58pub enum MiddlewareResult {
59 Continue,
61 Stop(CallToolResult),
63 Error(McpError),
65}
66
67#[async_trait::async_trait]
69pub trait McpMiddleware: Send + Sync {
70 fn name(&self) -> &str;
72
73 fn priority(&self) -> i32 {
75 0
76 }
77
78 async fn before_request(
80 &self,
81 request: &CallToolRequest,
82 context: &mut MiddlewareContext,
83 ) -> McpResult<MiddlewareResult> {
84 let _ = (request, context);
85 Ok(MiddlewareResult::Continue)
86 }
87
88 async fn after_request(
90 &self,
91 request: &CallToolRequest,
92 response: &mut CallToolResult,
93 context: &mut MiddlewareContext,
94 ) -> McpResult<MiddlewareResult> {
95 let _ = (request, response, context);
96 Ok(MiddlewareResult::Continue)
97 }
98
99 async fn on_error(
101 &self,
102 request: &CallToolRequest,
103 error: &McpError,
104 context: &mut MiddlewareContext,
105 ) -> McpResult<MiddlewareResult> {
106 let _ = (request, error, context);
107 Ok(MiddlewareResult::Continue)
108 }
109}
110
111pub struct MiddlewareChain {
113 middlewares: Vec<Arc<dyn McpMiddleware>>,
114}
115
116impl MiddlewareChain {
117 #[must_use]
119 pub fn new() -> Self {
120 Self {
121 middlewares: Vec::new(),
122 }
123 }
124
125 #[must_use]
127 pub fn add_middleware<M: McpMiddleware + 'static>(mut self, middleware: M) -> Self {
128 self.middlewares.push(Arc::new(middleware));
129 self.sort_by_priority();
130 self
131 }
132
133 #[must_use]
135 pub fn add_arc(mut self, middleware: Arc<dyn McpMiddleware>) -> Self {
136 self.middlewares.push(middleware);
137 self.sort_by_priority();
138 self
139 }
140
141 fn sort_by_priority(&mut self) {
143 self.middlewares.sort_by_key(|m| m.priority());
144 }
145
146 pub async fn execute<F, Fut>(
155 &self,
156 request: CallToolRequest,
157 handler: F,
158 ) -> McpResult<CallToolResult>
159 where
160 F: FnOnce(CallToolRequest) -> Fut,
161 Fut: std::future::Future<Output = McpResult<CallToolResult>> + Send,
162 {
163 let request_id = uuid::Uuid::new_v4().to_string();
164 let mut context = MiddlewareContext::new(request_id);
165
166 for middleware in &self.middlewares {
168 match middleware.before_request(&request, &mut context).await? {
169 MiddlewareResult::Continue => {}
170 MiddlewareResult::Stop(result) => return Ok(result),
171 MiddlewareResult::Error(error) => return Err(error),
172 }
173 }
174
175 let request_clone = request.clone();
177
178 let mut result = match handler(request).await {
180 Ok(response) => response,
181 Err(error) => {
182 for middleware in &self.middlewares {
184 match middleware
185 .on_error(&request_clone, &error, &mut context)
186 .await?
187 {
188 MiddlewareResult::Continue => {}
189 MiddlewareResult::Stop(result) => return Ok(result),
190 MiddlewareResult::Error(middleware_error) => return Err(middleware_error),
191 }
192 }
193 return Err(error);
194 }
195 };
196
197 for middleware in &self.middlewares {
199 match middleware
200 .after_request(&request_clone, &mut result, &mut context)
201 .await?
202 {
203 MiddlewareResult::Continue => {}
204 MiddlewareResult::Stop(new_result) => return Ok(new_result),
205 MiddlewareResult::Error(error) => return Err(error),
206 }
207 }
208
209 Ok(result)
210 }
211
212 #[must_use]
214 pub fn len(&self) -> usize {
215 self.middlewares.len()
216 }
217
218 #[must_use]
220 pub fn is_empty(&self) -> bool {
221 self.middlewares.is_empty()
222 }
223}
224
225impl Default for MiddlewareChain {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231pub struct LoggingMiddleware {
233 level: LogLevel,
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq)]
237pub enum LogLevel {
238 Debug,
239 Info,
240 Warn,
241 Error,
242}
243
244impl LoggingMiddleware {
245 #[must_use]
247 pub fn new(level: LogLevel) -> Self {
248 Self { level }
249 }
250
251 #[must_use]
253 pub fn debug() -> Self {
254 Self::new(LogLevel::Debug)
255 }
256
257 #[must_use]
259 pub fn info() -> Self {
260 Self::new(LogLevel::Info)
261 }
262
263 #[must_use]
265 pub fn warn() -> Self {
266 Self::new(LogLevel::Warn)
267 }
268
269 #[must_use]
271 pub fn error() -> Self {
272 Self::new(LogLevel::Error)
273 }
274
275 fn should_log(&self, level: LogLevel) -> bool {
276 matches!(
277 (self.level, level),
278 (LogLevel::Debug, _)
279 | (
280 LogLevel::Info,
281 LogLevel::Info | LogLevel::Warn | LogLevel::Error
282 )
283 | (LogLevel::Warn, LogLevel::Warn | LogLevel::Error)
284 | (LogLevel::Error, LogLevel::Error)
285 )
286 }
287
288 fn log(&self, level: LogLevel, message: &str) {
289 if self.should_log(level) {
290 match level {
291 LogLevel::Debug => println!("[DEBUG] {message}"),
292 LogLevel::Info => println!("[INFO] {message}"),
293 LogLevel::Warn => println!("[WARN] {message}"),
294 LogLevel::Error => println!("[ERROR] {message}"),
295 }
296 }
297 }
298}
299
300#[async_trait::async_trait]
301impl McpMiddleware for LoggingMiddleware {
302 fn name(&self) -> &'static str {
303 "logging"
304 }
305
306 fn priority(&self) -> i32 {
307 100 }
309
310 async fn before_request(
311 &self,
312 request: &CallToolRequest,
313 context: &mut MiddlewareContext,
314 ) -> McpResult<MiddlewareResult> {
315 self.log(
316 LogLevel::Info,
317 &format!(
318 "Request started: {} (ID: {})",
319 request.name, context.request_id
320 ),
321 );
322 Ok(MiddlewareResult::Continue)
323 }
324
325 async fn after_request(
326 &self,
327 request: &CallToolRequest,
328 response: &mut CallToolResult,
329 context: &mut MiddlewareContext,
330 ) -> McpResult<MiddlewareResult> {
331 let elapsed = context.elapsed();
332 let status = if response.is_error {
333 "ERROR"
334 } else {
335 "SUCCESS"
336 };
337
338 self.log(
339 LogLevel::Info,
340 &format!(
341 "Request completed: {} (ID: {}) - {} in {:?}",
342 request.name, context.request_id, status, elapsed
343 ),
344 );
345 Ok(MiddlewareResult::Continue)
346 }
347
348 async fn on_error(
349 &self,
350 request: &CallToolRequest,
351 error: &McpError,
352 context: &mut MiddlewareContext,
353 ) -> McpResult<MiddlewareResult> {
354 self.log(
355 LogLevel::Error,
356 &format!(
357 "Request failed: {} (ID: {}) - {}",
358 request.name, context.request_id, error
359 ),
360 );
361 Ok(MiddlewareResult::Continue)
362 }
363}
364
365pub struct ValidationMiddleware {
367 strict_mode: bool,
368}
369
370impl ValidationMiddleware {
371 #[must_use]
373 pub fn new(strict_mode: bool) -> Self {
374 Self { strict_mode }
375 }
376
377 #[must_use]
379 pub fn strict() -> Self {
380 Self::new(true)
381 }
382
383 #[must_use]
385 pub fn lenient() -> Self {
386 Self::new(false)
387 }
388
389 fn validate_request(&self, request: &CallToolRequest) -> McpResult<()> {
390 if request.name.is_empty() {
392 return Err(McpError::validation_error("Tool name cannot be empty"));
393 }
394
395 if !request
397 .name
398 .chars()
399 .all(|c| c.is_alphanumeric() || c == '_')
400 {
401 return Err(McpError::validation_error(
402 "Tool name must contain only alphanumeric characters and underscores",
403 ));
404 }
405
406 if self.strict_mode {
408 if let Some(args) = &request.arguments {
409 if !args.is_object() {
410 return Err(McpError::validation_error(
411 "Arguments must be a JSON object",
412 ));
413 }
414 }
415 }
416
417 Ok(())
418 }
419}
420
421#[async_trait::async_trait]
422impl McpMiddleware for ValidationMiddleware {
423 fn name(&self) -> &'static str {
424 "validation"
425 }
426
427 fn priority(&self) -> i32 {
428 50 }
430
431 async fn before_request(
432 &self,
433 request: &CallToolRequest,
434 context: &mut MiddlewareContext,
435 ) -> McpResult<MiddlewareResult> {
436 if let Err(error) = self.validate_request(request) {
437 context.set_metadata(
438 "validation_error".to_string(),
439 serde_json::Value::String(error.to_string()),
440 );
441 return Ok(MiddlewareResult::Error(error));
442 }
443
444 context.set_metadata("validated".to_string(), serde_json::Value::Bool(true));
445 Ok(MiddlewareResult::Continue)
446 }
447}
448
449pub struct PerformanceMiddleware {
451 slow_request_threshold: Duration,
452}
453
454impl PerformanceMiddleware {
455 #[must_use]
457 pub fn new(slow_request_threshold: Duration) -> Self {
458 Self {
459 slow_request_threshold,
460 }
461 }
462
463 #[must_use]
465 pub fn create_default() -> Self {
466 Self::new(Duration::from_secs(1))
467 }
468
469 #[must_use]
471 pub fn with_threshold(threshold: Duration) -> Self {
472 Self::new(threshold)
473 }
474}
475
476#[async_trait::async_trait]
477impl McpMiddleware for PerformanceMiddleware {
478 fn name(&self) -> &'static str {
479 "performance"
480 }
481
482 fn priority(&self) -> i32 {
483 200 }
485
486 async fn after_request(
487 &self,
488 request: &CallToolRequest,
489 _response: &mut CallToolResult,
490 context: &mut MiddlewareContext,
491 ) -> McpResult<MiddlewareResult> {
492 let elapsed = context.elapsed();
493
494 context.set_metadata(
496 "duration_ms".to_string(),
497 serde_json::Value::Number(serde_json::Number::from(
498 u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
499 )),
500 );
501
502 context.set_metadata(
503 "is_slow".to_string(),
504 serde_json::Value::Bool(elapsed > self.slow_request_threshold),
505 );
506
507 if elapsed > self.slow_request_threshold {
509 println!(
510 "[PERF] Slow request detected: {} took {:?} (threshold: {:?})",
511 request.name, elapsed, self.slow_request_threshold
512 );
513 }
514
515 Ok(MiddlewareResult::Continue)
516 }
517}
518
519pub struct AuthenticationMiddleware {
521 api_keys: HashMap<String, ApiKeyInfo>,
522 jwt_secret: String,
523 #[allow(dead_code)]
524 oauth_config: Option<OAuthConfig>,
525 require_auth: bool,
526}
527
528#[derive(Debug, Clone)]
529pub struct ApiKeyInfo {
530 pub key_id: String,
531 pub permissions: Vec<String>,
532 pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
533}
534
535#[derive(Debug, Clone)]
536pub struct OAuthConfig {
537 pub client_id: String,
538 pub client_secret: String,
539 pub token_endpoint: String,
540 pub scope: Vec<String>,
541}
542
543#[derive(Debug, Serialize, Deserialize)]
544pub struct JwtClaims {
545 pub sub: String, pub exp: usize, pub iat: usize, pub permissions: Vec<String>,
549}
550
551impl AuthenticationMiddleware {
552 #[must_use]
554 pub fn new(api_keys: HashMap<String, ApiKeyInfo>, jwt_secret: String) -> Self {
555 Self {
556 api_keys,
557 jwt_secret,
558 oauth_config: None,
559 require_auth: true,
560 }
561 }
562
563 #[must_use]
565 pub fn with_oauth(
566 api_keys: HashMap<String, ApiKeyInfo>,
567 jwt_secret: String,
568 oauth_config: OAuthConfig,
569 ) -> Self {
570 Self {
571 api_keys,
572 jwt_secret,
573 oauth_config: Some(oauth_config),
574 require_auth: true,
575 }
576 }
577
578 #[must_use]
580 pub fn permissive() -> Self {
581 Self {
582 api_keys: HashMap::new(),
583 jwt_secret: "test-secret".to_string(),
584 oauth_config: None,
585 require_auth: false,
586 }
587 }
588
589 fn extract_api_key(request: &CallToolRequest) -> Option<String> {
591 if let Some(args) = &request.arguments {
593 if let Some(api_key) = args.get("api_key").and_then(|v| v.as_str()) {
594 return Some(api_key.to_string());
595 }
596 }
597 None
598 }
599
600 fn extract_jwt_token(request: &CallToolRequest) -> Option<String> {
602 if let Some(args) = &request.arguments {
604 if let Some(token) = args.get("jwt_token").and_then(|v| v.as_str()) {
605 return Some(token.to_string());
606 }
607 }
608 None
609 }
610
611 fn validate_api_key(&self, api_key: &str) -> McpResult<ApiKeyInfo> {
613 self.api_keys
614 .get(api_key)
615 .cloned()
616 .ok_or_else(|| McpError::validation_error("Invalid API key"))
617 }
618
619 fn validate_jwt_token(&self, token: &str) -> McpResult<JwtClaims> {
621 let validation = Validation::new(Algorithm::HS256);
622 let key = DecodingKey::from_secret(self.jwt_secret.as_ref());
623
624 let token_data = decode::<JwtClaims>(token, &key, &validation)
625 .map_err(|_| McpError::validation_error("Invalid JWT token"))?;
626
627 let now = chrono::Utc::now().timestamp().try_into().unwrap_or(0);
629 if token_data.claims.exp < now {
630 return Err(McpError::validation_error("JWT token has expired"));
631 }
632
633 Ok(token_data.claims)
634 }
635
636 #[cfg(test)]
641 #[must_use]
642 pub fn generate_test_jwt(&self, user_id: &str, permissions: Vec<String>) -> String {
643 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
644 let now = chrono::Utc::now().timestamp() as usize;
645 let claims = JwtClaims {
646 sub: user_id.to_string(),
647 exp: now + 3600, iat: now,
649 permissions,
650 };
651
652 let header = Header::new(Algorithm::HS256);
653 let key = EncodingKey::from_secret(self.jwt_secret.as_ref());
654 encode(&header, &claims, &key).unwrap()
655 }
656}
657
658#[async_trait::async_trait]
659impl McpMiddleware for AuthenticationMiddleware {
660 fn name(&self) -> &'static str {
661 "authentication"
662 }
663
664 fn priority(&self) -> i32 {
665 10 }
667
668 async fn before_request(
669 &self,
670 request: &CallToolRequest,
671 context: &mut MiddlewareContext,
672 ) -> McpResult<MiddlewareResult> {
673 if !self.require_auth {
674 context.set_metadata("auth_required".to_string(), Value::Bool(false));
675 return Ok(MiddlewareResult::Continue);
676 }
677
678 if let Some(api_key) = Self::extract_api_key(request) {
680 if let Ok(api_key_info) = self.validate_api_key(&api_key) {
681 context.set_metadata(
682 "auth_type".to_string(),
683 Value::String("api_key".to_string()),
684 );
685 context.set_metadata(
686 "auth_key_id".to_string(),
687 Value::String(api_key_info.key_id),
688 );
689 context.set_metadata(
690 "auth_permissions".to_string(),
691 serde_json::to_value(api_key_info.permissions).unwrap_or(Value::Array(vec![])),
692 );
693 context.set_metadata("auth_required".to_string(), Value::Bool(true));
694 return Ok(MiddlewareResult::Continue);
695 }
696 }
698
699 if let Some(jwt_token) = Self::extract_jwt_token(request) {
701 if let Ok(claims) = self.validate_jwt_token(&jwt_token) {
702 context.set_metadata("auth_type".to_string(), Value::String("jwt".to_string()));
703 context.set_metadata("auth_user_id".to_string(), Value::String(claims.sub));
704 context.set_metadata(
705 "auth_permissions".to_string(),
706 serde_json::to_value(claims.permissions).unwrap_or(Value::Array(vec![])),
707 );
708 context.set_metadata("auth_required".to_string(), Value::Bool(true));
709 return Ok(MiddlewareResult::Continue);
710 }
711 }
713
714 let error_result = CallToolResult {
716 content: vec![crate::mcp::Content::Text {
717 text: "Authentication required. Please provide a valid API key or JWT token."
718 .to_string(),
719 }],
720 is_error: true,
721 };
722
723 Ok(MiddlewareResult::Stop(error_result))
724 }
725}
726
727pub struct RateLimitMiddleware {
729 rate_limiter: Arc<RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>>,
730 default_limit: u32,
731 #[allow(dead_code)]
732 burst_limit: u32,
733}
734
735impl RateLimitMiddleware {
736 #[must_use]
738 pub fn new(requests_per_minute: u32, burst_limit: u32) -> Self {
739 let quota = Quota::per_minute(nonzero!(60u32)); let rate_limiter = Arc::new(RateLimiter::keyed(quota));
741
742 Self {
743 rate_limiter,
744 default_limit: requests_per_minute,
745 burst_limit,
746 }
747 }
748
749 #[must_use]
751 pub fn with_limits(requests_per_minute: u32, burst_limit: u32) -> Self {
752 Self::new(requests_per_minute, burst_limit)
753 }
754
755 #[allow(clippy::should_implement_trait)]
757 #[must_use]
758 pub fn default() -> Self {
759 Self::new(60, 10)
760 }
761
762 fn extract_client_id(request: &CallToolRequest, context: &MiddlewareContext) -> String {
764 if let Some(auth_key_id) = context.get_metadata("auth_key_id").and_then(|v| v.as_str()) {
766 return format!("api_key:{auth_key_id}");
767 }
768
769 if let Some(auth_user_id) = context
770 .get_metadata("auth_user_id")
771 .and_then(|v| v.as_str())
772 {
773 return format!("jwt:{auth_user_id}");
774 }
775
776 if let Some(args) = &request.arguments {
778 if let Some(client_id) = args.get("client_id").and_then(|v| v.as_str()) {
779 return format!("client:{client_id}");
780 }
781 }
782
783 format!("request:{}", context.request_id)
785 }
786
787 fn check_rate_limit(&self, client_id: &str) -> bool {
789 self.rate_limiter.check_key(&client_id.to_string()).is_ok()
790 }
791
792 fn get_remaining_requests(&self, _client_id: &str) -> u32 {
794 self.default_limit
797 }
798}
799
800#[async_trait::async_trait]
801impl McpMiddleware for RateLimitMiddleware {
802 fn name(&self) -> &'static str {
803 "rate_limiting"
804 }
805
806 fn priority(&self) -> i32 {
807 20 }
809
810 async fn before_request(
811 &self,
812 request: &CallToolRequest,
813 context: &mut MiddlewareContext,
814 ) -> McpResult<MiddlewareResult> {
815 let client_id = Self::extract_client_id(request, context);
816
817 if !self.check_rate_limit(&client_id) {
818 let error_result = CallToolResult {
819 content: vec![crate::mcp::Content::Text {
820 text: format!(
821 "Rate limit exceeded. Limit: {} requests per minute. Please try again later.",
822 self.default_limit
823 ),
824 }],
825 is_error: true,
826 };
827
828 context.set_metadata("rate_limited".to_string(), Value::Bool(true));
829 context.set_metadata("rate_limit_client_id".to_string(), Value::String(client_id));
830
831 return Ok(MiddlewareResult::Stop(error_result));
832 }
833
834 let remaining = self.get_remaining_requests(&client_id);
835 context.set_metadata(
836 "rate_limit_remaining".to_string(),
837 Value::Number(serde_json::Number::from(remaining)),
838 );
839 context.set_metadata("rate_limit_client_id".to_string(), Value::String(client_id));
840
841 Ok(MiddlewareResult::Continue)
842 }
843}
844
845#[derive(Debug, Clone, Serialize, Deserialize)]
847pub struct SecurityConfig {
848 pub authentication: AuthenticationConfig,
850 pub rate_limiting: RateLimitingConfig,
852}
853
854#[derive(Debug, Clone, Serialize, Deserialize)]
855pub struct AuthenticationConfig {
856 pub enabled: bool,
858 pub require_auth: bool,
860 pub jwt_secret: String,
862 pub api_keys: Vec<ApiKeyConfig>,
864 pub oauth: Option<OAuth2Config>,
866}
867
868#[derive(Debug, Clone, Serialize, Deserialize)]
869pub struct ApiKeyConfig {
870 pub key: String,
872 pub key_id: String,
874 pub permissions: Vec<String>,
876 pub expires_at: Option<String>,
878}
879
880#[derive(Debug, Clone, Serialize, Deserialize)]
881pub struct OAuth2Config {
882 pub client_id: String,
884 pub client_secret: String,
886 pub token_endpoint: String,
888 pub scopes: Vec<String>,
890}
891
892#[derive(Debug, Clone, Serialize, Deserialize)]
893pub struct RateLimitingConfig {
894 pub enabled: bool,
896 pub requests_per_minute: u32,
898 pub burst_limit: u32,
900 pub custom_limits: Option<HashMap<String, u32>>,
902}
903
904impl Default for SecurityConfig {
905 fn default() -> Self {
906 Self {
907 authentication: AuthenticationConfig {
908 enabled: true,
909 require_auth: false, jwt_secret: "your-secret-key-change-this-in-production".to_string(),
911 api_keys: vec![],
912 oauth: None,
913 },
914 rate_limiting: RateLimitingConfig {
915 enabled: true,
916 requests_per_minute: 60,
917 burst_limit: 10,
918 custom_limits: None,
919 },
920 }
921 }
922}
923
924#[derive(Debug, Clone, Serialize, Deserialize)]
926pub struct MiddlewareConfig {
927 pub logging: LoggingConfig,
929 pub validation: ValidationConfig,
931 pub performance: PerformanceConfig,
933 pub security: SecurityConfig,
935}
936
937#[derive(Debug, Clone, Serialize, Deserialize)]
939pub struct LoggingConfig {
940 pub enabled: bool,
942 pub level: String,
944}
945
946#[derive(Debug, Clone, Serialize, Deserialize)]
948pub struct ValidationConfig {
949 pub enabled: bool,
951 pub strict_mode: bool,
953}
954
955#[derive(Debug, Clone, Serialize, Deserialize)]
957pub struct PerformanceConfig {
958 pub enabled: bool,
960 pub slow_request_threshold_ms: u64,
962}
963
964impl Default for MiddlewareConfig {
965 fn default() -> Self {
966 Self {
967 logging: LoggingConfig {
968 enabled: true,
969 level: "info".to_string(),
970 },
971 validation: ValidationConfig {
972 enabled: true,
973 strict_mode: false,
974 },
975 performance: PerformanceConfig {
976 enabled: true,
977 slow_request_threshold_ms: 1000,
978 },
979 security: SecurityConfig::default(),
980 }
981 }
982}
983
984impl MiddlewareConfig {
985 #[must_use]
987 pub fn new() -> Self {
988 Self::default()
989 }
990
991 #[must_use]
993 pub fn build_chain(self) -> MiddlewareChain {
994 let mut chain = MiddlewareChain::new();
995
996 if self.security.authentication.enabled {
998 let api_keys: HashMap<String, ApiKeyInfo> = self
999 .security
1000 .authentication
1001 .api_keys
1002 .into_iter()
1003 .map(|config| {
1004 let expires_at = config.expires_at.and_then(|date_str| {
1005 chrono::DateTime::parse_from_rfc3339(&date_str)
1006 .ok()
1007 .map(|dt| dt.with_timezone(&chrono::Utc))
1008 });
1009
1010 let api_key_info = ApiKeyInfo {
1011 key_id: config.key_id,
1012 permissions: config.permissions,
1013 expires_at,
1014 };
1015
1016 (config.key, api_key_info)
1017 })
1018 .collect();
1019
1020 let auth_middleware = if self.security.authentication.require_auth {
1021 if let Some(oauth_config) = self.security.authentication.oauth {
1022 let oauth = OAuthConfig {
1023 client_id: oauth_config.client_id,
1024 client_secret: oauth_config.client_secret,
1025 token_endpoint: oauth_config.token_endpoint,
1026 scope: oauth_config.scopes,
1027 };
1028 AuthenticationMiddleware::with_oauth(
1029 api_keys,
1030 self.security.authentication.jwt_secret,
1031 oauth,
1032 )
1033 } else {
1034 AuthenticationMiddleware::new(api_keys, self.security.authentication.jwt_secret)
1035 }
1036 } else {
1037 AuthenticationMiddleware::permissive()
1038 };
1039
1040 chain = chain.add_middleware(auth_middleware);
1041 }
1042
1043 if self.security.rate_limiting.enabled {
1044 let rate_limit_middleware = RateLimitMiddleware::with_limits(
1045 self.security.rate_limiting.requests_per_minute,
1046 self.security.rate_limiting.burst_limit,
1047 );
1048 chain = chain.add_middleware(rate_limit_middleware);
1049 }
1050
1051 if self.logging.enabled {
1052 let log_level = match self.logging.level.to_lowercase().as_str() {
1053 "debug" => LogLevel::Debug,
1054 "warn" => LogLevel::Warn,
1055 "error" => LogLevel::Error,
1056 _ => LogLevel::Info,
1057 };
1058 chain = chain.add_middleware(LoggingMiddleware::new(log_level));
1059 }
1060
1061 if self.validation.enabled {
1062 chain = chain.add_middleware(ValidationMiddleware::new(self.validation.strict_mode));
1063 }
1064
1065 if self.performance.enabled {
1066 let threshold = Duration::from_millis(self.performance.slow_request_threshold_ms);
1067 chain = chain.add_middleware(PerformanceMiddleware::with_threshold(threshold));
1068 }
1069
1070 chain
1071 }
1072}
1073
1074#[derive(Error, Debug)]
1076pub enum MiddlewareError {
1077 #[error("Middleware execution failed: {message}")]
1078 ExecutionFailed { message: String },
1079
1080 #[error("Middleware configuration error: {message}")]
1081 ConfigurationError { message: String },
1082
1083 #[error("Middleware chain error: {message}")]
1084 ChainError { message: String },
1085}
1086
1087impl From<MiddlewareError> for McpError {
1088 fn from(error: MiddlewareError) -> Self {
1089 McpError::internal_error(error.to_string())
1090 }
1091}
1092
1093#[cfg(test)]
1094mod tests {
1095 use super::*;
1096 use crate::mcp::Content;
1097
1098 struct TestMiddleware {
1099 priority: i32,
1100 }
1101
1102 #[async_trait::async_trait]
1103 impl McpMiddleware for TestMiddleware {
1104 fn name(&self) -> &'static str {
1105 "test_middleware"
1106 }
1107
1108 fn priority(&self) -> i32 {
1109 self.priority
1110 }
1111 }
1112
1113 #[tokio::test]
1114 async fn test_middleware_chain_creation() {
1115 let chain = MiddlewareChain::new()
1116 .add_middleware(TestMiddleware { priority: 100 })
1117 .add_middleware(TestMiddleware { priority: 50 });
1118
1119 assert_eq!(chain.len(), 2);
1120 assert!(!chain.is_empty());
1121 }
1122
1123 #[tokio::test]
1124 async fn test_middleware_priority_ordering() {
1125 let chain = MiddlewareChain::new()
1126 .add_middleware(TestMiddleware { priority: 10 })
1127 .add_middleware(TestMiddleware { priority: 100 });
1128
1129 assert_eq!(chain.len(), 2);
1131 }
1132
1133 #[tokio::test]
1134 async fn test_middleware_execution() {
1135 let chain = MiddlewareChain::new()
1136 .add_middleware(LoggingMiddleware::info())
1137 .add_middleware(ValidationMiddleware::lenient());
1138
1139 let request = CallToolRequest {
1140 name: "test_tool".to_string(),
1141 arguments: Some(serde_json::json!({"param": "value"})),
1142 };
1143
1144 let handler = |_req: CallToolRequest| {
1145 Box::pin(async move {
1146 Ok(CallToolResult {
1147 content: vec![Content::Text {
1148 text: "Test response".to_string(),
1149 }],
1150 is_error: false,
1151 })
1152 })
1153 };
1154
1155 let result = chain.execute(request, handler).await;
1156 assert!(result.is_ok());
1157 }
1158
1159 #[tokio::test]
1160 async fn test_validation_middleware() {
1161 let middleware = ValidationMiddleware::strict();
1162 let mut context = MiddlewareContext::new("test".to_string());
1163
1164 let valid_request = CallToolRequest {
1166 name: "valid_tool".to_string(),
1167 arguments: Some(serde_json::json!({"param": "value"})),
1168 };
1169
1170 let result = middleware
1171 .before_request(&valid_request, &mut context)
1172 .await;
1173 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1174
1175 let invalid_request = CallToolRequest {
1177 name: String::new(),
1178 arguments: None,
1179 };
1180
1181 let result = middleware
1182 .before_request(&invalid_request, &mut context)
1183 .await;
1184 assert!(matches!(result, Ok(MiddlewareResult::Error(_))));
1185 }
1186
1187 #[tokio::test]
1188 async fn test_performance_middleware() {
1189 let middleware = PerformanceMiddleware::with_threshold(Duration::from_millis(100));
1190 let mut context = MiddlewareContext::new("test".to_string());
1191
1192 tokio::time::sleep(Duration::from_millis(150)).await;
1194
1195 let mut response = CallToolResult {
1196 content: vec![Content::Text {
1197 text: "Test".to_string(),
1198 }],
1199 is_error: false,
1200 };
1201
1202 let request = CallToolRequest {
1203 name: "test".to_string(),
1204 arguments: None,
1205 };
1206
1207 let result = middleware
1208 .after_request(&request, &mut response, &mut context)
1209 .await;
1210 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1211
1212 assert!(context.get_metadata("duration_ms").is_some());
1214 assert!(context.get_metadata("is_slow").is_some());
1215 }
1216
1217 #[tokio::test]
1218 async fn test_middleware_config() {
1219 let config = MiddlewareConfig {
1220 logging: LoggingConfig {
1221 enabled: true,
1222 level: "debug".to_string(),
1223 },
1224 validation: ValidationConfig {
1225 enabled: true,
1226 strict_mode: true,
1227 },
1228 performance: PerformanceConfig {
1229 enabled: true,
1230 slow_request_threshold_ms: 500,
1231 },
1232 security: SecurityConfig::default(),
1233 };
1234
1235 let chain = config.build_chain();
1236 assert!(!chain.is_empty());
1237 assert!(chain.len() >= 3); }
1239
1240 #[tokio::test]
1241 async fn test_middleware_context_creation() {
1242 let context = MiddlewareContext::new("test-request-123".to_string());
1243 assert_eq!(context.request_id, "test-request-123");
1244 assert!(context.metadata.is_empty());
1245 }
1246
1247 #[tokio::test]
1248 async fn test_middleware_context_elapsed() {
1249 let context = MiddlewareContext::new("test-request-123".to_string());
1250 std::thread::sleep(std::time::Duration::from_millis(10));
1251 let elapsed = context.elapsed();
1252 assert!(elapsed.as_millis() >= 10);
1253 }
1254
1255 #[tokio::test]
1256 async fn test_middleware_context_metadata() {
1257 let mut context = MiddlewareContext::new("test-request-123".to_string());
1258
1259 context.set_metadata(
1261 "key1".to_string(),
1262 serde_json::Value::String("value1".to_string()),
1263 );
1264 context.set_metadata(
1265 "key2".to_string(),
1266 serde_json::Value::Number(serde_json::Number::from(42)),
1267 );
1268
1269 assert_eq!(
1271 context.get_metadata("key1"),
1272 Some(&serde_json::Value::String("value1".to_string()))
1273 );
1274 assert_eq!(
1275 context.get_metadata("key2"),
1276 Some(&serde_json::Value::Number(serde_json::Number::from(42)))
1277 );
1278 assert_eq!(context.get_metadata("nonexistent"), None);
1279 }
1280
1281 #[tokio::test]
1282 async fn test_middleware_result_variants() {
1283 let continue_result = MiddlewareResult::Continue;
1284 let stop_result = MiddlewareResult::Stop(CallToolResult {
1285 content: vec![Content::Text {
1286 text: "test".to_string(),
1287 }],
1288 is_error: false,
1289 });
1290 let error_result = MiddlewareResult::Error(McpError::tool_not_found("test error"));
1291
1292 match continue_result {
1294 MiddlewareResult::Continue => {}
1295 _ => panic!("Expected Continue"),
1296 }
1297
1298 match stop_result {
1299 MiddlewareResult::Stop(_) => {}
1300 _ => panic!("Expected Stop"),
1301 }
1302
1303 match error_result {
1304 MiddlewareResult::Error(_) => {}
1305 _ => panic!("Expected Error"),
1306 }
1307 }
1308
1309 #[tokio::test]
1310 async fn test_logging_middleware_different_levels() {
1311 let debug_middleware = LoggingMiddleware::new(LogLevel::Debug);
1312 let info_middleware = LoggingMiddleware::new(LogLevel::Info);
1313 let warn_middleware = LoggingMiddleware::new(LogLevel::Warn);
1314 let error_middleware = LoggingMiddleware::new(LogLevel::Error);
1315
1316 assert_eq!(debug_middleware.name(), "logging");
1317 assert_eq!(info_middleware.name(), "logging");
1318 assert_eq!(warn_middleware.name(), "logging");
1319 assert_eq!(error_middleware.name(), "logging");
1320 }
1321
1322 #[tokio::test]
1323 async fn test_logging_middleware_should_log() {
1324 let debug_middleware = LoggingMiddleware::new(LogLevel::Debug);
1325 let info_middleware = LoggingMiddleware::new(LogLevel::Info);
1326 let warn_middleware = LoggingMiddleware::new(LogLevel::Warn);
1327 let error_middleware = LoggingMiddleware::new(LogLevel::Error);
1328
1329 assert!(debug_middleware.should_log(LogLevel::Debug));
1331 assert!(debug_middleware.should_log(LogLevel::Info));
1332 assert!(debug_middleware.should_log(LogLevel::Warn));
1333 assert!(debug_middleware.should_log(LogLevel::Error));
1334
1335 assert!(!info_middleware.should_log(LogLevel::Debug));
1337 assert!(info_middleware.should_log(LogLevel::Info));
1338 assert!(info_middleware.should_log(LogLevel::Warn));
1339 assert!(info_middleware.should_log(LogLevel::Error));
1340
1341 assert!(!warn_middleware.should_log(LogLevel::Debug));
1343 assert!(!warn_middleware.should_log(LogLevel::Info));
1344 assert!(warn_middleware.should_log(LogLevel::Warn));
1345 assert!(warn_middleware.should_log(LogLevel::Error));
1346
1347 assert!(!error_middleware.should_log(LogLevel::Debug));
1349 assert!(!error_middleware.should_log(LogLevel::Info));
1350 assert!(!error_middleware.should_log(LogLevel::Warn));
1351 assert!(error_middleware.should_log(LogLevel::Error));
1352 }
1353
1354 #[tokio::test]
1355 async fn test_validation_middleware_strict_mode() {
1356 let strict_middleware = ValidationMiddleware::strict();
1357 let lenient_middleware = ValidationMiddleware::lenient();
1358
1359 assert_eq!(strict_middleware.name(), "validation");
1360 assert_eq!(lenient_middleware.name(), "validation");
1361 }
1362
1363 #[tokio::test]
1364 async fn test_validation_middleware_creation() {
1365 let middleware1 = ValidationMiddleware::new(true);
1366 let middleware2 = ValidationMiddleware::new(false);
1367
1368 assert_eq!(middleware1.name(), "validation");
1369 assert_eq!(middleware2.name(), "validation");
1370 }
1371
1372 #[tokio::test]
1373 async fn test_performance_middleware_creation() {
1374 let middleware1 = PerformanceMiddleware::new(Duration::from_millis(100));
1375 let middleware2 = PerformanceMiddleware::with_threshold(Duration::from_millis(200));
1376 let middleware3 = PerformanceMiddleware::create_default();
1377
1378 assert_eq!(middleware1.name(), "performance");
1379 assert_eq!(middleware2.name(), "performance");
1380 assert_eq!(middleware3.name(), "performance");
1381 }
1382
1383 #[tokio::test]
1384 async fn test_middleware_chain_empty() {
1385 let chain = MiddlewareChain::new();
1386 assert!(chain.is_empty());
1387 assert_eq!(chain.len(), 0);
1388 }
1389
1390 #[tokio::test]
1391 async fn test_middleware_chain_add_middleware() {
1392 let chain = MiddlewareChain::new()
1393 .add_middleware(LoggingMiddleware::new(LogLevel::Info))
1394 .add_middleware(ValidationMiddleware::new(false));
1395
1396 assert!(!chain.is_empty());
1397 assert_eq!(chain.len(), 2);
1398 }
1399
1400 #[tokio::test]
1401 async fn test_middleware_chain_add_arc() {
1402 let middleware = Arc::new(LoggingMiddleware::new(LogLevel::Info)) as Arc<dyn McpMiddleware>;
1403 let chain = MiddlewareChain::new().add_arc(middleware);
1404
1405 assert!(!chain.is_empty());
1406 assert_eq!(chain.len(), 1);
1407 }
1408
1409 #[tokio::test]
1410 async fn test_middleware_chain_execution_with_empty_chain() {
1411 let chain = MiddlewareChain::new();
1412 let request = CallToolRequest {
1413 name: "test_tool".to_string(),
1414 arguments: None,
1415 };
1416
1417 let result = chain
1418 .execute(request, |_| async {
1419 Ok(CallToolResult {
1420 content: vec![Content::Text {
1421 text: "success".to_string(),
1422 }],
1423 is_error: false,
1424 })
1425 })
1426 .await;
1427
1428 assert!(result.is_ok());
1429 let result = result.unwrap();
1430 assert!(!result.is_error);
1431 assert_eq!(result.content.len(), 1);
1432 }
1433
1434 #[tokio::test]
1435 async fn test_middleware_chain_execution_with_error() {
1436 let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
1437 let request = CallToolRequest {
1438 name: "test_tool".to_string(),
1439 arguments: None,
1440 };
1441
1442 let result = chain
1443 .execute(request, |_| async {
1444 Err(McpError::tool_not_found("test error"))
1445 })
1446 .await;
1447
1448 assert!(result.is_err());
1449 }
1450
1451 #[tokio::test]
1452 async fn test_middleware_chain_execution_with_stop() {
1453 struct StopMiddleware;
1455 #[async_trait::async_trait]
1456 impl McpMiddleware for StopMiddleware {
1457 fn name(&self) -> &'static str {
1458 "stop"
1459 }
1460
1461 async fn before_request(
1462 &self,
1463 _request: &CallToolRequest,
1464 _context: &mut MiddlewareContext,
1465 ) -> McpResult<MiddlewareResult> {
1466 Ok(MiddlewareResult::Stop(CallToolResult {
1467 content: vec![Content::Text {
1468 text: "stopped".to_string(),
1469 }],
1470 is_error: false,
1471 }))
1472 }
1473
1474 async fn after_request(
1475 &self,
1476 _request: &CallToolRequest,
1477 _result: &mut CallToolResult,
1478 _context: &mut MiddlewareContext,
1479 ) -> McpResult<MiddlewareResult> {
1480 Ok(MiddlewareResult::Continue)
1481 }
1482
1483 async fn on_error(
1484 &self,
1485 _request: &CallToolRequest,
1486 _error: &McpError,
1487 _context: &mut MiddlewareContext,
1488 ) -> McpResult<MiddlewareResult> {
1489 Ok(MiddlewareResult::Continue)
1490 }
1491 }
1492
1493 let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
1494 let request = CallToolRequest {
1495 name: "test_tool".to_string(),
1496 arguments: None,
1497 };
1498
1499 let chain = chain.add_middleware(StopMiddleware);
1500
1501 let result = chain
1502 .execute(request, |_| async {
1503 Ok(CallToolResult {
1504 content: vec![Content::Text {
1505 text: "should not reach here".to_string(),
1506 }],
1507 is_error: false,
1508 })
1509 })
1510 .await;
1511
1512 assert!(result.is_ok());
1513 let result = result.unwrap();
1514 let Content::Text { text } = &result.content[0];
1515 assert_eq!(text, "stopped");
1516 }
1517
1518 #[tokio::test]
1519 async fn test_middleware_chain_execution_with_middleware_error() {
1520 struct ErrorMiddleware;
1522 #[async_trait::async_trait]
1523 impl McpMiddleware for ErrorMiddleware {
1524 fn name(&self) -> &'static str {
1525 "error"
1526 }
1527
1528 async fn before_request(
1529 &self,
1530 _request: &CallToolRequest,
1531 _context: &mut MiddlewareContext,
1532 ) -> McpResult<MiddlewareResult> {
1533 Err(McpError::tool_not_found("middleware error"))
1534 }
1535
1536 async fn after_request(
1537 &self,
1538 _request: &CallToolRequest,
1539 _result: &mut CallToolResult,
1540 _context: &mut MiddlewareContext,
1541 ) -> McpResult<MiddlewareResult> {
1542 Ok(MiddlewareResult::Continue)
1543 }
1544
1545 async fn on_error(
1546 &self,
1547 _request: &CallToolRequest,
1548 _error: &McpError,
1549 _context: &mut MiddlewareContext,
1550 ) -> McpResult<MiddlewareResult> {
1551 Ok(MiddlewareResult::Continue)
1552 }
1553 }
1554
1555 let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
1556 let request = CallToolRequest {
1557 name: "test_tool".to_string(),
1558 arguments: None,
1559 };
1560
1561 let chain = chain.add_middleware(ErrorMiddleware);
1562
1563 let result = chain
1564 .execute(request, |_| async {
1565 Ok(CallToolResult {
1566 content: vec![Content::Text {
1567 text: "should not reach here".to_string(),
1568 }],
1569 is_error: false,
1570 })
1571 })
1572 .await;
1573
1574 assert!(result.is_err());
1575 let error = result.unwrap_err();
1576 assert!(matches!(error, McpError::ToolNotFound { tool_name: _ }));
1577 }
1578
1579 #[tokio::test]
1580 async fn test_middleware_chain_execution_with_on_error() {
1581 struct ErrorHandlerMiddleware;
1583 #[async_trait::async_trait]
1584 impl McpMiddleware for ErrorHandlerMiddleware {
1585 fn name(&self) -> &'static str {
1586 "error_handler"
1587 }
1588
1589 async fn before_request(
1590 &self,
1591 _request: &CallToolRequest,
1592 _context: &mut MiddlewareContext,
1593 ) -> McpResult<MiddlewareResult> {
1594 Ok(MiddlewareResult::Continue)
1595 }
1596
1597 async fn after_request(
1598 &self,
1599 _request: &CallToolRequest,
1600 _result: &mut CallToolResult,
1601 _context: &mut MiddlewareContext,
1602 ) -> McpResult<MiddlewareResult> {
1603 Ok(MiddlewareResult::Continue)
1604 }
1605
1606 async fn on_error(
1607 &self,
1608 _request: &CallToolRequest,
1609 _error: &McpError,
1610 _context: &mut MiddlewareContext,
1611 ) -> McpResult<MiddlewareResult> {
1612 Ok(MiddlewareResult::Stop(CallToolResult {
1613 content: vec![Content::Text {
1614 text: "error handled".to_string(),
1615 }],
1616 is_error: false,
1617 }))
1618 }
1619 }
1620
1621 let chain = MiddlewareChain::new().add_middleware(LoggingMiddleware::new(LogLevel::Info));
1622 let request = CallToolRequest {
1623 name: "test_tool".to_string(),
1624 arguments: None,
1625 };
1626
1627 let chain = chain.add_middleware(ErrorHandlerMiddleware);
1628
1629 let result = chain
1630 .execute(request, |_| async {
1631 Err(McpError::tool_not_found("test error"))
1632 })
1633 .await;
1634
1635 assert!(result.is_ok());
1636 let result = result.unwrap();
1637 let Content::Text { text } = &result.content[0];
1638 assert_eq!(text, "error handled");
1639 }
1640
1641 #[tokio::test]
1642 async fn test_config_structs_creation() {
1643 let logging_config = LoggingConfig {
1644 enabled: true,
1645 level: "debug".to_string(),
1646 };
1647 let validation_config = ValidationConfig {
1648 enabled: true,
1649 strict_mode: true,
1650 };
1651 let performance_config = PerformanceConfig {
1652 enabled: true,
1653 slow_request_threshold_ms: 1000,
1654 };
1655
1656 assert!(logging_config.enabled);
1657 assert_eq!(logging_config.level, "debug");
1658 assert!(validation_config.enabled);
1659 assert!(validation_config.strict_mode);
1660 assert!(performance_config.enabled);
1661 assert_eq!(performance_config.slow_request_threshold_ms, 1000);
1662 }
1663
1664 #[tokio::test]
1665 async fn test_config_default() {
1666 let config = MiddlewareConfig::default();
1667 assert!(config.logging.enabled);
1668 assert_eq!(config.logging.level, "info");
1669 assert!(config.validation.enabled);
1670 assert!(!config.validation.strict_mode);
1671 assert!(config.performance.enabled);
1672 assert_eq!(config.performance.slow_request_threshold_ms, 1000);
1673 }
1674
1675 #[tokio::test]
1676 async fn test_config_build_chain_with_disabled_middleware() {
1677 let config = MiddlewareConfig {
1678 logging: LoggingConfig {
1679 enabled: false,
1680 level: "debug".to_string(),
1681 },
1682 validation: ValidationConfig {
1683 enabled: false,
1684 strict_mode: true,
1685 },
1686 performance: PerformanceConfig {
1687 enabled: false,
1688 slow_request_threshold_ms: 1000,
1689 },
1690 security: SecurityConfig {
1691 authentication: AuthenticationConfig {
1692 enabled: false,
1693 require_auth: false,
1694 jwt_secret: "test".to_string(),
1695 api_keys: vec![],
1696 oauth: None,
1697 },
1698 rate_limiting: RateLimitingConfig {
1699 enabled: false,
1700 requests_per_minute: 60,
1701 burst_limit: 10,
1702 custom_limits: None,
1703 },
1704 },
1705 };
1706
1707 let chain = config.build_chain();
1708 assert!(chain.is_empty());
1709 }
1710
1711 #[tokio::test]
1712 async fn test_config_build_chain_with_partial_middleware() {
1713 let config = MiddlewareConfig {
1714 logging: LoggingConfig {
1715 enabled: true,
1716 level: "debug".to_string(),
1717 },
1718 validation: ValidationConfig {
1719 enabled: false,
1720 strict_mode: true,
1721 },
1722 performance: PerformanceConfig {
1723 enabled: true,
1724 slow_request_threshold_ms: 1000,
1725 },
1726 security: SecurityConfig::default(),
1727 };
1728
1729 let chain = config.build_chain();
1730 assert!(!chain.is_empty());
1731 assert!(chain.len() >= 2); }
1733
1734 #[tokio::test]
1735 async fn test_config_build_chain_with_invalid_log_level() {
1736 let config = MiddlewareConfig {
1737 logging: LoggingConfig {
1738 enabled: true,
1739 level: "invalid".to_string(),
1740 },
1741 validation: ValidationConfig {
1742 enabled: true,
1743 strict_mode: true,
1744 },
1745 performance: PerformanceConfig {
1746 enabled: true,
1747 slow_request_threshold_ms: 1000,
1748 },
1749 security: SecurityConfig::default(),
1750 };
1751
1752 let chain = config.build_chain();
1753 assert!(!chain.is_empty());
1754 }
1756
1757 #[tokio::test]
1758 async fn test_middleware_chain_execution_with_empty_middleware() {
1759 let chain = MiddlewareChain::new();
1760 let request = CallToolRequest {
1761 name: "test_tool".to_string(),
1762 arguments: Some(serde_json::json!({"param": "value"})),
1763 };
1764
1765 let result = chain
1766 .execute(request, |_| async {
1767 Ok(CallToolResult {
1768 content: vec![Content::Text {
1769 text: "Test response".to_string(),
1770 }],
1771 is_error: false,
1772 })
1773 })
1774 .await;
1775
1776 assert!(result.is_ok());
1777 let result = result.unwrap();
1778 assert!(!result.is_error);
1779 assert_eq!(result.content.len(), 1);
1780 }
1781
1782 #[tokio::test]
1783 async fn test_middleware_chain_execution_with_multiple_middleware() {
1784 let chain = MiddlewareChain::new()
1785 .add_middleware(LoggingMiddleware::new(LogLevel::Info))
1786 .add_middleware(ValidationMiddleware::new(false))
1787 .add_middleware(PerformanceMiddleware::new(Duration::from_millis(100)));
1788
1789 let request = CallToolRequest {
1790 name: "test_tool".to_string(),
1791 arguments: Some(serde_json::json!({"param": "value"})),
1792 };
1793
1794 let result = chain
1795 .execute(request, |_| async {
1796 Ok(CallToolResult {
1797 content: vec![Content::Text {
1798 text: "Test response".to_string(),
1799 }],
1800 is_error: false,
1801 })
1802 })
1803 .await;
1804
1805 assert!(result.is_ok());
1806 let result = result.unwrap();
1807 assert!(!result.is_error);
1808 assert_eq!(result.content.len(), 1);
1809 }
1810
1811 #[tokio::test]
1812 async fn test_middleware_chain_execution_with_middleware_stop() {
1813 struct StopMiddleware;
1814 #[async_trait::async_trait]
1815 impl McpMiddleware for StopMiddleware {
1816 fn name(&self) -> &'static str {
1817 "stop_middleware"
1818 }
1819
1820 fn priority(&self) -> i32 {
1821 100
1822 }
1823
1824 async fn before_request(
1825 &self,
1826 _request: &CallToolRequest,
1827 _context: &mut MiddlewareContext,
1828 ) -> McpResult<MiddlewareResult> {
1829 Ok(MiddlewareResult::Stop(CallToolResult {
1830 content: vec![Content::Text {
1831 text: "Stopped by middleware".to_string(),
1832 }],
1833 is_error: false,
1834 }))
1835 }
1836
1837 async fn after_request(
1838 &self,
1839 _request: &CallToolRequest,
1840 _result: &mut CallToolResult,
1841 _context: &mut MiddlewareContext,
1842 ) -> McpResult<MiddlewareResult> {
1843 Ok(MiddlewareResult::Continue)
1844 }
1845
1846 async fn on_error(
1847 &self,
1848 _request: &CallToolRequest,
1849 _error: &McpError,
1850 _context: &mut MiddlewareContext,
1851 ) -> McpResult<MiddlewareResult> {
1852 Ok(MiddlewareResult::Continue)
1853 }
1854 }
1855
1856 let chain = MiddlewareChain::new().add_middleware(StopMiddleware);
1857
1858 let request = CallToolRequest {
1859 name: "test_tool".to_string(),
1860 arguments: None,
1861 };
1862
1863 let result = chain
1864 .execute(request, |_| async {
1865 Ok(CallToolResult {
1866 content: vec![Content::Text {
1867 text: "Should not reach here".to_string(),
1868 }],
1869 is_error: false,
1870 })
1871 })
1872 .await;
1873
1874 assert!(result.is_ok());
1875 let result = result.unwrap();
1876 assert!(!result.is_error);
1877 let Content::Text { text } = &result.content[0];
1878 assert_eq!(text, "Stopped by middleware");
1879 }
1880
1881 #[tokio::test]
1882 async fn test_middleware_chain_execution_with_middleware_error_duplicate() {
1883 struct ErrorMiddleware;
1884 #[async_trait::async_trait]
1885 impl McpMiddleware for ErrorMiddleware {
1886 fn name(&self) -> &'static str {
1887 "error_middleware"
1888 }
1889
1890 fn priority(&self) -> i32 {
1891 100
1892 }
1893
1894 async fn before_request(
1895 &self,
1896 _request: &CallToolRequest,
1897 _context: &mut MiddlewareContext,
1898 ) -> McpResult<MiddlewareResult> {
1899 Err(McpError::internal_error("Middleware error"))
1900 }
1901
1902 async fn after_request(
1903 &self,
1904 _request: &CallToolRequest,
1905 _result: &mut CallToolResult,
1906 _context: &mut MiddlewareContext,
1907 ) -> McpResult<MiddlewareResult> {
1908 Ok(MiddlewareResult::Continue)
1909 }
1910
1911 async fn on_error(
1912 &self,
1913 _request: &CallToolRequest,
1914 _error: &McpError,
1915 _context: &mut MiddlewareContext,
1916 ) -> McpResult<MiddlewareResult> {
1917 Ok(MiddlewareResult::Continue)
1918 }
1919 }
1920
1921 let chain = MiddlewareChain::new().add_middleware(ErrorMiddleware);
1922
1923 let request = CallToolRequest {
1924 name: "test_tool".to_string(),
1925 arguments: None,
1926 };
1927
1928 let result = chain
1929 .execute(request, |_| async {
1930 Ok(CallToolResult {
1931 content: vec![Content::Text {
1932 text: "Should not reach here".to_string(),
1933 }],
1934 is_error: false,
1935 })
1936 })
1937 .await;
1938
1939 assert!(result.is_err());
1940 let error = result.unwrap_err();
1941 assert!(matches!(error, McpError::InternalError { .. }));
1942 }
1943
1944 #[tokio::test]
1946 async fn test_authentication_middleware_permissive() {
1947 let middleware = AuthenticationMiddleware::permissive();
1948 let mut context = MiddlewareContext::new("test".to_string());
1949
1950 let request = CallToolRequest {
1951 name: "test_tool".to_string(),
1952 arguments: None,
1953 };
1954
1955 let result = middleware.before_request(&request, &mut context).await;
1956
1957 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1958 assert_eq!(
1959 context.get_metadata("auth_required"),
1960 Some(&Value::Bool(false))
1961 );
1962 }
1963
1964 #[tokio::test]
1965 async fn test_authentication_middleware_with_valid_api_key() {
1966 let mut api_keys = HashMap::new();
1967 api_keys.insert(
1968 "test-api-key".to_string(),
1969 ApiKeyInfo {
1970 key_id: "test-key-1".to_string(),
1971 permissions: vec!["read".to_string(), "write".to_string()],
1972 expires_at: None,
1973 },
1974 );
1975
1976 let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
1977 let mut context = MiddlewareContext::new("test".to_string());
1978
1979 let request = CallToolRequest {
1980 name: "test_tool".to_string(),
1981 arguments: Some(serde_json::json!({
1982 "api_key": "test-api-key"
1983 })),
1984 };
1985
1986 let result = middleware.before_request(&request, &mut context).await;
1987
1988 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
1989 assert_eq!(
1990 context.get_metadata("auth_type"),
1991 Some(&Value::String("api_key".to_string()))
1992 );
1993 assert_eq!(
1994 context.get_metadata("auth_key_id"),
1995 Some(&Value::String("test-key-1".to_string()))
1996 );
1997 }
1998
1999 #[tokio::test]
2000 async fn test_authentication_middleware_with_invalid_api_key() {
2001 let api_keys = HashMap::new();
2002 let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
2003 let mut context = MiddlewareContext::new("test".to_string());
2004
2005 let request = CallToolRequest {
2006 name: "test_tool".to_string(),
2007 arguments: Some(serde_json::json!({
2008 "api_key": "invalid-key"
2009 })),
2010 };
2011
2012 let result = middleware.before_request(&request, &mut context).await;
2013
2014 assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
2015 }
2016
2017 #[tokio::test]
2018 async fn test_authentication_middleware_with_valid_jwt() {
2019 let api_keys = HashMap::new();
2020 let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
2021
2022 let jwt_token = middleware.generate_test_jwt("user123", vec!["read".to_string()]);
2024
2025 let mut context = MiddlewareContext::new("test".to_string());
2026
2027 let request = CallToolRequest {
2028 name: "test_tool".to_string(),
2029 arguments: Some(serde_json::json!({
2030 "jwt_token": jwt_token
2031 })),
2032 };
2033
2034 let result = middleware.before_request(&request, &mut context).await;
2035
2036 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
2037 assert_eq!(
2038 context.get_metadata("auth_type"),
2039 Some(&Value::String("jwt".to_string()))
2040 );
2041 assert_eq!(
2042 context.get_metadata("auth_user_id"),
2043 Some(&Value::String("user123".to_string()))
2044 );
2045 }
2046
2047 #[tokio::test]
2048 async fn test_authentication_middleware_with_invalid_jwt() {
2049 let api_keys = HashMap::new();
2050 let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
2051 let mut context = MiddlewareContext::new("test".to_string());
2052
2053 let request = CallToolRequest {
2054 name: "test_tool".to_string(),
2055 arguments: Some(serde_json::json!({
2056 "jwt_token": "invalid.jwt.token"
2057 })),
2058 };
2059
2060 let result = middleware.before_request(&request, &mut context).await;
2061
2062 assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
2063 }
2064
2065 #[tokio::test]
2066 async fn test_authentication_middleware_no_auth_provided() {
2067 let api_keys = HashMap::new();
2068 let middleware = AuthenticationMiddleware::new(api_keys, "test-secret".to_string());
2069 let mut context = MiddlewareContext::new("test".to_string());
2070
2071 let request = CallToolRequest {
2072 name: "test_tool".to_string(),
2073 arguments: None,
2074 };
2075
2076 let result = middleware.before_request(&request, &mut context).await;
2077
2078 assert!(matches!(result, Ok(MiddlewareResult::Stop(_))));
2079 }
2080
2081 #[tokio::test]
2083 async fn test_rate_limit_middleware_allows_request() {
2084 let middleware = RateLimitMiddleware::new(10, 5);
2085 let mut context = MiddlewareContext::new("test".to_string());
2086
2087 let request = CallToolRequest {
2088 name: "test_tool".to_string(),
2089 arguments: Some(serde_json::json!({
2090 "client_id": "test-client"
2091 })),
2092 };
2093
2094 let result = middleware.before_request(&request, &mut context).await;
2095
2096 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
2097 assert_eq!(
2098 context.get_metadata("rate_limit_client_id"),
2099 Some(&Value::String("client:test-client".to_string()))
2100 );
2101 }
2102
2103 #[tokio::test]
2104 async fn test_rate_limit_middleware_uses_auth_context() {
2105 let middleware = RateLimitMiddleware::new(10, 5);
2106 let mut context = MiddlewareContext::new("test".to_string());
2107
2108 context.set_metadata(
2110 "auth_key_id".to_string(),
2111 Value::String("api-key-123".to_string()),
2112 );
2113
2114 let request = CallToolRequest {
2115 name: "test_tool".to_string(),
2116 arguments: None,
2117 };
2118
2119 let result = middleware.before_request(&request, &mut context).await;
2120
2121 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
2122 assert_eq!(
2123 context.get_metadata("rate_limit_client_id"),
2124 Some(&Value::String("api_key:api-key-123".to_string()))
2125 );
2126 }
2127
2128 #[tokio::test]
2129 async fn test_rate_limit_middleware_uses_jwt_context() {
2130 let middleware = RateLimitMiddleware::new(10, 5);
2131 let mut context = MiddlewareContext::new("test".to_string());
2132
2133 context.set_metadata(
2135 "auth_user_id".to_string(),
2136 Value::String("user-456".to_string()),
2137 );
2138
2139 let request = CallToolRequest {
2140 name: "test_tool".to_string(),
2141 arguments: None,
2142 };
2143
2144 let result = middleware.before_request(&request, &mut context).await;
2145
2146 assert!(matches!(result, Ok(MiddlewareResult::Continue)));
2147 assert_eq!(
2148 context.get_metadata("rate_limit_client_id"),
2149 Some(&Value::String("jwt:user-456".to_string()))
2150 );
2151 }
2152
2153 #[tokio::test]
2155 async fn test_security_config_default() {
2156 let config = SecurityConfig::default();
2157 assert!(config.authentication.enabled);
2158 assert!(!config.authentication.require_auth); assert!(config.rate_limiting.enabled);
2160 assert_eq!(config.rate_limiting.requests_per_minute, 60);
2161 }
2162
2163 #[tokio::test]
2164 async fn test_middleware_config_with_security() {
2165 let config = MiddlewareConfig {
2166 logging: LoggingConfig {
2167 enabled: true,
2168 level: "debug".to_string(),
2169 },
2170 validation: ValidationConfig {
2171 enabled: true,
2172 strict_mode: true,
2173 },
2174 performance: PerformanceConfig {
2175 enabled: true,
2176 slow_request_threshold_ms: 500,
2177 },
2178 security: SecurityConfig {
2179 authentication: AuthenticationConfig {
2180 enabled: true,
2181 require_auth: true,
2182 jwt_secret: "test-secret".to_string(),
2183 api_keys: vec![ApiKeyConfig {
2184 key: "test-key".to_string(),
2185 key_id: "test-id".to_string(),
2186 permissions: vec!["read".to_string()],
2187 expires_at: None,
2188 }],
2189 oauth: None,
2190 },
2191 rate_limiting: RateLimitingConfig {
2192 enabled: true,
2193 requests_per_minute: 30,
2194 burst_limit: 5,
2195 custom_limits: None,
2196 },
2197 },
2198 };
2199
2200 let chain = config.build_chain();
2201 assert!(!chain.is_empty());
2202 assert!(chain.len() >= 5); }
2204
2205 #[tokio::test]
2206 async fn test_middleware_chain_with_security_middleware() {
2207 let mut api_keys = HashMap::new();
2208 api_keys.insert(
2209 "test-key".to_string(),
2210 ApiKeyInfo {
2211 key_id: "test-id".to_string(),
2212 permissions: vec!["read".to_string()],
2213 expires_at: None,
2214 },
2215 );
2216
2217 let chain = MiddlewareChain::new()
2218 .add_middleware(AuthenticationMiddleware::new(
2219 api_keys,
2220 "test-secret".to_string(),
2221 ))
2222 .add_middleware(RateLimitMiddleware::new(10, 5))
2223 .add_middleware(LoggingMiddleware::new(LogLevel::Info));
2224
2225 let request = CallToolRequest {
2226 name: "test_tool".to_string(),
2227 arguments: Some(serde_json::json!({
2228 "api_key": "test-key"
2229 })),
2230 };
2231
2232 let result = chain
2233 .execute(request, |_| async {
2234 Ok(CallToolResult {
2235 content: vec![Content::Text {
2236 text: "success".to_string(),
2237 }],
2238 is_error: false,
2239 })
2240 })
2241 .await;
2242
2243 assert!(result.is_ok());
2244 }
2245}