1use std::collections::HashSet;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::time::{Duration, Instant};
13
14use parking_lot::Mutex;
15use serde::{Deserialize, Serialize};
16
17pub use turbomcp_core::SUPPORTED_VERSIONS as SUPPORTED_PROTOCOL_VERSIONS;
19
20pub const DEFAULT_MAX_CONNECTIONS: usize = 1000;
22
23pub const DEFAULT_RATE_LIMIT: u32 = 100;
25
26pub const DEFAULT_RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
28
29pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
31
32#[derive(Debug, Clone)]
34pub struct ServerConfig {
35 pub protocol: ProtocolConfig,
37 pub rate_limit: Option<RateLimitConfig>,
39 pub connection_limits: ConnectionLimits,
41 pub required_capabilities: RequiredCapabilities,
43 pub max_message_size: usize,
45}
46
47impl Default for ServerConfig {
48 fn default() -> Self {
49 Self {
50 protocol: ProtocolConfig::default(),
51 rate_limit: None,
52 connection_limits: ConnectionLimits::default(),
53 required_capabilities: RequiredCapabilities::default(),
54 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
55 }
56 }
57}
58
59impl ServerConfig {
60 #[must_use]
62 pub fn new() -> Self {
63 Self::default()
64 }
65
66 #[must_use]
68 pub fn builder() -> ServerConfigBuilder {
69 ServerConfigBuilder::default()
70 }
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct ServerConfigBuilder {
76 protocol: Option<ProtocolConfig>,
77 rate_limit: Option<RateLimitConfig>,
78 connection_limits: Option<ConnectionLimits>,
79 required_capabilities: Option<RequiredCapabilities>,
80 max_message_size: Option<usize>,
81}
82
83impl ServerConfigBuilder {
84 #[must_use]
86 pub fn protocol(mut self, config: ProtocolConfig) -> Self {
87 self.protocol = Some(config);
88 self
89 }
90
91 #[must_use]
93 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
94 self.rate_limit = Some(config);
95 self
96 }
97
98 #[must_use]
100 pub fn connection_limits(mut self, limits: ConnectionLimits) -> Self {
101 self.connection_limits = Some(limits);
102 self
103 }
104
105 #[must_use]
107 pub fn required_capabilities(mut self, caps: RequiredCapabilities) -> Self {
108 self.required_capabilities = Some(caps);
109 self
110 }
111
112 #[must_use]
117 pub fn max_message_size(mut self, size: usize) -> Self {
118 self.max_message_size = Some(size);
119 self
120 }
121
122 #[must_use]
127 pub fn build(self) -> ServerConfig {
128 ServerConfig {
129 protocol: self.protocol.unwrap_or_default(),
130 rate_limit: self.rate_limit,
131 connection_limits: self.connection_limits.unwrap_or_default(),
132 required_capabilities: self.required_capabilities.unwrap_or_default(),
133 max_message_size: self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE),
134 }
135 }
136
137 pub fn try_build(self) -> Result<ServerConfig, ConfigValidationError> {
163 let max_message_size = self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE);
164
165 if max_message_size < 1024 {
167 return Err(ConfigValidationError::InvalidMessageSize {
168 size: max_message_size,
169 min: 1024,
170 });
171 }
172
173 if let Some(ref rate_limit) = self.rate_limit {
175 if rate_limit.max_requests == 0 {
176 return Err(ConfigValidationError::InvalidRateLimit {
177 reason: "max_requests cannot be 0".to_string(),
178 });
179 }
180 if rate_limit.window.is_zero() {
181 return Err(ConfigValidationError::InvalidRateLimit {
182 reason: "rate limit window cannot be zero".to_string(),
183 });
184 }
185 }
186
187 let connection_limits = self.connection_limits.unwrap_or_default();
189 if connection_limits.max_tcp_connections == 0
190 && connection_limits.max_websocket_connections == 0
191 && connection_limits.max_http_concurrent == 0
192 && connection_limits.max_unix_connections == 0
193 {
194 return Err(ConfigValidationError::InvalidConnectionLimits {
195 reason: "at least one connection limit must be non-zero".to_string(),
196 });
197 }
198
199 Ok(ServerConfig {
200 protocol: self.protocol.unwrap_or_default(),
201 rate_limit: self.rate_limit,
202 connection_limits,
203 required_capabilities: self.required_capabilities.unwrap_or_default(),
204 max_message_size,
205 })
206 }
207}
208
209#[derive(Debug, Clone, thiserror::Error)]
211pub enum ConfigValidationError {
212 #[error("Invalid max_message_size: {size} bytes is below minimum of {min} bytes")]
214 InvalidMessageSize {
215 size: usize,
217 min: usize,
219 },
220
221 #[error("Invalid rate limit: {reason}")]
223 InvalidRateLimit {
224 reason: String,
226 },
227
228 #[error("Invalid connection limits: {reason}")]
230 InvalidConnectionLimits {
231 reason: String,
233 },
234}
235
236#[derive(Debug, Clone)]
238pub struct ProtocolConfig {
239 pub preferred_version: String,
241 pub supported_versions: Vec<String>,
243 pub allow_fallback: bool,
245}
246
247impl Default for ProtocolConfig {
248 fn default() -> Self {
249 Self {
250 preferred_version: SUPPORTED_PROTOCOL_VERSIONS[0].to_string(),
251 supported_versions: SUPPORTED_PROTOCOL_VERSIONS
252 .iter()
253 .map(|s| s.to_string())
254 .collect(),
255 allow_fallback: false,
256 }
257 }
258}
259
260impl ProtocolConfig {
261 #[must_use]
263 pub fn strict(version: &str) -> Self {
264 Self {
265 preferred_version: version.to_string(),
266 supported_versions: vec![version.to_string()],
267 allow_fallback: false,
268 }
269 }
270
271 #[must_use]
273 pub fn is_supported(&self, version: &str) -> bool {
274 self.supported_versions.iter().any(|v| v == version)
275 }
276
277 #[must_use]
281 pub fn negotiate(&self, client_version: Option<&str>) -> Option<String> {
282 match client_version {
283 Some(version) if self.is_supported(version) => Some(version.to_string()),
284 Some(_) if self.allow_fallback => Some(self.preferred_version.clone()),
285 Some(_) => None,
286 None => Some(self.preferred_version.clone()),
287 }
288 }
289}
290
291#[derive(Debug, Clone)]
293pub struct RateLimitConfig {
294 pub max_requests: u32,
296 pub window: Duration,
298 pub per_client: bool,
300}
301
302impl Default for RateLimitConfig {
303 fn default() -> Self {
304 Self {
305 max_requests: DEFAULT_RATE_LIMIT,
306 window: DEFAULT_RATE_LIMIT_WINDOW,
307 per_client: true,
308 }
309 }
310}
311
312impl RateLimitConfig {
313 #[must_use]
315 pub fn new(max_requests: u32, window: Duration) -> Self {
316 Self {
317 max_requests,
318 window,
319 per_client: true,
320 }
321 }
322
323 #[must_use]
325 pub fn per_client(mut self, enabled: bool) -> Self {
326 self.per_client = enabled;
327 self
328 }
329}
330
331#[derive(Debug, Clone)]
333pub struct ConnectionLimits {
334 pub max_tcp_connections: usize,
336 pub max_websocket_connections: usize,
338 pub max_http_concurrent: usize,
340 pub max_unix_connections: usize,
342}
343
344impl Default for ConnectionLimits {
345 fn default() -> Self {
346 Self {
347 max_tcp_connections: DEFAULT_MAX_CONNECTIONS,
348 max_websocket_connections: DEFAULT_MAX_CONNECTIONS,
349 max_http_concurrent: DEFAULT_MAX_CONNECTIONS,
350 max_unix_connections: DEFAULT_MAX_CONNECTIONS,
351 }
352 }
353}
354
355impl ConnectionLimits {
356 #[must_use]
358 pub fn new(max_connections: usize) -> Self {
359 Self {
360 max_tcp_connections: max_connections,
361 max_websocket_connections: max_connections,
362 max_http_concurrent: max_connections,
363 max_unix_connections: max_connections,
364 }
365 }
366}
367
368#[derive(Debug, Clone, Default, Serialize, Deserialize)]
372pub struct RequiredCapabilities {
373 #[serde(default)]
375 pub roots: bool,
376 #[serde(default)]
378 pub sampling: bool,
379 #[serde(default)]
381 pub experimental: HashSet<String>,
382}
383
384impl RequiredCapabilities {
385 #[must_use]
387 pub fn none() -> Self {
388 Self::default()
389 }
390
391 #[must_use]
393 pub fn with_roots(mut self) -> Self {
394 self.roots = true;
395 self
396 }
397
398 #[must_use]
400 pub fn with_sampling(mut self) -> Self {
401 self.sampling = true;
402 self
403 }
404
405 #[must_use]
407 pub fn with_experimental(mut self, name: impl Into<String>) -> Self {
408 self.experimental.insert(name.into());
409 self
410 }
411
412 #[must_use]
414 pub fn validate(&self, client_caps: &ClientCapabilities) -> CapabilityValidation {
415 let mut missing = Vec::new();
416
417 if self.roots && !client_caps.roots {
418 missing.push("roots".to_string());
419 }
420
421 if self.sampling && !client_caps.sampling {
422 missing.push("sampling".to_string());
423 }
424
425 for exp in &self.experimental {
426 if !client_caps.experimental.contains(exp) {
427 missing.push(format!("experimental/{}", exp));
428 }
429 }
430
431 if missing.is_empty() {
432 CapabilityValidation::Valid
433 } else {
434 CapabilityValidation::Missing(missing)
435 }
436 }
437}
438
439#[derive(Debug, Clone, Default, Serialize, Deserialize)]
441pub struct ClientCapabilities {
442 #[serde(default)]
444 pub roots: bool,
445 #[serde(default)]
447 pub sampling: bool,
448 #[serde(default)]
450 pub experimental: HashSet<String>,
451}
452
453impl ClientCapabilities {
454 #[must_use]
456 pub fn from_params(params: &serde_json::Value) -> Self {
457 let caps = params.get("capabilities").cloned().unwrap_or_default();
458
459 Self {
460 roots: caps.get("roots").map(|v| !v.is_null()).unwrap_or(false),
461 sampling: caps.get("sampling").map(|v| !v.is_null()).unwrap_or(false),
462 experimental: caps
463 .get("experimental")
464 .and_then(|v| v.as_object())
465 .map(|obj| obj.keys().cloned().collect())
466 .unwrap_or_default(),
467 }
468 }
469}
470
471#[derive(Debug, Clone)]
473pub enum CapabilityValidation {
474 Valid,
476 Missing(Vec<String>),
478}
479
480impl CapabilityValidation {
481 #[must_use]
483 pub fn is_valid(&self) -> bool {
484 matches!(self, Self::Valid)
485 }
486
487 #[must_use]
489 pub fn missing(&self) -> Option<&[String]> {
490 match self {
491 Self::Valid => None,
492 Self::Missing(caps) => Some(caps),
493 }
494 }
495}
496
497#[derive(Debug)]
499pub struct RateLimiter {
500 config: RateLimitConfig,
501 global_bucket: Mutex<TokenBucket>,
503 client_buckets: Mutex<std::collections::HashMap<String, TokenBucket>>,
505 last_cleanup: Mutex<Instant>,
507}
508
509impl RateLimiter {
510 #[must_use]
512 pub fn new(config: RateLimitConfig) -> Self {
513 Self {
514 global_bucket: Mutex::new(TokenBucket::new(config.max_requests, config.window)),
515 client_buckets: Mutex::new(std::collections::HashMap::new()),
516 last_cleanup: Mutex::new(Instant::now()),
517 config,
518 }
519 }
520
521 pub fn check(&self, client_id: Option<&str>) -> bool {
525 let needs_cleanup = {
527 let last = self.last_cleanup.lock();
528 last.elapsed() > Duration::from_secs(60)
529 };
530 if needs_cleanup {
531 self.cleanup(Duration::from_secs(300));
532 *self.last_cleanup.lock() = Instant::now();
533 }
534
535 if self.config.per_client {
536 if let Some(id) = client_id {
537 let mut buckets = self.client_buckets.lock();
538 let bucket = buckets.entry(id.to_string()).or_insert_with(|| {
539 TokenBucket::new(self.config.max_requests, self.config.window)
540 });
541 bucket.try_acquire()
542 } else {
543 self.global_bucket.lock().try_acquire()
545 }
546 } else {
547 self.global_bucket.lock().try_acquire()
548 }
549 }
550
551 pub fn cleanup(&self, max_age: Duration) {
553 let mut buckets = self.client_buckets.lock();
554 let now = Instant::now();
555 buckets.retain(|_, bucket| now.duration_since(bucket.last_access) < max_age);
556 }
557
558 #[must_use]
560 pub fn client_bucket_count(&self) -> usize {
561 self.client_buckets.lock().len()
562 }
563}
564
565#[derive(Debug)]
567struct TokenBucket {
568 tokens: f64,
569 max_tokens: f64,
570 refill_rate: f64, last_refill: Instant,
572 last_access: Instant,
573}
574
575impl TokenBucket {
576 fn new(max_requests: u32, window: Duration) -> Self {
577 let max_tokens = max_requests as f64;
578 let refill_rate = max_tokens / window.as_secs_f64();
579 Self {
580 tokens: max_tokens,
581 max_tokens,
582 refill_rate,
583 last_refill: Instant::now(),
584 last_access: Instant::now(),
585 }
586 }
587
588 fn try_acquire(&mut self) -> bool {
589 let now = Instant::now();
590 let elapsed = now.duration_since(self.last_refill);
591
592 if elapsed >= Duration::from_millis(10) {
594 self.tokens =
595 (self.tokens + elapsed.as_secs_f64() * self.refill_rate).min(self.max_tokens);
596 self.last_refill = now;
597 }
598
599 self.last_access = now;
600
601 if self.tokens >= 1.0 {
602 self.tokens -= 1.0;
603 true
604 } else {
605 false
606 }
607 }
608}
609
610#[derive(Debug)]
615pub struct ConnectionCounter {
616 current: AtomicUsize,
617 max: usize,
618}
619
620impl ConnectionCounter {
621 #[must_use]
623 pub fn new(max: usize) -> Self {
624 Self {
625 current: AtomicUsize::new(0),
626 max,
627 }
628 }
629
630 pub fn try_acquire_arc(self: &Arc<Self>) -> Option<ConnectionGuard> {
635 for _ in 0..1000 {
638 let current = self.current.load(Ordering::Relaxed);
639 if current >= self.max {
640 return None;
641 }
642 if self
643 .current
644 .compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::Relaxed)
645 .is_ok()
646 {
647 return Some(ConnectionGuard {
648 counter: Arc::clone(self),
649 });
650 }
651 std::hint::spin_loop();
653 }
654 tracing::error!(
656 "ConnectionCounter CAS loop exceeded 1000 iterations - possible contention bug"
657 );
658 None
659 }
660
661 #[must_use]
663 pub fn current(&self) -> usize {
664 self.current.load(Ordering::Relaxed)
665 }
666
667 #[must_use]
669 pub fn max(&self) -> usize {
670 self.max
671 }
672
673 fn release(&self) {
674 self.current.fetch_sub(1, Ordering::SeqCst);
675 }
676}
677
678#[derive(Debug)]
682pub struct ConnectionGuard {
683 counter: Arc<ConnectionCounter>,
684}
685
686impl Drop for ConnectionGuard {
687 fn drop(&mut self) {
688 self.counter.release();
689 }
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695
696 #[test]
697 fn test_protocol_negotiation_exact_match() {
698 let config = ProtocolConfig::default();
699 assert_eq!(
700 config.negotiate(Some("2025-11-25")),
701 Some("2025-11-25".to_string())
702 );
703 }
704
705 #[test]
706 fn test_protocol_negotiation_fallback() {
707 let config = ProtocolConfig::default();
708 assert_eq!(config.negotiate(Some("unknown-version")), None);
709 }
710
711 #[test]
712 fn test_protocol_negotiation_strict() {
713 let config = ProtocolConfig::strict("2025-11-25");
714 assert_eq!(config.negotiate(Some("2025-06-18")), None);
715 }
716
717 #[test]
718 fn test_capability_validation() {
719 let required = RequiredCapabilities::none().with_roots();
720 let client = ClientCapabilities {
721 roots: true,
722 ..Default::default()
723 };
724 assert!(required.validate(&client).is_valid());
725
726 let client_missing = ClientCapabilities::default();
727 assert!(!required.validate(&client_missing).is_valid());
728 }
729
730 #[test]
731 fn test_rate_limiter() {
732 let config = RateLimitConfig::new(2, Duration::from_secs(1));
733 let limiter = RateLimiter::new(config);
734
735 assert!(limiter.check(None));
736 assert!(limiter.check(None));
737 assert!(!limiter.check(None)); }
739
740 #[test]
741 fn test_connection_counter() {
742 let counter = Arc::new(ConnectionCounter::new(2));
743
744 let guard1 = counter.try_acquire_arc();
745 assert!(guard1.is_some());
746 assert_eq!(counter.current(), 1);
747
748 let guard2 = counter.try_acquire_arc();
749 assert!(guard2.is_some());
750 assert_eq!(counter.current(), 2);
751
752 let guard3 = counter.try_acquire_arc();
753 assert!(guard3.is_none()); drop(guard1);
756 assert_eq!(counter.current(), 1);
757
758 let guard4 = counter.try_acquire_arc();
759 assert!(guard4.is_some());
760 }
761
762 #[test]
767 fn test_builder_default_succeeds() {
768 let config = ServerConfig::builder().build();
770 assert_eq!(config.max_message_size, DEFAULT_MAX_MESSAGE_SIZE);
771 }
772
773 #[test]
774 fn test_builder_try_build_valid() {
775 let result = ServerConfig::builder()
776 .max_message_size(1024 * 1024)
777 .try_build();
778 assert!(result.is_ok());
779 }
780
781 #[test]
782 fn test_builder_try_build_invalid_message_size() {
783 let result = ServerConfig::builder()
784 .max_message_size(100) .try_build();
786 assert!(result.is_err());
787 assert!(matches!(
788 result.unwrap_err(),
789 ConfigValidationError::InvalidMessageSize { .. }
790 ));
791 }
792
793 #[test]
794 fn test_builder_try_build_invalid_rate_limit() {
795 let result = ServerConfig::builder()
796 .rate_limit(RateLimitConfig {
797 max_requests: 0, window: Duration::from_secs(1),
799 per_client: true,
800 })
801 .try_build();
802 assert!(result.is_err());
803 assert!(matches!(
804 result.unwrap_err(),
805 ConfigValidationError::InvalidRateLimit { .. }
806 ));
807 }
808
809 #[test]
810 fn test_builder_try_build_zero_window() {
811 let result = ServerConfig::builder()
812 .rate_limit(RateLimitConfig {
813 max_requests: 100,
814 window: Duration::ZERO, per_client: true,
816 })
817 .try_build();
818 assert!(result.is_err());
819 assert!(matches!(
820 result.unwrap_err(),
821 ConfigValidationError::InvalidRateLimit { .. }
822 ));
823 }
824
825 #[test]
826 fn test_builder_try_build_invalid_connection_limits() {
827 let result = ServerConfig::builder()
828 .connection_limits(ConnectionLimits {
829 max_tcp_connections: 0,
830 max_websocket_connections: 0,
831 max_http_concurrent: 0,
832 max_unix_connections: 0,
833 })
834 .try_build();
835 assert!(result.is_err());
836 assert!(matches!(
837 result.unwrap_err(),
838 ConfigValidationError::InvalidConnectionLimits { .. }
839 ));
840 }
841}
842
843#[cfg(test)]
844mod proptest_tests {
845 use super::*;
846 use proptest::prelude::*;
847
848 proptest! {
849 #[test]
850 fn config_builder_never_panics(
851 max_msg_size in 0usize..10_000_000,
852 ) {
853 let _ = ServerConfig::builder()
855 .max_message_size(max_msg_size)
856 .try_build();
857 }
858
859 #[test]
860 fn connection_counter_bounded(max in 1usize..10000) {
861 let counter = Arc::new(ConnectionCounter::new(max));
862 let mut guards = Vec::new();
863 for _ in 0..max + 10 {
865 if let Some(guard) = counter.try_acquire_arc() {
866 guards.push(guard);
867 }
868 }
869 assert_eq!(guards.len(), max);
870 assert_eq!(counter.current(), max);
871 }
872 }
873}