1use regex::Regex;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use std::time::{Duration, Instant};
14
15use crate::SessionContext;
16use turul_mcp_protocol::McpError;
17
18#[derive(Debug, Clone)]
20pub struct RateLimitConfig {
21    pub max_requests: u32,
23    pub window_duration: Duration,
25    pub burst_size: u32,
27}
28
29impl Default for RateLimitConfig {
30    fn default() -> Self {
31        Self {
32            max_requests: 100,
33            window_duration: Duration::from_secs(60),
34            burst_size: 10,
35        }
36    }
37}
38
39type SessionBuckets = Arc<Mutex<HashMap<String, (Vec<Instant>, u32)>>>;
42
43#[derive(Debug)]
44pub struct RateLimiter {
45    config: RateLimitConfig,
46    session_buckets: SessionBuckets,
48}
49
50impl RateLimiter {
51    pub fn new(config: RateLimitConfig) -> Self {
52        Self {
53            config,
54            session_buckets: Arc::new(Mutex::new(HashMap::new())),
55        }
56    }
57
58    pub fn check_rate_limit(&self, session_id: &str) -> Result<(), McpError> {
60        let mut buckets = self.session_buckets.lock().unwrap();
61        let now = Instant::now();
62
63        let (request_times, burst_count) = buckets
64            .entry(session_id.to_string())
65            .or_insert_with(|| (Vec::new(), 0));
66
67        request_times.retain(|&time| now.duration_since(time) < self.config.window_duration);
69
70        request_times.push(now);
72
73        if request_times.len() > self.config.max_requests as usize {
75            if *burst_count < self.config.burst_size {
77                *burst_count += 1;
78                return Ok(());
79            }
80
81            request_times.pop();
83
84            return Err(McpError::param_out_of_range(
85                "request_rate",
86                &format!("{} requests", request_times.len() + 1),
87                &format!(
88                    "max {} requests per {:?}",
89                    self.config.max_requests, self.config.window_duration
90                ),
91            ));
92        }
93
94        if request_times.len() < (self.config.max_requests as f32 * 0.8) as usize {
96            *burst_count = 0;
97        }
98
99        Ok(())
100    }
101
102    pub fn cleanup_expired_sessions(&self) {
104        let mut buckets = self.session_buckets.lock().unwrap();
105        let now = Instant::now();
106
107        buckets.retain(|_, (request_times, _)| {
108            request_times.retain(|&time| now.duration_since(time) < self.config.window_duration);
109            !request_times.is_empty()
110        });
111    }
112}
113
114#[derive(Debug, Clone, PartialEq)]
116pub enum AccessLevel {
117    Public,
119    SessionRequired,
121    Custom(String), }
124
125#[derive(Debug, Clone)]
127pub struct ResourceAccessControl {
128    pub access_level: AccessLevel,
130    pub allowed_patterns: Vec<Regex>,
132    pub blocked_patterns: Vec<Regex>,
134    pub max_size: Option<u64>,
136    pub allowed_mime_types: Option<Vec<String>>,
138}
139
140impl Default for ResourceAccessControl {
141    fn default() -> Self {
142        Self {
143            access_level: AccessLevel::SessionRequired,
144            allowed_patterns: vec![
145                Regex::new(r"^file:///[a-zA-Z0-9_/-]+\.(json|txt|md|html)$").unwrap(),
146            ],
147            blocked_patterns: vec![
148                Regex::new(r"\.\.").unwrap(),   Regex::new(r"/etc/").unwrap(),  Regex::new(r"/proc/").unwrap(), Regex::new(r"\.exe$").unwrap(), ],
153            max_size: Some(10 * 1024 * 1024), allowed_mime_types: Some(vec![
155                "text/plain".to_string(),
156                "text/markdown".to_string(),
157                "application/json".to_string(),
158                "text/html".to_string(),
159                "image/png".to_string(),
160                "image/jpeg".to_string(),
161            ]),
162        }
163    }
164}
165
166impl ResourceAccessControl {
167    pub fn validate_uri(&self, uri: &str) -> Result<(), McpError> {
169        for blocked_pattern in &self.blocked_patterns {
171            if blocked_pattern.is_match(uri) {
172                return Err(McpError::invalid_param_type(
173                    "uri",
174                    "URI not matching blocked patterns",
175                    uri,
176                ));
177            }
178        }
179
180        if !self.allowed_patterns.is_empty() {
182            let allowed = self
183                .allowed_patterns
184                .iter()
185                .any(|pattern| pattern.is_match(uri));
186
187            if !allowed {
188                return Err(McpError::invalid_param_type(
189                    "uri",
190                    "URI matching allowed patterns",
191                    uri,
192                ));
193            }
194        }
195
196        Ok(())
197    }
198
199    pub fn validate_mime_type(&self, mime_type: &str) -> Result<(), McpError> {
201        if let Some(allowed_types) = &self.allowed_mime_types
202            && !allowed_types.contains(&mime_type.to_string())
203        {
204            return Err(McpError::invalid_param_type(
205                "mime_type",
206                "allowed MIME type",
207                mime_type,
208            ));
209        }
210        Ok(())
211    }
212
213    pub fn validate_size(&self, size: u64) -> Result<(), McpError> {
215        if let Some(max_size) = self.max_size
216            && size > max_size
217        {
218            return Err(McpError::param_out_of_range(
219                "content_size",
220                &format!("{} bytes", size),
221                &format!("max {} bytes", max_size),
222            ));
223        }
224        Ok(())
225    }
226}
227
228pub struct InputValidator {
230    max_json_depth: usize,
232    max_string_length: usize,
234    max_collection_size: usize,
236}
237
238impl Default for InputValidator {
239    fn default() -> Self {
240        Self {
241            max_json_depth: 10,
242            max_string_length: 1024 * 1024, max_collection_size: 1000,
244        }
245    }
246}
247
248impl InputValidator {
249    pub fn new(
250        max_json_depth: usize,
251        max_string_length: usize,
252        max_collection_size: usize,
253    ) -> Self {
254        Self {
255            max_json_depth,
256            max_string_length,
257            max_collection_size,
258        }
259    }
260
261    pub fn validate_json(&self, value: &Value) -> Result<(), McpError> {
263        self.validate_json_recursive(value, 0)
264    }
265
266    fn validate_json_recursive(&self, value: &Value, depth: usize) -> Result<(), McpError> {
267        if depth > self.max_json_depth {
268            return Err(McpError::param_out_of_range(
269                "json_depth",
270                &format!("{}", depth),
271                &format!("max {}", self.max_json_depth),
272            ));
273        }
274
275        match value {
276            Value::String(s) => {
277                if s.len() > self.max_string_length {
278                    return Err(McpError::param_out_of_range(
279                        "string_length",
280                        &format!("{}", s.len()),
281                        &format!("max {}", self.max_string_length),
282                    ));
283                }
284
285                if s.contains("../") || s.contains("..\\") {
287                    return Err(McpError::invalid_param_type(
288                        "string_content",
289                        "string without directory traversal sequences",
290                        s,
291                    ));
292                }
293            }
294            Value::Array(arr) => {
295                if arr.len() > self.max_collection_size {
296                    return Err(McpError::param_out_of_range(
297                        "array_size",
298                        &format!("{}", arr.len()),
299                        &format!("max {}", self.max_collection_size),
300                    ));
301                }
302
303                for item in arr {
304                    self.validate_json_recursive(item, depth + 1)?;
305                }
306            }
307            Value::Object(obj) => {
308                if obj.len() > self.max_collection_size {
309                    return Err(McpError::param_out_of_range(
310                        "object_size",
311                        &format!("{}", obj.len()),
312                        &format!("max {}", self.max_collection_size),
313                    ));
314                }
315
316                for (key, val) in obj {
317                    if key.len() > self.max_string_length {
319                        return Err(McpError::param_out_of_range(
320                            "object_key_length",
321                            &format!("{}", key.len()),
322                            &format!("max {}", self.max_string_length),
323                        ));
324                    }
325
326                    self.validate_json_recursive(val, depth + 1)?;
327                }
328            }
329            _ => {} }
331
332        Ok(())
333    }
334
335    pub fn sanitize_string(&self, input: &str) -> String {
337        input
338            .chars()
339            .filter(|c| c.is_ascii() && !c.is_control() || c.is_whitespace())
340            .take(self.max_string_length)
341            .collect()
342    }
343}
344
345pub struct SecurityMiddleware {
347    rate_limiter: Option<RateLimiter>,
348    resource_access_control: ResourceAccessControl,
349    input_validator: InputValidator,
350}
351
352impl SecurityMiddleware {
353    pub fn new() -> Self {
354        Self {
355            rate_limiter: Some(RateLimiter::new(RateLimitConfig::default())),
356            resource_access_control: ResourceAccessControl::default(),
357            input_validator: InputValidator::default(),
358        }
359    }
360
361    pub fn resource_access_control(&self) -> &ResourceAccessControl {
363        &self.resource_access_control
364    }
365
366    pub fn with_rate_limiting(mut self, config: RateLimitConfig) -> Self {
367        self.rate_limiter = Some(RateLimiter::new(config));
368        self
369    }
370
371    pub fn without_rate_limiting(mut self) -> Self {
372        self.rate_limiter = None;
373        self
374    }
375
376    pub fn with_resource_access_control(mut self, config: ResourceAccessControl) -> Self {
377        self.resource_access_control = config;
378        self
379    }
380
381    pub fn with_input_validation(mut self, validator: InputValidator) -> Self {
382        self.input_validator = validator;
383        self
384    }
385
386    pub fn validate_request(
388        &self,
389        method: &str,
390        params: Option<&Value>,
391        session: Option<&SessionContext>,
392    ) -> Result<(), McpError> {
393        if let Some(rate_limiter) = &self.rate_limiter
395            && let Some(session) = session
396        {
397            rate_limiter.check_rate_limit(&session.session_id)?;
398        }
399
400        if let Some(params) = params {
402            self.input_validator.validate_json(params)?;
403        }
404
405        if method == "resources/read" {
407            if let Some(params) = params
408                && let Some(uri) = params.get("uri").and_then(|v| v.as_str())
409            {
410                self.resource_access_control.validate_uri(uri)?;
411            }
412
413            match self.resource_access_control.access_level {
415                AccessLevel::SessionRequired if session.is_none() => {
416                    return Err(McpError::invalid_param_type(
417                        "session",
418                        "valid session context",
419                        "none",
420                    ));
421                }
422                _ => {}
423            }
424        }
425        Ok(())
428    }
429
430    pub fn cleanup(&self) {
432        if let Some(rate_limiter) = &self.rate_limiter {
433            rate_limiter.cleanup_expired_sessions();
434        }
435    }
436}
437
438impl Default for SecurityMiddleware {
439    fn default() -> Self {
440        Self::new()
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use serde_json::json;
448
449    #[test]
450    fn test_rate_limiter_basic() {
451        let config = RateLimitConfig {
452            max_requests: 3,
453            window_duration: Duration::from_secs(60),
454            burst_size: 1,
455        };
456        let limiter = RateLimiter::new(config);
457
458        assert!(limiter.check_rate_limit("session1").is_ok());
460        assert!(limiter.check_rate_limit("session1").is_ok());
461        assert!(limiter.check_rate_limit("session1").is_ok());
462
463        assert!(limiter.check_rate_limit("session1").is_ok());
465
466        assert!(limiter.check_rate_limit("session1").is_err());
468    }
469
470    #[test]
471    fn test_rate_limiter_different_sessions() {
472        let config = RateLimitConfig {
473            max_requests: 2,
474            window_duration: Duration::from_secs(60),
475            burst_size: 0,
476        };
477        let limiter = RateLimiter::new(config);
478
479        assert!(limiter.check_rate_limit("session1").is_ok());
481        assert!(limiter.check_rate_limit("session1").is_ok());
482        assert!(limiter.check_rate_limit("session1").is_err());
483
484        assert!(limiter.check_rate_limit("session2").is_ok());
485        assert!(limiter.check_rate_limit("session2").is_ok());
486        assert!(limiter.check_rate_limit("session2").is_err());
487    }
488
489    #[test]
490    fn test_resource_access_control_uri_validation() {
491        let access_control = ResourceAccessControl::default();
492
493        assert!(
495            access_control
496                .validate_uri("file:///data/test.json")
497                .is_ok()
498        );
499        assert!(
500            access_control
501                .validate_uri("file:///docs/readme.txt")
502                .is_ok()
503        );
504
505        assert!(access_control.validate_uri("file:///etc/passwd").is_err());
507        assert!(
508            access_control
509                .validate_uri("file:///data/../etc/shadow")
510                .is_err()
511        );
512        assert!(
513            access_control
514                .validate_uri("file:///app/malware.exe")
515                .is_err()
516        );
517    }
518
519    #[test]
520    fn test_input_validator_json_depth() {
521        let validator = InputValidator::new(3, 1000, 100);
522
523        let valid_json = json!({
525            "level1": {
526                "level2": {
527                    "level3": "value"
528                }
529            }
530        });
531        assert!(validator.validate_json(&valid_json).is_ok());
532
533        let deep_json = json!({
535            "l1": { "l2": { "l3": { "l4": { "l5": "too deep" } } } }
536        });
537        assert!(validator.validate_json(&deep_json).is_err());
538    }
539
540    #[test]
541    fn test_input_validator_string_length() {
542        let validator = InputValidator::new(10, 10, 100);
543
544        let valid_json = json!({"key": "short"});
545        assert!(validator.validate_json(&valid_json).is_ok());
546
547        let invalid_json = json!({"key": "this string is too long"});
548        assert!(validator.validate_json(&invalid_json).is_err());
549    }
550
551    #[test]
552    fn test_input_validator_directory_traversal() {
553        let validator = InputValidator::default();
554
555        let malicious_json = json!({"path": "../../../etc/passwd"});
556        assert!(validator.validate_json(&malicious_json).is_err());
557
558        let safe_json = json!({"path": "data/file.txt"});
559        assert!(validator.validate_json(&safe_json).is_ok());
560    }
561
562    #[test]
563    fn test_security_middleware_integration() {
564        let session_id = "test-session".to_string();
566        let session = SessionContext {
567            session_id: session_id.clone(),
568            get_state: Arc::new(|_| Box::pin(futures::future::ready(None))),
569            set_state: Arc::new(|_, _| Box::pin(futures::future::ready(()))),
570            remove_state: Arc::new(|_| Box::pin(futures::future::ready(None))),
571            is_initialized: Arc::new(|| Box::pin(futures::future::ready(true))),
572            send_notification: Arc::new(|_| Box::pin(futures::future::ready(()))),
573            broadcaster: None,
574        };
575
576        let middleware = SecurityMiddleware::new();
577
578        let params = json!({"uri": "file:///data/test.json"});
580        assert!(
581            middleware
582                .validate_request("resources/read", Some(¶ms), Some(&session))
583                .is_ok()
584        );
585
586        let bad_params = json!({"uri": "file:///etc/passwd"});
588        assert!(
589            middleware
590                .validate_request("resources/read", Some(&bad_params), Some(&session))
591                .is_err()
592        );
593
594        assert!(
596            middleware
597                .validate_request("resources/read", Some(¶ms), None)
598                .is_err()
599        );
600    }
601
602    #[test]
603    fn test_mime_type_validation() {
604        let access_control = ResourceAccessControl::default();
605
606        assert!(
607            access_control
608                .validate_mime_type("application/json")
609                .is_ok()
610        );
611        assert!(access_control.validate_mime_type("text/plain").is_ok());
612        assert!(
613            access_control
614                .validate_mime_type("application/octet-stream")
615                .is_err()
616        );
617        assert!(
618            access_control
619                .validate_mime_type("application/x-executable")
620                .is_err()
621        );
622    }
623
624    #[test]
625    fn test_size_validation() {
626        let access_control = ResourceAccessControl::default();
627
628        assert!(access_control.validate_size(1024).is_ok()); assert!(access_control.validate_size(1024 * 1024).is_ok()); assert!(access_control.validate_size(20 * 1024 * 1024).is_err()); }
632}