Skip to main content

turbomcp_server/
config.rs

1//! Server Configuration
2//!
3//! This module provides configuration options for MCP servers including:
4//! - Protocol version negotiation
5//! - Rate limiting
6//! - Connection limits
7//! - Capability requirements
8
9use std::collections::HashSet;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::time::{Duration, Instant};
13
14use parking_lot::Mutex;
15use serde::{Deserialize, Serialize};
16
17// Re-export from core (single source of truth - DRY)
18pub use turbomcp_core::SUPPORTED_VERSIONS as SUPPORTED_PROTOCOL_VERSIONS;
19pub use turbomcp_core::types::core::ProtocolVersion;
20
21/// Default maximum connections for TCP transport.
22pub const DEFAULT_MAX_CONNECTIONS: usize = 1000;
23
24/// Default rate limit (requests per second).
25pub const DEFAULT_RATE_LIMIT: u32 = 100;
26
27/// Default rate limit window.
28pub const DEFAULT_RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
29
30/// Default maximum message size (10MB).
31pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
32
33/// Origin validation configuration for HTTP transports.
34#[derive(Debug, Clone)]
35pub struct OriginValidationConfig {
36    /// Explicitly allowed origins.
37    pub allowed_origins: HashSet<String>,
38    /// Whether to allow localhost/browser-dev origins.
39    pub allow_localhost: bool,
40    /// Whether to disable origin checks entirely.
41    pub allow_any: bool,
42}
43
44impl Default for OriginValidationConfig {
45    fn default() -> Self {
46        Self {
47            allowed_origins: HashSet::new(),
48            allow_localhost: true,
49            allow_any: false,
50        }
51    }
52}
53
54impl OriginValidationConfig {
55    /// Create a new origin validation configuration.
56    #[must_use]
57    pub fn new() -> Self {
58        Self::default()
59    }
60}
61
62/// Server configuration.
63#[derive(Debug, Clone)]
64pub struct ServerConfig {
65    /// Protocol version configuration.
66    pub protocol: ProtocolConfig,
67    /// Rate limiting configuration.
68    pub rate_limit: Option<RateLimitConfig>,
69    /// Connection limits.
70    pub connection_limits: ConnectionLimits,
71    /// Required client capabilities.
72    pub required_capabilities: RequiredCapabilities,
73    /// Maximum message size in bytes (default: 10MB).
74    pub max_message_size: usize,
75    /// HTTP origin validation policy.
76    pub origin_validation: OriginValidationConfig,
77}
78
79impl Default for ServerConfig {
80    fn default() -> Self {
81        Self {
82            protocol: ProtocolConfig::default(),
83            rate_limit: None,
84            connection_limits: ConnectionLimits::default(),
85            required_capabilities: RequiredCapabilities::default(),
86            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
87            origin_validation: OriginValidationConfig::default(),
88        }
89    }
90}
91
92impl ServerConfig {
93    /// Create a new server configuration with defaults.
94    #[must_use]
95    pub fn new() -> Self {
96        Self::default()
97    }
98
99    /// Create a builder for server configuration.
100    #[must_use]
101    pub fn builder() -> ServerConfigBuilder {
102        ServerConfigBuilder::default()
103    }
104}
105
106/// Builder for server configuration.
107#[derive(Debug, Clone, Default)]
108pub struct ServerConfigBuilder {
109    protocol: Option<ProtocolConfig>,
110    rate_limit: Option<RateLimitConfig>,
111    connection_limits: Option<ConnectionLimits>,
112    required_capabilities: Option<RequiredCapabilities>,
113    max_message_size: Option<usize>,
114    origin_validation: Option<OriginValidationConfig>,
115}
116
117impl ServerConfigBuilder {
118    /// Set protocol configuration.
119    #[must_use]
120    pub fn protocol(mut self, config: ProtocolConfig) -> Self {
121        self.protocol = Some(config);
122        self
123    }
124
125    /// Set rate limiting configuration.
126    #[must_use]
127    pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
128        self.rate_limit = Some(config);
129        self
130    }
131
132    /// Set connection limits.
133    #[must_use]
134    pub fn connection_limits(mut self, limits: ConnectionLimits) -> Self {
135        self.connection_limits = Some(limits);
136        self
137    }
138
139    /// Set required client capabilities.
140    #[must_use]
141    pub fn required_capabilities(mut self, caps: RequiredCapabilities) -> Self {
142        self.required_capabilities = Some(caps);
143        self
144    }
145
146    /// Set maximum message size in bytes.
147    ///
148    /// Messages exceeding this size will be rejected.
149    /// Default: 10MB.
150    #[must_use]
151    pub fn max_message_size(mut self, size: usize) -> Self {
152        self.max_message_size = Some(size);
153        self
154    }
155
156    /// Set HTTP origin validation configuration.
157    #[must_use]
158    pub fn origin_validation(mut self, config: OriginValidationConfig) -> Self {
159        self.origin_validation = Some(config);
160        self
161    }
162
163    /// Add a single allowed origin for HTTP transports.
164    #[must_use]
165    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
166        self.origin_validation
167            .get_or_insert_with(OriginValidationConfig::default)
168            .allowed_origins
169            .insert(origin.into());
170        self
171    }
172
173    /// Add multiple allowed origins for HTTP transports.
174    #[must_use]
175    pub fn allow_origins<I, S>(mut self, origins: I) -> Self
176    where
177        I: IntoIterator<Item = S>,
178        S: Into<String>,
179    {
180        let config = self
181            .origin_validation
182            .get_or_insert_with(OriginValidationConfig::default);
183        config
184            .allowed_origins
185            .extend(origins.into_iter().map(Into::into));
186        self
187    }
188
189    /// Control whether localhost origins are accepted.
190    #[must_use]
191    pub fn allow_localhost_origins(mut self, allow: bool) -> Self {
192        self.origin_validation
193            .get_or_insert_with(OriginValidationConfig::default)
194            .allow_localhost = allow;
195        self
196    }
197
198    /// Disable origin checks entirely.
199    #[must_use]
200    pub fn allow_any_origin(mut self, allow: bool) -> Self {
201        self.origin_validation
202            .get_or_insert_with(OriginValidationConfig::default)
203            .allow_any = allow;
204        self
205    }
206
207    /// Build the server configuration with sensible defaults.
208    ///
209    /// This method always succeeds and uses defaults for any unset fields.
210    /// For strict validation, use [`try_build()`](Self::try_build).
211    #[must_use]
212    pub fn build(self) -> ServerConfig {
213        ServerConfig {
214            protocol: self.protocol.unwrap_or_default(),
215            rate_limit: self.rate_limit,
216            connection_limits: self.connection_limits.unwrap_or_default(),
217            required_capabilities: self.required_capabilities.unwrap_or_default(),
218            max_message_size: self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE),
219            origin_validation: self.origin_validation.unwrap_or_default(),
220        }
221    }
222
223    /// Build the server configuration with validation.
224    ///
225    /// This method validates the configuration and returns an error if any
226    /// constraints are violated. Use this for stricter configuration checking
227    /// in enterprise deployments.
228    ///
229    /// # Errors
230    ///
231    /// Returns an error if:
232    /// - `max_message_size` is less than 1024 bytes (minimum viable message size)
233    /// - Rate limit `max_requests` is 0
234    /// - Rate limit `window` is zero
235    /// - Connection limits have all values set to 0
236    ///
237    /// # Example
238    ///
239    /// ```rust
240    /// use turbomcp_server::ServerConfig;
241    ///
242    /// // Validated build - catches configuration errors
243    /// let config = ServerConfig::builder()
244    ///     .max_message_size(1024 * 1024) // 1MB
245    ///     .try_build()
246    ///     .expect("Invalid configuration");
247    /// ```
248    pub fn try_build(self) -> Result<ServerConfig, ConfigValidationError> {
249        let max_message_size = self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE);
250
251        // Validate message size
252        if max_message_size < 1024 {
253            return Err(ConfigValidationError::InvalidMessageSize {
254                size: max_message_size,
255                min: 1024,
256            });
257        }
258
259        // Validate rate limit if provided
260        if let Some(ref rate_limit) = self.rate_limit {
261            if rate_limit.max_requests == 0 {
262                return Err(ConfigValidationError::InvalidRateLimit {
263                    reason: "max_requests cannot be 0".to_string(),
264                });
265            }
266            if rate_limit.window.is_zero() {
267                return Err(ConfigValidationError::InvalidRateLimit {
268                    reason: "rate limit window cannot be zero".to_string(),
269                });
270            }
271        }
272
273        // Validate connection limits
274        let connection_limits = self.connection_limits.unwrap_or_default();
275        if connection_limits.max_tcp_connections == 0
276            && connection_limits.max_websocket_connections == 0
277            && connection_limits.max_http_concurrent == 0
278            && connection_limits.max_unix_connections == 0
279        {
280            return Err(ConfigValidationError::InvalidConnectionLimits {
281                reason: "at least one connection limit must be non-zero".to_string(),
282            });
283        }
284
285        Ok(ServerConfig {
286            protocol: self.protocol.unwrap_or_default(),
287            rate_limit: self.rate_limit,
288            connection_limits,
289            required_capabilities: self.required_capabilities.unwrap_or_default(),
290            max_message_size,
291            origin_validation: self.origin_validation.unwrap_or_default(),
292        })
293    }
294}
295
296/// Errors that can occur during configuration validation.
297#[derive(Debug, Clone, thiserror::Error)]
298pub enum ConfigValidationError {
299    /// Invalid message size configuration.
300    #[error("Invalid max_message_size: {size} bytes is below minimum of {min} bytes")]
301    InvalidMessageSize {
302        /// The configured size.
303        size: usize,
304        /// The minimum allowed size.
305        min: usize,
306    },
307
308    /// Invalid rate limit configuration.
309    #[error("Invalid rate limit: {reason}")]
310    InvalidRateLimit {
311        /// Description of the validation failure.
312        reason: String,
313    },
314
315    /// Invalid connection limits configuration.
316    #[error("Invalid connection limits: {reason}")]
317    InvalidConnectionLimits {
318        /// Description of the validation failure.
319        reason: String,
320    },
321}
322
323/// Protocol version configuration.
324#[derive(Debug, Clone)]
325pub struct ProtocolConfig {
326    /// Preferred protocol version.
327    pub preferred_version: ProtocolVersion,
328    /// Supported protocol versions.
329    pub supported_versions: Vec<ProtocolVersion>,
330    /// Allow fallback to server's preferred version if client's is unsupported.
331    pub allow_fallback: bool,
332}
333
334impl Default for ProtocolConfig {
335    fn default() -> Self {
336        Self {
337            preferred_version: ProtocolVersion::LATEST.clone(),
338            supported_versions: vec![ProtocolVersion::LATEST.clone()],
339            allow_fallback: false,
340        }
341    }
342}
343
344impl ProtocolConfig {
345    /// Create a strict configuration that only accepts the specified version.
346    #[must_use]
347    pub fn strict(version: impl Into<ProtocolVersion>) -> Self {
348        let v = version.into();
349        Self {
350            preferred_version: v.clone(),
351            supported_versions: vec![v],
352            allow_fallback: false,
353        }
354    }
355
356    /// Create a multi-version configuration that accepts all stable versions.
357    ///
358    /// The preferred version is the latest stable. Older clients are accepted
359    /// and responses are filtered through the appropriate version adapter.
360    #[must_use]
361    pub fn multi_version() -> Self {
362        Self {
363            preferred_version: ProtocolVersion::LATEST.clone(),
364            supported_versions: ProtocolVersion::STABLE.to_vec(),
365            allow_fallback: false,
366        }
367    }
368
369    /// Check if a protocol version is supported.
370    #[must_use]
371    pub fn is_supported(&self, version: &ProtocolVersion) -> bool {
372        self.supported_versions.contains(version)
373    }
374
375    /// Negotiate protocol version with client.
376    ///
377    /// Returns the negotiated version or None if no compatible version found.
378    #[must_use]
379    pub fn negotiate(&self, client_version: Option<&str>) -> Option<ProtocolVersion> {
380        match client_version {
381            Some(version_str) => {
382                let version = ProtocolVersion::from(version_str);
383                if self.is_supported(&version) {
384                    Some(version)
385                } else if self.allow_fallback {
386                    Some(self.preferred_version.clone())
387                } else {
388                    None
389                }
390            }
391            None => Some(self.preferred_version.clone()),
392        }
393    }
394}
395
396/// Rate limiting configuration.
397#[derive(Debug, Clone)]
398pub struct RateLimitConfig {
399    /// Maximum requests per window.
400    pub max_requests: u32,
401    /// Time window for rate limiting.
402    pub window: Duration,
403    /// Whether to rate limit per client (by user_id or IP).
404    pub per_client: bool,
405}
406
407impl Default for RateLimitConfig {
408    fn default() -> Self {
409        Self {
410            max_requests: DEFAULT_RATE_LIMIT,
411            window: DEFAULT_RATE_LIMIT_WINDOW,
412            per_client: true,
413        }
414    }
415}
416
417impl RateLimitConfig {
418    /// Create a new rate limit configuration.
419    #[must_use]
420    pub fn new(max_requests: u32, window: Duration) -> Self {
421        Self {
422            max_requests,
423            window,
424            per_client: true,
425        }
426    }
427
428    /// Set per-client rate limiting.
429    #[must_use]
430    pub fn per_client(mut self, enabled: bool) -> Self {
431        self.per_client = enabled;
432        self
433    }
434}
435
436/// Connection limits.
437#[derive(Debug, Clone)]
438pub struct ConnectionLimits {
439    /// Maximum concurrent TCP connections.
440    pub max_tcp_connections: usize,
441    /// Maximum concurrent WebSocket connections.
442    pub max_websocket_connections: usize,
443    /// Maximum concurrent HTTP requests.
444    pub max_http_concurrent: usize,
445    /// Maximum concurrent Unix socket connections.
446    pub max_unix_connections: usize,
447}
448
449impl Default for ConnectionLimits {
450    fn default() -> Self {
451        Self {
452            max_tcp_connections: DEFAULT_MAX_CONNECTIONS,
453            max_websocket_connections: DEFAULT_MAX_CONNECTIONS,
454            max_http_concurrent: DEFAULT_MAX_CONNECTIONS,
455            max_unix_connections: DEFAULT_MAX_CONNECTIONS,
456        }
457    }
458}
459
460impl ConnectionLimits {
461    /// Create a new connection limits configuration.
462    #[must_use]
463    pub fn new(max_connections: usize) -> Self {
464        Self {
465            max_tcp_connections: max_connections,
466            max_websocket_connections: max_connections,
467            max_http_concurrent: max_connections,
468            max_unix_connections: max_connections,
469        }
470    }
471}
472
473/// Required client capabilities.
474///
475/// Specifies which client capabilities the server requires.
476#[derive(Debug, Clone, Default, Serialize, Deserialize)]
477pub struct RequiredCapabilities {
478    /// Require roots capability.
479    #[serde(default)]
480    pub roots: bool,
481    /// Require sampling capability.
482    #[serde(default)]
483    pub sampling: bool,
484    /// Require draft extensions.
485    #[serde(default)]
486    pub extensions: HashSet<String>,
487    /// Require experimental capabilities.
488    #[serde(default)]
489    pub experimental: HashSet<String>,
490}
491
492impl RequiredCapabilities {
493    /// Create empty required capabilities (no requirements).
494    #[must_use]
495    pub fn none() -> Self {
496        Self::default()
497    }
498
499    /// Require roots capability.
500    #[must_use]
501    pub fn with_roots(mut self) -> Self {
502        self.roots = true;
503        self
504    }
505
506    /// Require sampling capability.
507    #[must_use]
508    pub fn with_sampling(mut self) -> Self {
509        self.sampling = true;
510        self
511    }
512
513    /// Require a draft extension.
514    #[must_use]
515    pub fn with_extension(mut self, name: impl Into<String>) -> Self {
516        self.extensions.insert(name.into());
517        self
518    }
519
520    /// Require an experimental capability.
521    #[must_use]
522    pub fn with_experimental(mut self, name: impl Into<String>) -> Self {
523        self.experimental.insert(name.into());
524        self
525    }
526
527    /// Check if all required capabilities are present in client capabilities.
528    #[must_use]
529    pub fn validate(&self, client_caps: &ClientCapabilities) -> CapabilityValidation {
530        let mut missing = Vec::new();
531
532        if self.roots && !client_caps.roots {
533            missing.push("roots".to_string());
534        }
535
536        if self.sampling && !client_caps.sampling {
537            missing.push("sampling".to_string());
538        }
539
540        for extension in &self.extensions {
541            if !client_caps.extensions.contains(extension) {
542                missing.push(format!("extensions/{}", extension));
543            }
544        }
545
546        for exp in &self.experimental {
547            if !client_caps.experimental.contains(exp) {
548                missing.push(format!("experimental/{}", exp));
549            }
550        }
551
552        if missing.is_empty() {
553            CapabilityValidation::Valid
554        } else {
555            CapabilityValidation::Missing(missing)
556        }
557    }
558}
559
560/// Client capabilities received during initialization.
561#[derive(Debug, Clone, Default, Serialize, Deserialize)]
562pub struct ClientCapabilities {
563    /// Client supports roots.
564    #[serde(default)]
565    pub roots: bool,
566    /// Client supports sampling.
567    #[serde(default)]
568    pub sampling: bool,
569    /// Client draft extensions.
570    #[serde(default)]
571    pub extensions: HashSet<String>,
572    /// Client experimental capabilities.
573    #[serde(default)]
574    pub experimental: HashSet<String>,
575}
576
577impl ClientCapabilities {
578    /// Parse client capabilities from initialize request params.
579    #[must_use]
580    pub fn from_params(params: &serde_json::Value) -> Self {
581        let caps = params.get("capabilities").cloned().unwrap_or_default();
582
583        Self {
584            roots: caps.get("roots").map(|v| !v.is_null()).unwrap_or(false),
585            sampling: caps.get("sampling").map(|v| !v.is_null()).unwrap_or(false),
586            extensions: caps
587                .get("extensions")
588                .and_then(|v| v.as_object())
589                .map(|obj| obj.keys().cloned().collect())
590                .unwrap_or_default(),
591            experimental: caps
592                .get("experimental")
593                .and_then(|v| v.as_object())
594                .map(|obj| obj.keys().cloned().collect())
595                .unwrap_or_default(),
596        }
597    }
598}
599
600/// Result of capability validation.
601#[derive(Debug, Clone)]
602pub enum CapabilityValidation {
603    /// All required capabilities are present.
604    Valid,
605    /// Some required capabilities are missing.
606    Missing(Vec<String>),
607}
608
609impl CapabilityValidation {
610    /// Check if validation passed.
611    #[must_use]
612    pub fn is_valid(&self) -> bool {
613        matches!(self, Self::Valid)
614    }
615
616    /// Get missing capabilities if any.
617    #[must_use]
618    pub fn missing(&self) -> Option<&[String]> {
619        match self {
620            Self::Valid => None,
621            Self::Missing(caps) => Some(caps),
622        }
623    }
624}
625
626/// Rate limiter using token bucket algorithm.
627#[derive(Debug)]
628pub struct RateLimiter {
629    config: RateLimitConfig,
630    /// Global bucket for non-per-client limiting.
631    global_bucket: Mutex<TokenBucket>,
632    /// Per-client buckets (keyed by client ID).
633    client_buckets: Mutex<std::collections::HashMap<String, TokenBucket>>,
634    /// Last cleanup timestamp for automatic cleanup.
635    last_cleanup: Mutex<Instant>,
636}
637
638impl RateLimiter {
639    /// Create a new rate limiter.
640    #[must_use]
641    pub fn new(config: RateLimitConfig) -> Self {
642        Self {
643            global_bucket: Mutex::new(TokenBucket::new(config.max_requests, config.window)),
644            client_buckets: Mutex::new(std::collections::HashMap::new()),
645            last_cleanup: Mutex::new(Instant::now()),
646            config,
647        }
648    }
649
650    /// Check if a request is allowed.
651    ///
652    /// Returns `true` if allowed, `false` if rate limited.
653    pub fn check(&self, client_id: Option<&str>) -> bool {
654        // Periodic cleanup of stale client buckets (avoid unbounded growth)
655        let needs_cleanup = {
656            let last = self.last_cleanup.lock();
657            last.elapsed() > Duration::from_secs(60)
658        };
659        if needs_cleanup {
660            self.cleanup(Duration::from_secs(300));
661            *self.last_cleanup.lock() = Instant::now();
662        }
663
664        if self.config.per_client {
665            if let Some(id) = client_id {
666                let mut buckets = self.client_buckets.lock();
667                let bucket = buckets.entry(id.to_string()).or_insert_with(|| {
668                    TokenBucket::new(self.config.max_requests, self.config.window)
669                });
670                bucket.try_acquire()
671            } else {
672                // No client ID, use global bucket
673                self.global_bucket.lock().try_acquire()
674            }
675        } else {
676            self.global_bucket.lock().try_acquire()
677        }
678    }
679
680    /// Clean up old client buckets to prevent memory growth.
681    pub fn cleanup(&self, max_age: Duration) {
682        let mut buckets = self.client_buckets.lock();
683        let now = Instant::now();
684        buckets.retain(|_, bucket| now.duration_since(bucket.last_access) < max_age);
685    }
686
687    /// Get the current number of tracked client buckets.
688    #[must_use]
689    pub fn client_bucket_count(&self) -> usize {
690        self.client_buckets.lock().len()
691    }
692}
693
694/// Token bucket for rate limiting.
695#[derive(Debug)]
696struct TokenBucket {
697    tokens: f64,
698    max_tokens: f64,
699    refill_rate: f64, // tokens per second
700    last_refill: Instant,
701    last_access: Instant,
702}
703
704impl TokenBucket {
705    fn new(max_requests: u32, window: Duration) -> Self {
706        let max_tokens = max_requests as f64;
707        let refill_rate = max_tokens / window.as_secs_f64();
708        Self {
709            tokens: max_tokens,
710            max_tokens,
711            refill_rate,
712            last_refill: Instant::now(),
713            last_access: Instant::now(),
714        }
715    }
716
717    fn try_acquire(&mut self) -> bool {
718        let now = Instant::now();
719        let elapsed = now.duration_since(self.last_refill);
720
721        // Only refill if meaningful time has passed (reduces syscalls on burst traffic)
722        if elapsed >= Duration::from_millis(10) {
723            self.tokens =
724                (self.tokens + elapsed.as_secs_f64() * self.refill_rate).min(self.max_tokens);
725            self.last_refill = now;
726        }
727
728        self.last_access = now;
729
730        if self.tokens >= 1.0 {
731            self.tokens -= 1.0;
732            true
733        } else {
734            false
735        }
736    }
737}
738
739/// Connection counter for tracking active connections.
740///
741/// This is designed to be wrapped in `Arc` and shared across async tasks.
742/// Use `try_acquire_arc` to get a guard that can be moved into spawned tasks.
743#[derive(Debug)]
744pub struct ConnectionCounter {
745    current: AtomicUsize,
746    max: usize,
747}
748
749impl ConnectionCounter {
750    /// Create a new connection counter.
751    #[must_use]
752    pub fn new(max: usize) -> Self {
753        Self {
754            current: AtomicUsize::new(0),
755            max,
756        }
757    }
758
759    /// Try to acquire a connection slot (for use when counter is in Arc).
760    ///
761    /// Returns a guard that releases the slot when dropped, or None if at capacity.
762    /// The guard is `Send + 'static` and can be moved into spawned async tasks.
763    pub fn try_acquire_arc(self: &Arc<Self>) -> Option<ConnectionGuard> {
764        // CAS loop with bounded iterations to prevent infinite spin
765        // In practice this should succeed in 1-2 iterations; 1000 indicates a bug
766        for _ in 0..1000 {
767            let current = self.current.load(Ordering::Relaxed);
768            if current >= self.max {
769                return None;
770            }
771            if self
772                .current
773                .compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::Relaxed)
774                .is_ok()
775            {
776                return Some(ConnectionGuard {
777                    counter: Arc::clone(self),
778                });
779            }
780            // Hint to the CPU that we're spinning (avoids pipeline stalls)
781            std::hint::spin_loop();
782        }
783        // This should never be reached in normal operation
784        tracing::error!(
785            "ConnectionCounter CAS loop exceeded 1000 iterations - possible contention bug"
786        );
787        None
788    }
789
790    /// Get current connection count.
791    #[must_use]
792    pub fn current(&self) -> usize {
793        self.current.load(Ordering::Relaxed)
794    }
795
796    /// Get maximum connections.
797    #[must_use]
798    pub fn max(&self) -> usize {
799        self.max
800    }
801
802    fn release(&self) {
803        self.current.fetch_sub(1, Ordering::SeqCst);
804    }
805}
806
807/// Guard that releases a connection slot when dropped.
808///
809/// This guard is `Send + 'static` and can be safely moved into spawned async tasks.
810#[derive(Debug)]
811pub struct ConnectionGuard {
812    counter: Arc<ConnectionCounter>,
813}
814
815impl Drop for ConnectionGuard {
816    fn drop(&mut self) {
817        self.counter.release();
818    }
819}
820
821#[cfg(test)]
822mod tests {
823    use super::*;
824
825    #[test]
826    fn test_protocol_negotiation_exact_match() {
827        let config = ProtocolConfig::default();
828        assert_eq!(
829            config.negotiate(Some("2025-11-25")),
830            Some(ProtocolVersion::V2025_11_25)
831        );
832    }
833
834    #[test]
835    fn test_protocol_negotiation_default_rejects_older_version() {
836        // Default config is strict latest-only
837        let config = ProtocolConfig::default();
838        assert_eq!(config.negotiate(Some("2025-06-18")), None);
839    }
840
841    #[test]
842    fn test_protocol_negotiation_multi_version_accepts_older() {
843        let config = ProtocolConfig::multi_version();
844        assert_eq!(
845            config.negotiate(Some("2025-06-18")),
846            Some(ProtocolVersion::V2025_06_18)
847        );
848        assert_eq!(
849            config.negotiate(Some("2025-11-25")),
850            Some(ProtocolVersion::V2025_11_25)
851        );
852    }
853
854    #[test]
855    fn test_protocol_negotiation_none_returns_preferred() {
856        let config = ProtocolConfig::default();
857        assert_eq!(config.negotiate(None), Some(ProtocolVersion::V2025_11_25));
858    }
859
860    #[test]
861    fn test_protocol_negotiation_unknown_version() {
862        let config = ProtocolConfig::default();
863        assert_eq!(config.negotiate(Some("unknown-version")), None);
864    }
865
866    #[test]
867    fn test_protocol_negotiation_strict() {
868        let config = ProtocolConfig::strict("2025-11-25");
869        assert_eq!(config.negotiate(Some("2025-06-18")), None);
870    }
871
872    #[test]
873    fn test_capability_validation() {
874        let required = RequiredCapabilities::none().with_roots();
875        let client = ClientCapabilities {
876            roots: true,
877            ..Default::default()
878        };
879        assert!(required.validate(&client).is_valid());
880
881        let client_missing = ClientCapabilities::default();
882        assert!(!required.validate(&client_missing).is_valid());
883    }
884
885    #[test]
886    fn test_extension_capability_validation() {
887        let required = RequiredCapabilities::none().with_extension("trace");
888        let client = ClientCapabilities {
889            extensions: ["trace".to_string()].into_iter().collect(),
890            ..Default::default()
891        };
892        assert!(required.validate(&client).is_valid());
893
894        let missing = ClientCapabilities::default();
895        let validation = required.validate(&missing);
896        assert!(!validation.is_valid());
897        assert_eq!(
898            validation.missing(),
899            Some(&["extensions/trace".to_string()][..])
900        );
901    }
902
903    #[test]
904    fn test_client_capabilities_parse_extensions() {
905        let params = serde_json::json!({
906            "capabilities": {
907                "extensions": {
908                    "trace": {"version": "1"},
909                    "handoff": {}
910                }
911            }
912        });
913
914        let caps = ClientCapabilities::from_params(&params);
915        assert!(caps.extensions.contains("trace"));
916        assert!(caps.extensions.contains("handoff"));
917    }
918
919    #[test]
920    fn test_rate_limiter() {
921        let config = RateLimitConfig::new(2, Duration::from_secs(1));
922        let limiter = RateLimiter::new(config);
923
924        assert!(limiter.check(None));
925        assert!(limiter.check(None));
926        assert!(!limiter.check(None)); // Should be rate limited
927    }
928
929    #[test]
930    fn test_connection_counter() {
931        let counter = Arc::new(ConnectionCounter::new(2));
932
933        let guard1 = counter.try_acquire_arc();
934        assert!(guard1.is_some());
935        assert_eq!(counter.current(), 1);
936
937        let guard2 = counter.try_acquire_arc();
938        assert!(guard2.is_some());
939        assert_eq!(counter.current(), 2);
940
941        let guard3 = counter.try_acquire_arc();
942        assert!(guard3.is_none()); // At capacity
943
944        drop(guard1);
945        assert_eq!(counter.current(), 1);
946
947        let guard4 = counter.try_acquire_arc();
948        assert!(guard4.is_some());
949    }
950
951    // =========================================================================
952    // Builder validation tests
953    // =========================================================================
954
955    #[test]
956    fn test_builder_default_succeeds() {
957        // Default configuration should always succeed
958        let config = ServerConfig::builder().build();
959        assert_eq!(config.max_message_size, DEFAULT_MAX_MESSAGE_SIZE);
960        assert!(config.origin_validation.allow_localhost);
961        assert!(config.origin_validation.allowed_origins.is_empty());
962    }
963
964    #[test]
965    fn test_builder_origin_validation_overrides() {
966        let config = ServerConfig::builder()
967            .allow_origin("https://app.example.com")
968            .allow_localhost_origins(false)
969            .build();
970
971        assert!(!config.origin_validation.allow_localhost);
972        assert!(
973            config
974                .origin_validation
975                .allowed_origins
976                .contains("https://app.example.com")
977        );
978    }
979
980    #[test]
981    fn test_builder_try_build_valid() {
982        let result = ServerConfig::builder()
983            .max_message_size(1024 * 1024)
984            .try_build();
985        assert!(result.is_ok());
986    }
987
988    #[test]
989    fn test_builder_try_build_invalid_message_size() {
990        let result = ServerConfig::builder()
991            .max_message_size(100) // Below minimum
992            .try_build();
993        assert!(result.is_err());
994        assert!(matches!(
995            result.unwrap_err(),
996            ConfigValidationError::InvalidMessageSize { .. }
997        ));
998    }
999
1000    #[test]
1001    fn test_builder_try_build_invalid_rate_limit() {
1002        let result = ServerConfig::builder()
1003            .rate_limit(RateLimitConfig {
1004                max_requests: 0, // Invalid
1005                window: Duration::from_secs(1),
1006                per_client: true,
1007            })
1008            .try_build();
1009        assert!(result.is_err());
1010        assert!(matches!(
1011            result.unwrap_err(),
1012            ConfigValidationError::InvalidRateLimit { .. }
1013        ));
1014    }
1015
1016    #[test]
1017    fn test_builder_try_build_zero_window() {
1018        let result = ServerConfig::builder()
1019            .rate_limit(RateLimitConfig {
1020                max_requests: 100,
1021                window: Duration::ZERO, // Invalid
1022                per_client: true,
1023            })
1024            .try_build();
1025        assert!(result.is_err());
1026        assert!(matches!(
1027            result.unwrap_err(),
1028            ConfigValidationError::InvalidRateLimit { .. }
1029        ));
1030    }
1031
1032    #[test]
1033    fn test_builder_try_build_invalid_connection_limits() {
1034        let result = ServerConfig::builder()
1035            .connection_limits(ConnectionLimits {
1036                max_tcp_connections: 0,
1037                max_websocket_connections: 0,
1038                max_http_concurrent: 0,
1039                max_unix_connections: 0,
1040            })
1041            .try_build();
1042        assert!(result.is_err());
1043        assert!(matches!(
1044            result.unwrap_err(),
1045            ConfigValidationError::InvalidConnectionLimits { .. }
1046        ));
1047    }
1048}
1049
1050#[cfg(test)]
1051mod proptest_tests {
1052    use super::*;
1053    use proptest::prelude::*;
1054
1055    proptest! {
1056        #[test]
1057        fn config_builder_never_panics(
1058            max_msg_size in 0usize..10_000_000,
1059        ) {
1060            // Builder should never panic, just return errors for invalid inputs
1061            let _ = ServerConfig::builder()
1062                .max_message_size(max_msg_size)
1063                .try_build();
1064        }
1065
1066        #[test]
1067        fn connection_counter_bounded(max in 1usize..10000) {
1068            let counter = Arc::new(ConnectionCounter::new(max));
1069            let mut guards = Vec::new();
1070            // Should never acquire more than max
1071            for _ in 0..max + 10 {
1072                if let Some(guard) = counter.try_acquire_arc() {
1073                    guards.push(guard);
1074                }
1075            }
1076            assert_eq!(guards.len(), max);
1077            assert_eq!(counter.current(), max);
1078        }
1079    }
1080}