1use std::collections::HashSet;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::time::{Duration, Instant};
13
14use parking_lot::Mutex;
15use serde::{Deserialize, Serialize};
16
17pub use turbomcp_core::SUPPORTED_VERSIONS as SUPPORTED_PROTOCOL_VERSIONS;
19pub use turbomcp_core::types::core::ProtocolVersion;
20
21pub const DEFAULT_MAX_CONNECTIONS: usize = 1000;
23
24pub const DEFAULT_RATE_LIMIT: u32 = 100;
26
27pub const DEFAULT_RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
29
30pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
32
33#[derive(Debug, Clone)]
35pub struct OriginValidationConfig {
36 pub allowed_origins: HashSet<String>,
38 pub allow_localhost: bool,
40 pub allow_any: bool,
42}
43
44impl Default for OriginValidationConfig {
45 fn default() -> Self {
46 Self {
47 allowed_origins: HashSet::new(),
48 allow_localhost: true,
49 allow_any: false,
50 }
51 }
52}
53
54impl OriginValidationConfig {
55 #[must_use]
57 pub fn new() -> Self {
58 Self::default()
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ServerConfig {
65 pub protocol: ProtocolConfig,
67 pub rate_limit: Option<RateLimitConfig>,
69 pub connection_limits: ConnectionLimits,
71 pub required_capabilities: RequiredCapabilities,
73 pub max_message_size: usize,
75 pub origin_validation: OriginValidationConfig,
77}
78
79impl Default for ServerConfig {
80 fn default() -> Self {
81 Self {
82 protocol: ProtocolConfig::default(),
83 rate_limit: None,
84 connection_limits: ConnectionLimits::default(),
85 required_capabilities: RequiredCapabilities::default(),
86 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
87 origin_validation: OriginValidationConfig::default(),
88 }
89 }
90}
91
92impl ServerConfig {
93 #[must_use]
95 pub fn new() -> Self {
96 Self::default()
97 }
98
99 #[must_use]
101 pub fn builder() -> ServerConfigBuilder {
102 ServerConfigBuilder::default()
103 }
104}
105
106#[derive(Debug, Clone, Default)]
108pub struct ServerConfigBuilder {
109 protocol: Option<ProtocolConfig>,
110 rate_limit: Option<RateLimitConfig>,
111 connection_limits: Option<ConnectionLimits>,
112 required_capabilities: Option<RequiredCapabilities>,
113 max_message_size: Option<usize>,
114 origin_validation: Option<OriginValidationConfig>,
115}
116
117impl ServerConfigBuilder {
118 #[must_use]
120 pub fn protocol(mut self, config: ProtocolConfig) -> Self {
121 self.protocol = Some(config);
122 self
123 }
124
125 #[must_use]
127 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
128 self.rate_limit = Some(config);
129 self
130 }
131
132 #[must_use]
134 pub fn connection_limits(mut self, limits: ConnectionLimits) -> Self {
135 self.connection_limits = Some(limits);
136 self
137 }
138
139 #[must_use]
141 pub fn required_capabilities(mut self, caps: RequiredCapabilities) -> Self {
142 self.required_capabilities = Some(caps);
143 self
144 }
145
146 #[must_use]
151 pub fn max_message_size(mut self, size: usize) -> Self {
152 self.max_message_size = Some(size);
153 self
154 }
155
156 #[must_use]
158 pub fn origin_validation(mut self, config: OriginValidationConfig) -> Self {
159 self.origin_validation = Some(config);
160 self
161 }
162
163 #[must_use]
165 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
166 self.origin_validation
167 .get_or_insert_with(OriginValidationConfig::default)
168 .allowed_origins
169 .insert(origin.into());
170 self
171 }
172
173 #[must_use]
175 pub fn allow_origins<I, S>(mut self, origins: I) -> Self
176 where
177 I: IntoIterator<Item = S>,
178 S: Into<String>,
179 {
180 let config = self
181 .origin_validation
182 .get_or_insert_with(OriginValidationConfig::default);
183 config
184 .allowed_origins
185 .extend(origins.into_iter().map(Into::into));
186 self
187 }
188
189 #[must_use]
191 pub fn allow_localhost_origins(mut self, allow: bool) -> Self {
192 self.origin_validation
193 .get_or_insert_with(OriginValidationConfig::default)
194 .allow_localhost = allow;
195 self
196 }
197
198 #[must_use]
200 pub fn allow_any_origin(mut self, allow: bool) -> Self {
201 self.origin_validation
202 .get_or_insert_with(OriginValidationConfig::default)
203 .allow_any = allow;
204 self
205 }
206
207 #[must_use]
212 pub fn build(self) -> ServerConfig {
213 ServerConfig {
214 protocol: self.protocol.unwrap_or_default(),
215 rate_limit: self.rate_limit,
216 connection_limits: self.connection_limits.unwrap_or_default(),
217 required_capabilities: self.required_capabilities.unwrap_or_default(),
218 max_message_size: self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE),
219 origin_validation: self.origin_validation.unwrap_or_default(),
220 }
221 }
222
223 pub fn try_build(self) -> Result<ServerConfig, ConfigValidationError> {
249 let max_message_size = self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE);
250
251 if max_message_size < 1024 {
253 return Err(ConfigValidationError::InvalidMessageSize {
254 size: max_message_size,
255 min: 1024,
256 });
257 }
258
259 if let Some(ref rate_limit) = self.rate_limit {
261 if rate_limit.max_requests == 0 {
262 return Err(ConfigValidationError::InvalidRateLimit {
263 reason: "max_requests cannot be 0".to_string(),
264 });
265 }
266 if rate_limit.window.is_zero() {
267 return Err(ConfigValidationError::InvalidRateLimit {
268 reason: "rate limit window cannot be zero".to_string(),
269 });
270 }
271 }
272
273 let connection_limits = self.connection_limits.unwrap_or_default();
275 if connection_limits.max_tcp_connections == 0
276 && connection_limits.max_websocket_connections == 0
277 && connection_limits.max_http_concurrent == 0
278 && connection_limits.max_unix_connections == 0
279 {
280 return Err(ConfigValidationError::InvalidConnectionLimits {
281 reason: "at least one connection limit must be non-zero".to_string(),
282 });
283 }
284
285 Ok(ServerConfig {
286 protocol: self.protocol.unwrap_or_default(),
287 rate_limit: self.rate_limit,
288 connection_limits,
289 required_capabilities: self.required_capabilities.unwrap_or_default(),
290 max_message_size,
291 origin_validation: self.origin_validation.unwrap_or_default(),
292 })
293 }
294}
295
296#[derive(Debug, Clone, thiserror::Error)]
298pub enum ConfigValidationError {
299 #[error("Invalid max_message_size: {size} bytes is below minimum of {min} bytes")]
301 InvalidMessageSize {
302 size: usize,
304 min: usize,
306 },
307
308 #[error("Invalid rate limit: {reason}")]
310 InvalidRateLimit {
311 reason: String,
313 },
314
315 #[error("Invalid connection limits: {reason}")]
317 InvalidConnectionLimits {
318 reason: String,
320 },
321}
322
323#[derive(Debug, Clone)]
325pub struct ProtocolConfig {
326 pub preferred_version: ProtocolVersion,
328 pub supported_versions: Vec<ProtocolVersion>,
330 pub allow_fallback: bool,
332}
333
334impl Default for ProtocolConfig {
335 fn default() -> Self {
336 Self {
337 preferred_version: ProtocolVersion::LATEST.clone(),
338 supported_versions: vec![ProtocolVersion::LATEST.clone()],
339 allow_fallback: false,
340 }
341 }
342}
343
344impl ProtocolConfig {
345 #[must_use]
347 pub fn strict(version: impl Into<ProtocolVersion>) -> Self {
348 let v = version.into();
349 Self {
350 preferred_version: v.clone(),
351 supported_versions: vec![v],
352 allow_fallback: false,
353 }
354 }
355
356 #[must_use]
361 pub fn multi_version() -> Self {
362 Self {
363 preferred_version: ProtocolVersion::LATEST.clone(),
364 supported_versions: ProtocolVersion::STABLE.to_vec(),
365 allow_fallback: false,
366 }
367 }
368
369 #[must_use]
371 pub fn is_supported(&self, version: &ProtocolVersion) -> bool {
372 self.supported_versions.contains(version)
373 }
374
375 #[must_use]
379 pub fn negotiate(&self, client_version: Option<&str>) -> Option<ProtocolVersion> {
380 match client_version {
381 Some(version_str) => {
382 let version = ProtocolVersion::from(version_str);
383 if self.is_supported(&version) {
384 Some(version)
385 } else if self.allow_fallback {
386 Some(self.preferred_version.clone())
387 } else {
388 None
389 }
390 }
391 None => Some(self.preferred_version.clone()),
392 }
393 }
394}
395
396#[derive(Debug, Clone)]
398pub struct RateLimitConfig {
399 pub max_requests: u32,
401 pub window: Duration,
403 pub per_client: bool,
405}
406
407impl Default for RateLimitConfig {
408 fn default() -> Self {
409 Self {
410 max_requests: DEFAULT_RATE_LIMIT,
411 window: DEFAULT_RATE_LIMIT_WINDOW,
412 per_client: true,
413 }
414 }
415}
416
417impl RateLimitConfig {
418 #[must_use]
420 pub fn new(max_requests: u32, window: Duration) -> Self {
421 Self {
422 max_requests,
423 window,
424 per_client: true,
425 }
426 }
427
428 #[must_use]
430 pub fn per_client(mut self, enabled: bool) -> Self {
431 self.per_client = enabled;
432 self
433 }
434}
435
436#[derive(Debug, Clone)]
438pub struct ConnectionLimits {
439 pub max_tcp_connections: usize,
441 pub max_websocket_connections: usize,
443 pub max_http_concurrent: usize,
445 pub max_unix_connections: usize,
447}
448
449impl Default for ConnectionLimits {
450 fn default() -> Self {
451 Self {
452 max_tcp_connections: DEFAULT_MAX_CONNECTIONS,
453 max_websocket_connections: DEFAULT_MAX_CONNECTIONS,
454 max_http_concurrent: DEFAULT_MAX_CONNECTIONS,
455 max_unix_connections: DEFAULT_MAX_CONNECTIONS,
456 }
457 }
458}
459
460impl ConnectionLimits {
461 #[must_use]
463 pub fn new(max_connections: usize) -> Self {
464 Self {
465 max_tcp_connections: max_connections,
466 max_websocket_connections: max_connections,
467 max_http_concurrent: max_connections,
468 max_unix_connections: max_connections,
469 }
470 }
471}
472
473#[derive(Debug, Clone, Default, Serialize, Deserialize)]
477pub struct RequiredCapabilities {
478 #[serde(default)]
480 pub roots: bool,
481 #[serde(default)]
483 pub sampling: bool,
484 #[serde(default)]
486 pub extensions: HashSet<String>,
487 #[serde(default)]
489 pub experimental: HashSet<String>,
490}
491
492impl RequiredCapabilities {
493 #[must_use]
495 pub fn none() -> Self {
496 Self::default()
497 }
498
499 #[must_use]
501 pub fn with_roots(mut self) -> Self {
502 self.roots = true;
503 self
504 }
505
506 #[must_use]
508 pub fn with_sampling(mut self) -> Self {
509 self.sampling = true;
510 self
511 }
512
513 #[must_use]
515 pub fn with_extension(mut self, name: impl Into<String>) -> Self {
516 self.extensions.insert(name.into());
517 self
518 }
519
520 #[must_use]
522 pub fn with_experimental(mut self, name: impl Into<String>) -> Self {
523 self.experimental.insert(name.into());
524 self
525 }
526
527 #[must_use]
529 pub fn validate(&self, client_caps: &ClientCapabilities) -> CapabilityValidation {
530 let mut missing = Vec::new();
531
532 if self.roots && !client_caps.roots {
533 missing.push("roots".to_string());
534 }
535
536 if self.sampling && !client_caps.sampling {
537 missing.push("sampling".to_string());
538 }
539
540 for extension in &self.extensions {
541 if !client_caps.extensions.contains(extension) {
542 missing.push(format!("extensions/{}", extension));
543 }
544 }
545
546 for exp in &self.experimental {
547 if !client_caps.experimental.contains(exp) {
548 missing.push(format!("experimental/{}", exp));
549 }
550 }
551
552 if missing.is_empty() {
553 CapabilityValidation::Valid
554 } else {
555 CapabilityValidation::Missing(missing)
556 }
557 }
558}
559
560#[derive(Debug, Clone, Default, Serialize, Deserialize)]
562pub struct ClientCapabilities {
563 #[serde(default)]
565 pub roots: bool,
566 #[serde(default)]
568 pub sampling: bool,
569 #[serde(default)]
571 pub extensions: HashSet<String>,
572 #[serde(default)]
574 pub experimental: HashSet<String>,
575}
576
577impl ClientCapabilities {
578 #[must_use]
580 pub fn from_params(params: &serde_json::Value) -> Self {
581 let caps = params.get("capabilities").cloned().unwrap_or_default();
582
583 Self {
584 roots: caps.get("roots").map(|v| !v.is_null()).unwrap_or(false),
585 sampling: caps.get("sampling").map(|v| !v.is_null()).unwrap_or(false),
586 extensions: caps
587 .get("extensions")
588 .and_then(|v| v.as_object())
589 .map(|obj| obj.keys().cloned().collect())
590 .unwrap_or_default(),
591 experimental: caps
592 .get("experimental")
593 .and_then(|v| v.as_object())
594 .map(|obj| obj.keys().cloned().collect())
595 .unwrap_or_default(),
596 }
597 }
598}
599
600#[derive(Debug, Clone)]
602pub enum CapabilityValidation {
603 Valid,
605 Missing(Vec<String>),
607}
608
609impl CapabilityValidation {
610 #[must_use]
612 pub fn is_valid(&self) -> bool {
613 matches!(self, Self::Valid)
614 }
615
616 #[must_use]
618 pub fn missing(&self) -> Option<&[String]> {
619 match self {
620 Self::Valid => None,
621 Self::Missing(caps) => Some(caps),
622 }
623 }
624}
625
626#[derive(Debug)]
628pub struct RateLimiter {
629 config: RateLimitConfig,
630 global_bucket: Mutex<TokenBucket>,
632 client_buckets: Mutex<std::collections::HashMap<String, TokenBucket>>,
634 last_cleanup: Mutex<Instant>,
636}
637
638impl RateLimiter {
639 #[must_use]
641 pub fn new(config: RateLimitConfig) -> Self {
642 Self {
643 global_bucket: Mutex::new(TokenBucket::new(config.max_requests, config.window)),
644 client_buckets: Mutex::new(std::collections::HashMap::new()),
645 last_cleanup: Mutex::new(Instant::now()),
646 config,
647 }
648 }
649
650 pub fn check(&self, client_id: Option<&str>) -> bool {
654 let needs_cleanup = {
656 let last = self.last_cleanup.lock();
657 last.elapsed() > Duration::from_secs(60)
658 };
659 if needs_cleanup {
660 self.cleanup(Duration::from_secs(300));
661 *self.last_cleanup.lock() = Instant::now();
662 }
663
664 if self.config.per_client {
665 if let Some(id) = client_id {
666 let mut buckets = self.client_buckets.lock();
667 let bucket = buckets.entry(id.to_string()).or_insert_with(|| {
668 TokenBucket::new(self.config.max_requests, self.config.window)
669 });
670 bucket.try_acquire()
671 } else {
672 self.global_bucket.lock().try_acquire()
674 }
675 } else {
676 self.global_bucket.lock().try_acquire()
677 }
678 }
679
680 pub fn cleanup(&self, max_age: Duration) {
682 let mut buckets = self.client_buckets.lock();
683 let now = Instant::now();
684 buckets.retain(|_, bucket| now.duration_since(bucket.last_access) < max_age);
685 }
686
687 #[must_use]
689 pub fn client_bucket_count(&self) -> usize {
690 self.client_buckets.lock().len()
691 }
692}
693
694#[derive(Debug)]
696struct TokenBucket {
697 tokens: f64,
698 max_tokens: f64,
699 refill_rate: f64, last_refill: Instant,
701 last_access: Instant,
702}
703
704impl TokenBucket {
705 fn new(max_requests: u32, window: Duration) -> Self {
706 let max_tokens = max_requests as f64;
707 let refill_rate = max_tokens / window.as_secs_f64();
708 Self {
709 tokens: max_tokens,
710 max_tokens,
711 refill_rate,
712 last_refill: Instant::now(),
713 last_access: Instant::now(),
714 }
715 }
716
717 fn try_acquire(&mut self) -> bool {
718 let now = Instant::now();
719 let elapsed = now.duration_since(self.last_refill);
720
721 if elapsed >= Duration::from_millis(10) {
723 self.tokens =
724 (self.tokens + elapsed.as_secs_f64() * self.refill_rate).min(self.max_tokens);
725 self.last_refill = now;
726 }
727
728 self.last_access = now;
729
730 if self.tokens >= 1.0 {
731 self.tokens -= 1.0;
732 true
733 } else {
734 false
735 }
736 }
737}
738
739#[derive(Debug)]
744pub struct ConnectionCounter {
745 current: AtomicUsize,
746 max: usize,
747}
748
749impl ConnectionCounter {
750 #[must_use]
752 pub fn new(max: usize) -> Self {
753 Self {
754 current: AtomicUsize::new(0),
755 max,
756 }
757 }
758
759 pub fn try_acquire_arc(self: &Arc<Self>) -> Option<ConnectionGuard> {
764 for _ in 0..1000 {
767 let current = self.current.load(Ordering::Relaxed);
768 if current >= self.max {
769 return None;
770 }
771 if self
772 .current
773 .compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::Relaxed)
774 .is_ok()
775 {
776 return Some(ConnectionGuard {
777 counter: Arc::clone(self),
778 });
779 }
780 std::hint::spin_loop();
782 }
783 tracing::error!(
785 "ConnectionCounter CAS loop exceeded 1000 iterations - possible contention bug"
786 );
787 None
788 }
789
790 #[must_use]
792 pub fn current(&self) -> usize {
793 self.current.load(Ordering::Relaxed)
794 }
795
796 #[must_use]
798 pub fn max(&self) -> usize {
799 self.max
800 }
801
802 fn release(&self) {
803 self.current.fetch_sub(1, Ordering::SeqCst);
804 }
805}
806
807#[derive(Debug)]
811pub struct ConnectionGuard {
812 counter: Arc<ConnectionCounter>,
813}
814
815impl Drop for ConnectionGuard {
816 fn drop(&mut self) {
817 self.counter.release();
818 }
819}
820
821#[cfg(test)]
822mod tests {
823 use super::*;
824
825 #[test]
826 fn test_protocol_negotiation_exact_match() {
827 let config = ProtocolConfig::default();
828 assert_eq!(
829 config.negotiate(Some("2025-11-25")),
830 Some(ProtocolVersion::V2025_11_25)
831 );
832 }
833
834 #[test]
835 fn test_protocol_negotiation_default_rejects_older_version() {
836 let config = ProtocolConfig::default();
838 assert_eq!(config.negotiate(Some("2025-06-18")), None);
839 }
840
841 #[test]
842 fn test_protocol_negotiation_multi_version_accepts_older() {
843 let config = ProtocolConfig::multi_version();
844 assert_eq!(
845 config.negotiate(Some("2025-06-18")),
846 Some(ProtocolVersion::V2025_06_18)
847 );
848 assert_eq!(
849 config.negotiate(Some("2025-11-25")),
850 Some(ProtocolVersion::V2025_11_25)
851 );
852 }
853
854 #[test]
855 fn test_protocol_negotiation_none_returns_preferred() {
856 let config = ProtocolConfig::default();
857 assert_eq!(config.negotiate(None), Some(ProtocolVersion::V2025_11_25));
858 }
859
860 #[test]
861 fn test_protocol_negotiation_unknown_version() {
862 let config = ProtocolConfig::default();
863 assert_eq!(config.negotiate(Some("unknown-version")), None);
864 }
865
866 #[test]
867 fn test_protocol_negotiation_strict() {
868 let config = ProtocolConfig::strict("2025-11-25");
869 assert_eq!(config.negotiate(Some("2025-06-18")), None);
870 }
871
872 #[test]
873 fn test_capability_validation() {
874 let required = RequiredCapabilities::none().with_roots();
875 let client = ClientCapabilities {
876 roots: true,
877 ..Default::default()
878 };
879 assert!(required.validate(&client).is_valid());
880
881 let client_missing = ClientCapabilities::default();
882 assert!(!required.validate(&client_missing).is_valid());
883 }
884
885 #[test]
886 fn test_extension_capability_validation() {
887 let required = RequiredCapabilities::none().with_extension("trace");
888 let client = ClientCapabilities {
889 extensions: ["trace".to_string()].into_iter().collect(),
890 ..Default::default()
891 };
892 assert!(required.validate(&client).is_valid());
893
894 let missing = ClientCapabilities::default();
895 let validation = required.validate(&missing);
896 assert!(!validation.is_valid());
897 assert_eq!(
898 validation.missing(),
899 Some(&["extensions/trace".to_string()][..])
900 );
901 }
902
903 #[test]
904 fn test_client_capabilities_parse_extensions() {
905 let params = serde_json::json!({
906 "capabilities": {
907 "extensions": {
908 "trace": {"version": "1"},
909 "handoff": {}
910 }
911 }
912 });
913
914 let caps = ClientCapabilities::from_params(¶ms);
915 assert!(caps.extensions.contains("trace"));
916 assert!(caps.extensions.contains("handoff"));
917 }
918
919 #[test]
920 fn test_rate_limiter() {
921 let config = RateLimitConfig::new(2, Duration::from_secs(1));
922 let limiter = RateLimiter::new(config);
923
924 assert!(limiter.check(None));
925 assert!(limiter.check(None));
926 assert!(!limiter.check(None)); }
928
929 #[test]
930 fn test_connection_counter() {
931 let counter = Arc::new(ConnectionCounter::new(2));
932
933 let guard1 = counter.try_acquire_arc();
934 assert!(guard1.is_some());
935 assert_eq!(counter.current(), 1);
936
937 let guard2 = counter.try_acquire_arc();
938 assert!(guard2.is_some());
939 assert_eq!(counter.current(), 2);
940
941 let guard3 = counter.try_acquire_arc();
942 assert!(guard3.is_none()); drop(guard1);
945 assert_eq!(counter.current(), 1);
946
947 let guard4 = counter.try_acquire_arc();
948 assert!(guard4.is_some());
949 }
950
951 #[test]
956 fn test_builder_default_succeeds() {
957 let config = ServerConfig::builder().build();
959 assert_eq!(config.max_message_size, DEFAULT_MAX_MESSAGE_SIZE);
960 assert!(config.origin_validation.allow_localhost);
961 assert!(config.origin_validation.allowed_origins.is_empty());
962 }
963
964 #[test]
965 fn test_builder_origin_validation_overrides() {
966 let config = ServerConfig::builder()
967 .allow_origin("https://app.example.com")
968 .allow_localhost_origins(false)
969 .build();
970
971 assert!(!config.origin_validation.allow_localhost);
972 assert!(
973 config
974 .origin_validation
975 .allowed_origins
976 .contains("https://app.example.com")
977 );
978 }
979
980 #[test]
981 fn test_builder_try_build_valid() {
982 let result = ServerConfig::builder()
983 .max_message_size(1024 * 1024)
984 .try_build();
985 assert!(result.is_ok());
986 }
987
988 #[test]
989 fn test_builder_try_build_invalid_message_size() {
990 let result = ServerConfig::builder()
991 .max_message_size(100) .try_build();
993 assert!(result.is_err());
994 assert!(matches!(
995 result.unwrap_err(),
996 ConfigValidationError::InvalidMessageSize { .. }
997 ));
998 }
999
1000 #[test]
1001 fn test_builder_try_build_invalid_rate_limit() {
1002 let result = ServerConfig::builder()
1003 .rate_limit(RateLimitConfig {
1004 max_requests: 0, window: Duration::from_secs(1),
1006 per_client: true,
1007 })
1008 .try_build();
1009 assert!(result.is_err());
1010 assert!(matches!(
1011 result.unwrap_err(),
1012 ConfigValidationError::InvalidRateLimit { .. }
1013 ));
1014 }
1015
1016 #[test]
1017 fn test_builder_try_build_zero_window() {
1018 let result = ServerConfig::builder()
1019 .rate_limit(RateLimitConfig {
1020 max_requests: 100,
1021 window: Duration::ZERO, per_client: true,
1023 })
1024 .try_build();
1025 assert!(result.is_err());
1026 assert!(matches!(
1027 result.unwrap_err(),
1028 ConfigValidationError::InvalidRateLimit { .. }
1029 ));
1030 }
1031
1032 #[test]
1033 fn test_builder_try_build_invalid_connection_limits() {
1034 let result = ServerConfig::builder()
1035 .connection_limits(ConnectionLimits {
1036 max_tcp_connections: 0,
1037 max_websocket_connections: 0,
1038 max_http_concurrent: 0,
1039 max_unix_connections: 0,
1040 })
1041 .try_build();
1042 assert!(result.is_err());
1043 assert!(matches!(
1044 result.unwrap_err(),
1045 ConfigValidationError::InvalidConnectionLimits { .. }
1046 ));
1047 }
1048}
1049
1050#[cfg(test)]
1051mod proptest_tests {
1052 use super::*;
1053 use proptest::prelude::*;
1054
1055 proptest! {
1056 #[test]
1057 fn config_builder_never_panics(
1058 max_msg_size in 0usize..10_000_000,
1059 ) {
1060 let _ = ServerConfig::builder()
1062 .max_message_size(max_msg_size)
1063 .try_build();
1064 }
1065
1066 #[test]
1067 fn connection_counter_bounded(max in 1usize..10000) {
1068 let counter = Arc::new(ConnectionCounter::new(max));
1069 let mut guards = Vec::new();
1070 for _ in 0..max + 10 {
1072 if let Some(guard) = counter.try_acquire_arc() {
1073 guards.push(guard);
1074 }
1075 }
1076 assert_eq!(guards.len(), max);
1077 assert_eq!(counter.current(), max);
1078 }
1079 }
1080}