1use regex::Regex;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use std::time::{Duration, Instant};
14
15use crate::SessionContext;
16use turul_mcp_protocol::McpError;
17
18#[derive(Debug, Clone)]
20pub struct RateLimitConfig {
21 pub max_requests: u32,
23 pub window_duration: Duration,
25 pub burst_size: u32,
27}
28
29impl Default for RateLimitConfig {
30 fn default() -> Self {
31 Self {
32 max_requests: 100,
33 window_duration: Duration::from_secs(60),
34 burst_size: 10,
35 }
36 }
37}
38
39type SessionBuckets = Arc<Mutex<HashMap<String, (Vec<Instant>, u32)>>>;
42
43#[derive(Debug)]
44pub struct RateLimiter {
45 config: RateLimitConfig,
46 session_buckets: SessionBuckets,
48}
49
50impl RateLimiter {
51 pub fn new(config: RateLimitConfig) -> Self {
52 Self {
53 config,
54 session_buckets: Arc::new(Mutex::new(HashMap::new())),
55 }
56 }
57
58 pub fn check_rate_limit(&self, session_id: &str) -> Result<(), McpError> {
60 let mut buckets = self.session_buckets.lock().unwrap();
61 let now = Instant::now();
62
63 let (request_times, burst_count) = buckets
64 .entry(session_id.to_string())
65 .or_insert_with(|| (Vec::new(), 0));
66
67 request_times.retain(|&time| now.duration_since(time) < self.config.window_duration);
69
70 request_times.push(now);
72
73 if request_times.len() > self.config.max_requests as usize {
75 if *burst_count < self.config.burst_size {
77 *burst_count += 1;
78 return Ok(());
79 }
80
81 request_times.pop();
83
84 return Err(McpError::param_out_of_range(
85 "request_rate",
86 &format!("{} requests", request_times.len() + 1),
87 &format!(
88 "max {} requests per {:?}",
89 self.config.max_requests, self.config.window_duration
90 ),
91 ));
92 }
93
94 if request_times.len() < (self.config.max_requests as f32 * 0.8) as usize {
96 *burst_count = 0;
97 }
98
99 Ok(())
100 }
101
102 pub fn cleanup_expired_sessions(&self) {
104 let mut buckets = self.session_buckets.lock().unwrap();
105 let now = Instant::now();
106
107 buckets.retain(|_, (request_times, _)| {
108 request_times.retain(|&time| now.duration_since(time) < self.config.window_duration);
109 !request_times.is_empty()
110 });
111 }
112}
113
114#[derive(Debug, Clone, PartialEq)]
116pub enum AccessLevel {
117 Public,
119 SessionRequired,
121 Custom(String), }
124
125#[derive(Debug, Clone)]
127pub struct ResourceAccessControl {
128 pub access_level: AccessLevel,
130 pub allowed_patterns: Vec<Regex>,
132 pub blocked_patterns: Vec<Regex>,
134 pub max_size: Option<u64>,
136 pub allowed_mime_types: Option<Vec<String>>,
138}
139
140impl Default for ResourceAccessControl {
141 fn default() -> Self {
142 Self {
143 access_level: AccessLevel::SessionRequired,
144 allowed_patterns: vec![
145 Regex::new(r"^file:///[a-zA-Z0-9_/-]+\.(json|txt|md|html)$").unwrap(),
146 ],
147 blocked_patterns: vec![
148 Regex::new(r"\.\.").unwrap(), Regex::new(r"/etc/").unwrap(), Regex::new(r"/proc/").unwrap(), Regex::new(r"\.exe$").unwrap(), ],
153 max_size: Some(10 * 1024 * 1024), allowed_mime_types: Some(vec![
155 "text/plain".to_string(),
156 "text/markdown".to_string(),
157 "application/json".to_string(),
158 "text/html".to_string(),
159 "image/png".to_string(),
160 "image/jpeg".to_string(),
161 ]),
162 }
163 }
164}
165
166impl ResourceAccessControl {
167 pub fn validate_uri(&self, uri: &str) -> Result<(), McpError> {
169 for blocked_pattern in &self.blocked_patterns {
171 if blocked_pattern.is_match(uri) {
172 return Err(McpError::invalid_param_type(
173 "uri",
174 "URI not matching blocked patterns",
175 uri,
176 ));
177 }
178 }
179
180 if !self.allowed_patterns.is_empty() {
182 let allowed = self
183 .allowed_patterns
184 .iter()
185 .any(|pattern| pattern.is_match(uri));
186
187 if !allowed {
188 return Err(McpError::invalid_param_type(
189 "uri",
190 "URI matching allowed patterns",
191 uri,
192 ));
193 }
194 }
195
196 Ok(())
197 }
198
199 pub fn validate_mime_type(&self, mime_type: &str) -> Result<(), McpError> {
201 if let Some(allowed_types) = &self.allowed_mime_types
202 && !allowed_types.contains(&mime_type.to_string())
203 {
204 return Err(McpError::invalid_param_type(
205 "mime_type",
206 "allowed MIME type",
207 mime_type,
208 ));
209 }
210 Ok(())
211 }
212
213 pub fn validate_size(&self, size: u64) -> Result<(), McpError> {
215 if let Some(max_size) = self.max_size
216 && size > max_size
217 {
218 return Err(McpError::param_out_of_range(
219 "content_size",
220 &format!("{} bytes", size),
221 &format!("max {} bytes", max_size),
222 ));
223 }
224 Ok(())
225 }
226}
227
228pub struct InputValidator {
230 max_json_depth: usize,
232 max_string_length: usize,
234 max_collection_size: usize,
236}
237
238impl Default for InputValidator {
239 fn default() -> Self {
240 Self {
241 max_json_depth: 10,
242 max_string_length: 1024 * 1024, max_collection_size: 1000,
244 }
245 }
246}
247
248impl InputValidator {
249 pub fn new(
250 max_json_depth: usize,
251 max_string_length: usize,
252 max_collection_size: usize,
253 ) -> Self {
254 Self {
255 max_json_depth,
256 max_string_length,
257 max_collection_size,
258 }
259 }
260
261 pub fn validate_json(&self, value: &Value) -> Result<(), McpError> {
263 self.validate_json_recursive(value, 0)
264 }
265
266 fn validate_json_recursive(&self, value: &Value, depth: usize) -> Result<(), McpError> {
267 if depth > self.max_json_depth {
268 return Err(McpError::param_out_of_range(
269 "json_depth",
270 &format!("{}", depth),
271 &format!("max {}", self.max_json_depth),
272 ));
273 }
274
275 match value {
276 Value::String(s) => {
277 if s.len() > self.max_string_length {
278 return Err(McpError::param_out_of_range(
279 "string_length",
280 &format!("{}", s.len()),
281 &format!("max {}", self.max_string_length),
282 ));
283 }
284
285 if s.contains("../") || s.contains("..\\") {
287 return Err(McpError::invalid_param_type(
288 "string_content",
289 "string without directory traversal sequences",
290 s,
291 ));
292 }
293 }
294 Value::Array(arr) => {
295 if arr.len() > self.max_collection_size {
296 return Err(McpError::param_out_of_range(
297 "array_size",
298 &format!("{}", arr.len()),
299 &format!("max {}", self.max_collection_size),
300 ));
301 }
302
303 for item in arr {
304 self.validate_json_recursive(item, depth + 1)?;
305 }
306 }
307 Value::Object(obj) => {
308 if obj.len() > self.max_collection_size {
309 return Err(McpError::param_out_of_range(
310 "object_size",
311 &format!("{}", obj.len()),
312 &format!("max {}", self.max_collection_size),
313 ));
314 }
315
316 for (key, val) in obj {
317 if key.len() > self.max_string_length {
319 return Err(McpError::param_out_of_range(
320 "object_key_length",
321 &format!("{}", key.len()),
322 &format!("max {}", self.max_string_length),
323 ));
324 }
325
326 self.validate_json_recursive(val, depth + 1)?;
327 }
328 }
329 _ => {} }
331
332 Ok(())
333 }
334
335 pub fn sanitize_string(&self, input: &str) -> String {
337 input
338 .chars()
339 .filter(|c| c.is_ascii() && !c.is_control() || c.is_whitespace())
340 .take(self.max_string_length)
341 .collect()
342 }
343}
344
345pub struct SecurityMiddleware {
347 rate_limiter: Option<RateLimiter>,
348 resource_access_control: ResourceAccessControl,
349 input_validator: InputValidator,
350}
351
352impl SecurityMiddleware {
353 pub fn new() -> Self {
354 Self {
355 rate_limiter: Some(RateLimiter::new(RateLimitConfig::default())),
356 resource_access_control: ResourceAccessControl::default(),
357 input_validator: InputValidator::default(),
358 }
359 }
360
361 pub fn resource_access_control(&self) -> &ResourceAccessControl {
363 &self.resource_access_control
364 }
365
366 pub fn with_rate_limiting(mut self, config: RateLimitConfig) -> Self {
367 self.rate_limiter = Some(RateLimiter::new(config));
368 self
369 }
370
371 pub fn without_rate_limiting(mut self) -> Self {
372 self.rate_limiter = None;
373 self
374 }
375
376 pub fn with_resource_access_control(mut self, config: ResourceAccessControl) -> Self {
377 self.resource_access_control = config;
378 self
379 }
380
381 pub fn with_input_validation(mut self, validator: InputValidator) -> Self {
382 self.input_validator = validator;
383 self
384 }
385
386 pub fn validate_request(
388 &self,
389 method: &str,
390 params: Option<&Value>,
391 session: Option<&SessionContext>,
392 ) -> Result<(), McpError> {
393 if let Some(rate_limiter) = &self.rate_limiter
395 && let Some(session) = session
396 {
397 rate_limiter.check_rate_limit(&session.session_id)?;
398 }
399
400 if let Some(params) = params {
402 self.input_validator.validate_json(params)?;
403 }
404
405 if method == "resources/read" {
407 if let Some(params) = params
408 && let Some(uri) = params.get("uri").and_then(|v| v.as_str())
409 {
410 self.resource_access_control.validate_uri(uri)?;
411 }
412
413 match self.resource_access_control.access_level {
415 AccessLevel::SessionRequired if session.is_none() => {
416 return Err(McpError::invalid_param_type(
417 "session",
418 "valid session context",
419 "none",
420 ));
421 }
422 _ => {}
423 }
424 }
425 Ok(())
428 }
429
430 pub fn cleanup(&self) {
432 if let Some(rate_limiter) = &self.rate_limiter {
433 rate_limiter.cleanup_expired_sessions();
434 }
435 }
436}
437
438impl Default for SecurityMiddleware {
439 fn default() -> Self {
440 Self::new()
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use serde_json::json;
448
449 #[test]
450 fn test_rate_limiter_basic() {
451 let config = RateLimitConfig {
452 max_requests: 3,
453 window_duration: Duration::from_secs(60),
454 burst_size: 1,
455 };
456 let limiter = RateLimiter::new(config);
457
458 assert!(limiter.check_rate_limit("session1").is_ok());
460 assert!(limiter.check_rate_limit("session1").is_ok());
461 assert!(limiter.check_rate_limit("session1").is_ok());
462
463 assert!(limiter.check_rate_limit("session1").is_ok());
465
466 assert!(limiter.check_rate_limit("session1").is_err());
468 }
469
470 #[test]
471 fn test_rate_limiter_different_sessions() {
472 let config = RateLimitConfig {
473 max_requests: 2,
474 window_duration: Duration::from_secs(60),
475 burst_size: 0,
476 };
477 let limiter = RateLimiter::new(config);
478
479 assert!(limiter.check_rate_limit("session1").is_ok());
481 assert!(limiter.check_rate_limit("session1").is_ok());
482 assert!(limiter.check_rate_limit("session1").is_err());
483
484 assert!(limiter.check_rate_limit("session2").is_ok());
485 assert!(limiter.check_rate_limit("session2").is_ok());
486 assert!(limiter.check_rate_limit("session2").is_err());
487 }
488
489 #[test]
490 fn test_resource_access_control_uri_validation() {
491 let access_control = ResourceAccessControl::default();
492
493 assert!(
495 access_control
496 .validate_uri("file:///data/test.json")
497 .is_ok()
498 );
499 assert!(
500 access_control
501 .validate_uri("file:///docs/readme.txt")
502 .is_ok()
503 );
504
505 assert!(access_control.validate_uri("file:///etc/passwd").is_err());
507 assert!(
508 access_control
509 .validate_uri("file:///data/../etc/shadow")
510 .is_err()
511 );
512 assert!(
513 access_control
514 .validate_uri("file:///app/malware.exe")
515 .is_err()
516 );
517 }
518
519 #[test]
520 fn test_input_validator_json_depth() {
521 let validator = InputValidator::new(3, 1000, 100);
522
523 let valid_json = json!({
525 "level1": {
526 "level2": {
527 "level3": "value"
528 }
529 }
530 });
531 assert!(validator.validate_json(&valid_json).is_ok());
532
533 let deep_json = json!({
535 "l1": { "l2": { "l3": { "l4": { "l5": "too deep" } } } }
536 });
537 assert!(validator.validate_json(&deep_json).is_err());
538 }
539
540 #[test]
541 fn test_input_validator_string_length() {
542 let validator = InputValidator::new(10, 10, 100);
543
544 let valid_json = json!({"key": "short"});
545 assert!(validator.validate_json(&valid_json).is_ok());
546
547 let invalid_json = json!({"key": "this string is too long"});
548 assert!(validator.validate_json(&invalid_json).is_err());
549 }
550
551 #[test]
552 fn test_input_validator_directory_traversal() {
553 let validator = InputValidator::default();
554
555 let malicious_json = json!({"path": "../../../etc/passwd"});
556 assert!(validator.validate_json(&malicious_json).is_err());
557
558 let safe_json = json!({"path": "data/file.txt"});
559 assert!(validator.validate_json(&safe_json).is_ok());
560 }
561
562 #[test]
563 fn test_security_middleware_integration() {
564 let session_id = "test-session".to_string();
566 let session = SessionContext {
567 session_id: session_id.clone(),
568 get_state: Arc::new(|_| Box::pin(futures::future::ready(None))),
569 set_state: Arc::new(|_, _| Box::pin(futures::future::ready(()))),
570 remove_state: Arc::new(|_| Box::pin(futures::future::ready(None))),
571 is_initialized: Arc::new(|| Box::pin(futures::future::ready(true))),
572 send_notification: Arc::new(|_| Box::pin(futures::future::ready(()))),
573 broadcaster: None,
574 };
575
576 let middleware = SecurityMiddleware::new();
577
578 let params = json!({"uri": "file:///data/test.json"});
580 assert!(
581 middleware
582 .validate_request("resources/read", Some(¶ms), Some(&session))
583 .is_ok()
584 );
585
586 let bad_params = json!({"uri": "file:///etc/passwd"});
588 assert!(
589 middleware
590 .validate_request("resources/read", Some(&bad_params), Some(&session))
591 .is_err()
592 );
593
594 assert!(
596 middleware
597 .validate_request("resources/read", Some(¶ms), None)
598 .is_err()
599 );
600 }
601
602 #[test]
603 fn test_mime_type_validation() {
604 let access_control = ResourceAccessControl::default();
605
606 assert!(
607 access_control
608 .validate_mime_type("application/json")
609 .is_ok()
610 );
611 assert!(access_control.validate_mime_type("text/plain").is_ok());
612 assert!(
613 access_control
614 .validate_mime_type("application/octet-stream")
615 .is_err()
616 );
617 assert!(
618 access_control
619 .validate_mime_type("application/x-executable")
620 .is_err()
621 );
622 }
623
624 #[test]
625 fn test_size_validation() {
626 let access_control = ResourceAccessControl::default();
627
628 assert!(access_control.validate_size(1024).is_ok()); assert!(access_control.validate_size(1024 * 1024).is_ok()); assert!(access_control.validate_size(20 * 1024 * 1024).is_err()); }
632}