Skip to main content

turbomcp_server/
config.rs

1//! Server Configuration
2//!
3//! This module provides configuration options for MCP servers including:
4//! - Protocol version negotiation
5//! - Rate limiting
6//! - Connection limits
7//! - Capability requirements
8
9use std::collections::HashSet;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::time::{Duration, Instant};
13
14use parking_lot::Mutex;
15use serde::{Deserialize, Serialize};
16
17// Re-export from core (single source of truth - DRY)
18pub use turbomcp_core::SUPPORTED_VERSIONS as SUPPORTED_PROTOCOL_VERSIONS;
19pub use turbomcp_types::ProtocolVersion;
20
21/// Default maximum connections for TCP transport.
22pub const DEFAULT_MAX_CONNECTIONS: usize = 1000;
23
24/// Default rate limit (requests per second).
25pub const DEFAULT_RATE_LIMIT: u32 = 100;
26
27/// Default rate limit window.
28pub const DEFAULT_RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
29
30/// Default maximum message size (10MB).
31pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
32
33/// Origin validation configuration for HTTP transports.
34#[derive(Debug, Clone)]
35pub struct OriginValidationConfig {
36    /// Explicitly allowed origins.
37    pub allowed_origins: HashSet<String>,
38    /// Whether to allow localhost/browser-dev origins.
39    pub allow_localhost: bool,
40    /// Whether to disable origin checks entirely.
41    pub allow_any: bool,
42    /// Trusted reverse-proxy IPs / CIDRs whose `X-Forwarded-For`,
43    /// `X-Real-IP`, `CF-Connecting-IP`, and `X-Client-IP` headers we will
44    /// honour. **Empty means trust nothing** — direct clients can no longer
45    /// spoof their source IP via headers. Entries accept CIDR notation
46    /// (`10.0.0.0/8`) or bare addresses (`10.0.0.5`).
47    pub trusted_proxies: Vec<String>,
48}
49
50impl Default for OriginValidationConfig {
51    fn default() -> Self {
52        Self {
53            allowed_origins: HashSet::new(),
54            allow_localhost: true,
55            allow_any: false,
56            trusted_proxies: Vec::new(),
57        }
58    }
59}
60
61impl OriginValidationConfig {
62    /// Create a new origin validation configuration.
63    #[must_use]
64    pub fn new() -> Self {
65        Self::default()
66    }
67}
68
69/// Server configuration.
70#[derive(Debug, Clone)]
71pub struct ServerConfig {
72    /// Protocol version configuration.
73    pub protocol: ProtocolConfig,
74    /// Rate limiting configuration.
75    pub rate_limit: Option<RateLimitConfig>,
76    /// Connection limits.
77    pub connection_limits: ConnectionLimits,
78    /// Required client capabilities.
79    pub required_capabilities: RequiredCapabilities,
80    /// Maximum message size in bytes (default: 10MB).
81    pub max_message_size: usize,
82    /// HTTP origin validation policy.
83    pub origin_validation: OriginValidationConfig,
84}
85
86impl Default for ServerConfig {
87    fn default() -> Self {
88        Self {
89            protocol: ProtocolConfig::default(),
90            rate_limit: None,
91            connection_limits: ConnectionLimits::default(),
92            required_capabilities: RequiredCapabilities::default(),
93            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
94            origin_validation: OriginValidationConfig::default(),
95        }
96    }
97}
98
99impl ServerConfig {
100    /// Create a new server configuration with defaults.
101    #[must_use]
102    pub fn new() -> Self {
103        Self::default()
104    }
105
106    /// Create a builder for server configuration.
107    #[must_use]
108    pub fn builder() -> ServerConfigBuilder {
109        ServerConfigBuilder::default()
110    }
111}
112
113/// Builder for server configuration.
114#[derive(Debug, Clone, Default)]
115pub struct ServerConfigBuilder {
116    protocol: Option<ProtocolConfig>,
117    rate_limit: Option<RateLimitConfig>,
118    connection_limits: Option<ConnectionLimits>,
119    required_capabilities: Option<RequiredCapabilities>,
120    max_message_size: Option<usize>,
121    origin_validation: Option<OriginValidationConfig>,
122}
123
124impl ServerConfigBuilder {
125    /// Set protocol configuration.
126    #[must_use]
127    pub fn protocol(mut self, config: ProtocolConfig) -> Self {
128        self.protocol = Some(config);
129        self
130    }
131
132    /// Set rate limiting configuration.
133    #[must_use]
134    pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
135        self.rate_limit = Some(config);
136        self
137    }
138
139    /// Set connection limits.
140    #[must_use]
141    pub fn connection_limits(mut self, limits: ConnectionLimits) -> Self {
142        self.connection_limits = Some(limits);
143        self
144    }
145
146    /// Set required client capabilities.
147    #[must_use]
148    pub fn required_capabilities(mut self, caps: RequiredCapabilities) -> Self {
149        self.required_capabilities = Some(caps);
150        self
151    }
152
153    /// Set maximum message size in bytes.
154    ///
155    /// Messages exceeding this size will be rejected.
156    /// Default: 10MB.
157    #[must_use]
158    pub fn max_message_size(mut self, size: usize) -> Self {
159        self.max_message_size = Some(size);
160        self
161    }
162
163    /// Set HTTP origin validation configuration.
164    #[must_use]
165    pub fn origin_validation(mut self, config: OriginValidationConfig) -> Self {
166        self.origin_validation = Some(config);
167        self
168    }
169
170    /// Add a single allowed origin for HTTP transports.
171    #[must_use]
172    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
173        self.origin_validation
174            .get_or_insert_with(OriginValidationConfig::default)
175            .allowed_origins
176            .insert(origin.into());
177        self
178    }
179
180    /// Add multiple allowed origins for HTTP transports.
181    #[must_use]
182    pub fn allow_origins<I, S>(mut self, origins: I) -> Self
183    where
184        I: IntoIterator<Item = S>,
185        S: Into<String>,
186    {
187        let config = self
188            .origin_validation
189            .get_or_insert_with(OriginValidationConfig::default);
190        config
191            .allowed_origins
192            .extend(origins.into_iter().map(Into::into));
193        self
194    }
195
196    /// Control whether localhost origins are accepted.
197    #[must_use]
198    pub fn allow_localhost_origins(mut self, allow: bool) -> Self {
199        self.origin_validation
200            .get_or_insert_with(OriginValidationConfig::default)
201            .allow_localhost = allow;
202        self
203    }
204
205    /// Disable origin checks entirely.
206    #[must_use]
207    pub fn allow_any_origin(mut self, allow: bool) -> Self {
208        self.origin_validation
209            .get_or_insert_with(OriginValidationConfig::default)
210            .allow_any = allow;
211        self
212    }
213
214    /// Build the server configuration with sensible defaults.
215    ///
216    /// This method always succeeds and uses defaults for any unset fields.
217    /// For strict validation, use [`try_build()`](Self::try_build).
218    #[must_use]
219    pub fn build(self) -> ServerConfig {
220        ServerConfig {
221            protocol: self.protocol.unwrap_or_default(),
222            rate_limit: self.rate_limit,
223            connection_limits: self.connection_limits.unwrap_or_default(),
224            required_capabilities: self.required_capabilities.unwrap_or_default(),
225            max_message_size: self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE),
226            origin_validation: self.origin_validation.unwrap_or_default(),
227        }
228    }
229
230    /// Build the server configuration with validation.
231    ///
232    /// This method validates the configuration and returns an error if any
233    /// constraints are violated. Use this for stricter configuration checking
234    /// in enterprise deployments.
235    ///
236    /// # Errors
237    ///
238    /// Returns an error if:
239    /// - `max_message_size` is less than 1024 bytes (minimum viable message size)
240    /// - Rate limit `max_requests` is 0
241    /// - Rate limit `window` is zero
242    /// - Connection limits have all values set to 0
243    ///
244    /// # Example
245    ///
246    /// ```rust
247    /// use turbomcp_server::ServerConfig;
248    ///
249    /// // Validated build - catches configuration errors
250    /// let config = ServerConfig::builder()
251    ///     .max_message_size(1024 * 1024) // 1MB
252    ///     .try_build()
253    ///     .expect("Invalid configuration");
254    /// ```
255    pub fn try_build(self) -> Result<ServerConfig, ConfigValidationError> {
256        let max_message_size = self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE);
257
258        // Validate message size
259        if max_message_size < 1024 {
260            return Err(ConfigValidationError::InvalidMessageSize {
261                size: max_message_size,
262                min: 1024,
263            });
264        }
265
266        // Validate rate limit if provided
267        if let Some(ref rate_limit) = self.rate_limit {
268            if rate_limit.max_requests == 0 {
269                return Err(ConfigValidationError::InvalidRateLimit {
270                    reason: "max_requests cannot be 0".to_string(),
271                });
272            }
273            if rate_limit.window.is_zero() {
274                return Err(ConfigValidationError::InvalidRateLimit {
275                    reason: "rate limit window cannot be zero".to_string(),
276                });
277            }
278        }
279
280        // Validate connection limits
281        let connection_limits = self.connection_limits.unwrap_or_default();
282        if connection_limits.max_tcp_connections == 0
283            && connection_limits.max_websocket_connections == 0
284            && connection_limits.max_http_concurrent == 0
285            && connection_limits.max_unix_connections == 0
286        {
287            return Err(ConfigValidationError::InvalidConnectionLimits {
288                reason: "at least one connection limit must be non-zero".to_string(),
289            });
290        }
291
292        Ok(ServerConfig {
293            protocol: self.protocol.unwrap_or_default(),
294            rate_limit: self.rate_limit,
295            connection_limits,
296            required_capabilities: self.required_capabilities.unwrap_or_default(),
297            max_message_size,
298            origin_validation: self.origin_validation.unwrap_or_default(),
299        })
300    }
301}
302
303/// Errors that can occur during configuration validation.
304#[derive(Debug, Clone, thiserror::Error)]
305pub enum ConfigValidationError {
306    /// Invalid message size configuration.
307    #[error("Invalid max_message_size: {size} bytes is below minimum of {min} bytes")]
308    InvalidMessageSize {
309        /// The configured size.
310        size: usize,
311        /// The minimum allowed size.
312        min: usize,
313    },
314
315    /// Invalid rate limit configuration.
316    #[error("Invalid rate limit: {reason}")]
317    InvalidRateLimit {
318        /// Description of the validation failure.
319        reason: String,
320    },
321
322    /// Invalid connection limits configuration.
323    #[error("Invalid connection limits: {reason}")]
324    InvalidConnectionLimits {
325        /// Description of the validation failure.
326        reason: String,
327    },
328}
329
330/// Protocol version configuration.
331#[derive(Debug, Clone)]
332pub struct ProtocolConfig {
333    /// Preferred protocol version.
334    pub preferred_version: ProtocolVersion,
335    /// Supported protocol versions.
336    pub supported_versions: Vec<ProtocolVersion>,
337    /// Allow fallback to server's preferred version if client's is unsupported.
338    pub allow_fallback: bool,
339}
340
341impl Default for ProtocolConfig {
342    /// Multi-version by default in v3.1.
343    ///
344    /// Pre-3.1 the default was exact-match against `LATEST`, which silently
345    /// rejected clients on older stable spec versions even though we ship
346    /// version adapters for them. The spec permits the server to choose a
347    /// different version than the client requested; defaulting to the full
348    /// stable set unblocks older clients while still preferring the newest.
349    /// Use [`Self::strict`] to restore the old single-version behavior.
350    fn default() -> Self {
351        Self {
352            preferred_version: ProtocolVersion::LATEST.clone(),
353            supported_versions: ProtocolVersion::STABLE.to_vec(),
354            allow_fallback: false,
355        }
356    }
357}
358
359impl ProtocolConfig {
360    /// Create a strict configuration that only accepts the specified version.
361    #[must_use]
362    pub fn strict(version: impl Into<ProtocolVersion>) -> Self {
363        let v = version.into();
364        Self {
365            preferred_version: v.clone(),
366            supported_versions: vec![v],
367            allow_fallback: false,
368        }
369    }
370
371    /// Create a multi-version configuration that accepts all stable versions.
372    ///
373    /// The preferred version is the latest stable. Older clients are accepted
374    /// and responses are filtered through the appropriate version adapter.
375    #[must_use]
376    pub fn multi_version() -> Self {
377        Self {
378            preferred_version: ProtocolVersion::LATEST.clone(),
379            supported_versions: ProtocolVersion::STABLE.to_vec(),
380            allow_fallback: false,
381        }
382    }
383
384    /// Check if a protocol version is supported.
385    #[must_use]
386    pub fn is_supported(&self, version: &ProtocolVersion) -> bool {
387        self.supported_versions.contains(version)
388    }
389
390    /// Negotiate protocol version with client.
391    ///
392    /// Returns the negotiated version or None if no compatible version found.
393    #[must_use]
394    pub fn negotiate(&self, client_version: Option<&str>) -> Option<ProtocolVersion> {
395        match client_version {
396            Some(version_str) => {
397                let version = ProtocolVersion::from(version_str);
398                if self.is_supported(&version) {
399                    Some(version)
400                } else if self.allow_fallback {
401                    Some(self.preferred_version.clone())
402                } else {
403                    None
404                }
405            }
406            None => Some(self.preferred_version.clone()),
407        }
408    }
409}
410
411/// Rate limiting configuration.
412#[derive(Debug, Clone)]
413pub struct RateLimitConfig {
414    /// Maximum requests per window.
415    pub max_requests: u32,
416    /// Time window for rate limiting.
417    pub window: Duration,
418    /// Whether to rate limit per client (by user_id or IP).
419    pub per_client: bool,
420}
421
422impl Default for RateLimitConfig {
423    fn default() -> Self {
424        Self {
425            max_requests: DEFAULT_RATE_LIMIT,
426            window: DEFAULT_RATE_LIMIT_WINDOW,
427            per_client: true,
428        }
429    }
430}
431
432impl RateLimitConfig {
433    /// Create a new rate limit configuration.
434    #[must_use]
435    pub fn new(max_requests: u32, window: Duration) -> Self {
436        Self {
437            max_requests,
438            window,
439            per_client: true,
440        }
441    }
442
443    /// Set per-client rate limiting.
444    #[must_use]
445    pub fn per_client(mut self, enabled: bool) -> Self {
446        self.per_client = enabled;
447        self
448    }
449}
450
451/// Connection limits.
452#[derive(Debug, Clone)]
453pub struct ConnectionLimits {
454    /// Maximum concurrent TCP connections.
455    pub max_tcp_connections: usize,
456    /// Maximum concurrent WebSocket connections.
457    pub max_websocket_connections: usize,
458    /// Maximum concurrent HTTP requests.
459    pub max_http_concurrent: usize,
460    /// Maximum concurrent Unix socket connections.
461    pub max_unix_connections: usize,
462}
463
464impl Default for ConnectionLimits {
465    fn default() -> Self {
466        Self {
467            max_tcp_connections: DEFAULT_MAX_CONNECTIONS,
468            max_websocket_connections: DEFAULT_MAX_CONNECTIONS,
469            max_http_concurrent: DEFAULT_MAX_CONNECTIONS,
470            max_unix_connections: DEFAULT_MAX_CONNECTIONS,
471        }
472    }
473}
474
475impl ConnectionLimits {
476    /// Create a new connection limits configuration.
477    #[must_use]
478    pub fn new(max_connections: usize) -> Self {
479        Self {
480            max_tcp_connections: max_connections,
481            max_websocket_connections: max_connections,
482            max_http_concurrent: max_connections,
483            max_unix_connections: max_connections,
484        }
485    }
486}
487
488/// Required client capabilities.
489///
490/// Specifies which client capabilities the server requires.
491#[derive(Debug, Clone, Default, Serialize, Deserialize)]
492pub struct RequiredCapabilities {
493    /// Require roots capability.
494    #[serde(default)]
495    pub roots: bool,
496    /// Require sampling capability.
497    #[serde(default)]
498    pub sampling: bool,
499    /// Require draft extensions.
500    #[serde(default)]
501    pub extensions: HashSet<String>,
502    /// Require experimental capabilities.
503    #[serde(default)]
504    pub experimental: HashSet<String>,
505}
506
507impl RequiredCapabilities {
508    /// Create empty required capabilities (no requirements).
509    #[must_use]
510    pub fn none() -> Self {
511        Self::default()
512    }
513
514    /// Require roots capability.
515    #[must_use]
516    pub fn with_roots(mut self) -> Self {
517        self.roots = true;
518        self
519    }
520
521    /// Require sampling capability.
522    #[must_use]
523    pub fn with_sampling(mut self) -> Self {
524        self.sampling = true;
525        self
526    }
527
528    /// Require a draft extension.
529    #[must_use]
530    pub fn with_extension(mut self, name: impl Into<String>) -> Self {
531        self.extensions.insert(name.into());
532        self
533    }
534
535    /// Require an experimental capability.
536    #[must_use]
537    pub fn with_experimental(mut self, name: impl Into<String>) -> Self {
538        self.experimental.insert(name.into());
539        self
540    }
541
542    /// Check if all required capabilities are present in client capabilities.
543    #[must_use]
544    pub fn validate(&self, client_caps: &ClientCapabilities) -> CapabilityValidation {
545        let mut missing = Vec::new();
546
547        if self.roots && !client_caps.roots {
548            missing.push("roots".to_string());
549        }
550
551        if self.sampling && !client_caps.sampling {
552            missing.push("sampling".to_string());
553        }
554
555        for extension in &self.extensions {
556            if !client_caps.extensions.contains(extension) {
557                missing.push(format!("extensions/{}", extension));
558            }
559        }
560
561        for exp in &self.experimental {
562            if !client_caps.experimental.contains(exp) {
563                missing.push(format!("experimental/{}", exp));
564            }
565        }
566
567        if missing.is_empty() {
568            CapabilityValidation::Valid
569        } else {
570            CapabilityValidation::Missing(missing)
571        }
572    }
573}
574
575/// Client capabilities received during initialization.
576#[derive(Debug, Clone, Default, Serialize, Deserialize)]
577pub struct ClientCapabilities {
578    /// Client supports roots.
579    #[serde(default)]
580    pub roots: bool,
581    /// Client supports sampling.
582    #[serde(default)]
583    pub sampling: bool,
584    /// Client draft extensions.
585    #[serde(default)]
586    pub extensions: HashSet<String>,
587    /// Client experimental capabilities.
588    #[serde(default)]
589    pub experimental: HashSet<String>,
590}
591
592impl ClientCapabilities {
593    /// Parse client capabilities from initialize request params.
594    #[must_use]
595    pub fn from_params(params: &serde_json::Value) -> Self {
596        let caps = params.get("capabilities").cloned().unwrap_or_default();
597
598        Self {
599            roots: caps.get("roots").map(|v| !v.is_null()).unwrap_or(false),
600            sampling: caps.get("sampling").map(|v| !v.is_null()).unwrap_or(false),
601            extensions: caps
602                .get("extensions")
603                .and_then(|v| v.as_object())
604                .map(|obj| obj.keys().cloned().collect())
605                .unwrap_or_default(),
606            experimental: caps
607                .get("experimental")
608                .and_then(|v| v.as_object())
609                .map(|obj| obj.keys().cloned().collect())
610                .unwrap_or_default(),
611        }
612    }
613}
614
615/// Result of capability validation.
616#[derive(Debug, Clone)]
617pub enum CapabilityValidation {
618    /// All required capabilities are present.
619    Valid,
620    /// Some required capabilities are missing.
621    Missing(Vec<String>),
622}
623
624impl CapabilityValidation {
625    /// Check if validation passed.
626    #[must_use]
627    pub fn is_valid(&self) -> bool {
628        matches!(self, Self::Valid)
629    }
630
631    /// Get missing capabilities if any.
632    #[must_use]
633    pub fn missing(&self) -> Option<&[String]> {
634        match self {
635            Self::Valid => None,
636            Self::Missing(caps) => Some(caps),
637        }
638    }
639}
640
641/// Rate limiter using token bucket algorithm.
642#[derive(Debug)]
643pub struct RateLimiter {
644    config: RateLimitConfig,
645    /// Global bucket for non-per-client limiting.
646    global_bucket: Mutex<TokenBucket>,
647    /// Per-client buckets (keyed by client ID).
648    client_buckets: Mutex<std::collections::HashMap<String, TokenBucket>>,
649    /// Last cleanup timestamp for automatic cleanup.
650    last_cleanup: Mutex<Instant>,
651}
652
653impl RateLimiter {
654    /// Create a new rate limiter.
655    #[must_use]
656    pub fn new(config: RateLimitConfig) -> Self {
657        Self {
658            global_bucket: Mutex::new(TokenBucket::new(config.max_requests, config.window)),
659            client_buckets: Mutex::new(std::collections::HashMap::new()),
660            last_cleanup: Mutex::new(Instant::now()),
661            config,
662        }
663    }
664
665    /// Check if a request is allowed.
666    ///
667    /// Returns `true` if allowed, `false` if rate limited.
668    pub fn check(&self, client_id: Option<&str>) -> bool {
669        // Periodic cleanup of stale client buckets (avoid unbounded growth)
670        let needs_cleanup = {
671            let last = self.last_cleanup.lock();
672            last.elapsed() > Duration::from_secs(60)
673        };
674        if needs_cleanup {
675            self.cleanup(Duration::from_secs(300));
676            *self.last_cleanup.lock() = Instant::now();
677        }
678
679        if self.config.per_client {
680            if let Some(id) = client_id {
681                let mut buckets = self.client_buckets.lock();
682                let bucket = buckets.entry(id.to_string()).or_insert_with(|| {
683                    TokenBucket::new(self.config.max_requests, self.config.window)
684                });
685                bucket.try_acquire()
686            } else {
687                // No client ID, use global bucket
688                self.global_bucket.lock().try_acquire()
689            }
690        } else {
691            self.global_bucket.lock().try_acquire()
692        }
693    }
694
695    /// Clean up old client buckets to prevent memory growth.
696    pub fn cleanup(&self, max_age: Duration) {
697        let mut buckets = self.client_buckets.lock();
698        let now = Instant::now();
699        buckets.retain(|_, bucket| now.duration_since(bucket.last_access) < max_age);
700    }
701
702    /// Get the current number of tracked client buckets.
703    #[must_use]
704    pub fn client_bucket_count(&self) -> usize {
705        self.client_buckets.lock().len()
706    }
707}
708
709/// Token bucket for rate limiting.
710#[derive(Debug)]
711struct TokenBucket {
712    tokens: f64,
713    max_tokens: f64,
714    refill_rate: f64, // tokens per second
715    last_refill: Instant,
716    last_access: Instant,
717}
718
719impl TokenBucket {
720    fn new(max_requests: u32, window: Duration) -> Self {
721        let max_tokens = max_requests as f64;
722        let refill_rate = max_tokens / window.as_secs_f64();
723        Self {
724            tokens: max_tokens,
725            max_tokens,
726            refill_rate,
727            last_refill: Instant::now(),
728            last_access: Instant::now(),
729        }
730    }
731
732    fn try_acquire(&mut self) -> bool {
733        let now = Instant::now();
734        let elapsed = now.duration_since(self.last_refill);
735
736        // Only refill if meaningful time has passed (reduces syscalls on burst traffic)
737        if elapsed >= Duration::from_millis(10) {
738            self.tokens =
739                (self.tokens + elapsed.as_secs_f64() * self.refill_rate).min(self.max_tokens);
740            self.last_refill = now;
741        }
742
743        self.last_access = now;
744
745        if self.tokens >= 1.0 {
746            self.tokens -= 1.0;
747            true
748        } else {
749            false
750        }
751    }
752}
753
754/// Connection counter for tracking active connections.
755///
756/// This is designed to be wrapped in `Arc` and shared across async tasks.
757/// Use `try_acquire_arc` to get a guard that can be moved into spawned tasks.
758#[derive(Debug)]
759pub struct ConnectionCounter {
760    current: AtomicUsize,
761    max: usize,
762}
763
764impl ConnectionCounter {
765    /// Create a new connection counter.
766    #[must_use]
767    pub fn new(max: usize) -> Self {
768        Self {
769            current: AtomicUsize::new(0),
770            max,
771        }
772    }
773
774    /// Try to acquire a connection slot (for use when counter is in Arc).
775    ///
776    /// Returns a guard that releases the slot when dropped, or None if at capacity.
777    /// The guard is `Send + 'static` and can be moved into spawned async tasks.
778    ///
779    /// The CAS loop is unbounded — under genuine contention the standard pattern
780    /// is to keep retrying until either the slot is acquired or capacity is hit;
781    /// progress is guaranteed because each iteration either advances `current`
782    /// (acquire success) or witnesses another thread's advance (their CAS won).
783    pub fn try_acquire_arc(self: &Arc<Self>) -> Option<ConnectionGuard> {
784        loop {
785            let current = self.current.load(Ordering::Relaxed);
786            if current >= self.max {
787                return None;
788            }
789            if self
790                .current
791                .compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::Relaxed)
792                .is_ok()
793            {
794                return Some(ConnectionGuard {
795                    counter: Arc::clone(self),
796                });
797            }
798            // Hint to the CPU that we're spinning (avoids pipeline stalls).
799            std::hint::spin_loop();
800        }
801    }
802
803    /// Get current connection count.
804    #[must_use]
805    pub fn current(&self) -> usize {
806        self.current.load(Ordering::Relaxed)
807    }
808
809    /// Get maximum connections.
810    #[must_use]
811    pub fn max(&self) -> usize {
812        self.max
813    }
814
815    fn release(&self) {
816        self.current.fetch_sub(1, Ordering::SeqCst);
817    }
818}
819
820/// Guard that releases a connection slot when dropped.
821///
822/// This guard is `Send + 'static` and can be safely moved into spawned async tasks.
823#[derive(Debug)]
824pub struct ConnectionGuard {
825    counter: Arc<ConnectionCounter>,
826}
827
828impl Drop for ConnectionGuard {
829    fn drop(&mut self) {
830        self.counter.release();
831    }
832}
833
834#[cfg(test)]
835mod tests {
836    use super::*;
837
838    #[test]
839    fn test_protocol_negotiation_exact_match() {
840        let config = ProtocolConfig::default();
841        assert_eq!(
842            config.negotiate(Some("2025-11-25")),
843            Some(ProtocolVersion::V2025_11_25)
844        );
845    }
846
847    #[test]
848    fn test_protocol_negotiation_default_accepts_stable_versions() {
849        // v3.1: default is multi-version (all stable versions). Older spec
850        // versions are accepted; older clients route through version adapters.
851        // Use ProtocolConfig::strict(LATEST) to restore exact-match behavior.
852        let config = ProtocolConfig::default();
853        assert_eq!(
854            config.negotiate(Some("2025-06-18")),
855            Some(ProtocolVersion::V2025_06_18)
856        );
857    }
858
859    #[test]
860    fn test_protocol_negotiation_strict_rejects_older_version() {
861        let config = ProtocolConfig::strict(ProtocolVersion::LATEST.clone());
862        assert_eq!(config.negotiate(Some("2025-06-18")), None);
863    }
864
865    #[test]
866    fn test_protocol_negotiation_multi_version_accepts_older() {
867        let config = ProtocolConfig::multi_version();
868        assert_eq!(
869            config.negotiate(Some("2025-06-18")),
870            Some(ProtocolVersion::V2025_06_18)
871        );
872        assert_eq!(
873            config.negotiate(Some("2025-11-25")),
874            Some(ProtocolVersion::V2025_11_25)
875        );
876    }
877
878    #[test]
879    fn test_protocol_negotiation_none_returns_preferred() {
880        let config = ProtocolConfig::default();
881        assert_eq!(config.negotiate(None), Some(ProtocolVersion::V2025_11_25));
882    }
883
884    #[test]
885    fn test_protocol_negotiation_unknown_version() {
886        let config = ProtocolConfig::default();
887        assert_eq!(config.negotiate(Some("unknown-version")), None);
888    }
889
890    #[test]
891    fn test_protocol_negotiation_strict() {
892        let config = ProtocolConfig::strict("2025-11-25");
893        assert_eq!(config.negotiate(Some("2025-06-18")), None);
894    }
895
896    #[test]
897    fn test_capability_validation() {
898        let required = RequiredCapabilities::none().with_roots();
899        let client = ClientCapabilities {
900            roots: true,
901            ..Default::default()
902        };
903        assert!(required.validate(&client).is_valid());
904
905        let client_missing = ClientCapabilities::default();
906        assert!(!required.validate(&client_missing).is_valid());
907    }
908
909    #[test]
910    fn test_extension_capability_validation() {
911        let required = RequiredCapabilities::none().with_extension("trace");
912        let client = ClientCapabilities {
913            extensions: ["trace".to_string()].into_iter().collect(),
914            ..Default::default()
915        };
916        assert!(required.validate(&client).is_valid());
917
918        let missing = ClientCapabilities::default();
919        let validation = required.validate(&missing);
920        assert!(!validation.is_valid());
921        assert_eq!(
922            validation.missing(),
923            Some(&["extensions/trace".to_string()][..])
924        );
925    }
926
927    #[test]
928    fn test_client_capabilities_parse_extensions() {
929        let params = serde_json::json!({
930            "capabilities": {
931                "extensions": {
932                    "trace": {"version": "1"},
933                    "handoff": {}
934                }
935            }
936        });
937
938        let caps = ClientCapabilities::from_params(&params);
939        assert!(caps.extensions.contains("trace"));
940        assert!(caps.extensions.contains("handoff"));
941    }
942
943    #[test]
944    fn test_rate_limiter() {
945        let config = RateLimitConfig::new(2, Duration::from_secs(1));
946        let limiter = RateLimiter::new(config);
947
948        assert!(limiter.check(None));
949        assert!(limiter.check(None));
950        assert!(!limiter.check(None)); // Should be rate limited
951    }
952
953    #[test]
954    fn test_connection_counter() {
955        let counter = Arc::new(ConnectionCounter::new(2));
956
957        let guard1 = counter.try_acquire_arc();
958        assert!(guard1.is_some());
959        assert_eq!(counter.current(), 1);
960
961        let guard2 = counter.try_acquire_arc();
962        assert!(guard2.is_some());
963        assert_eq!(counter.current(), 2);
964
965        let guard3 = counter.try_acquire_arc();
966        assert!(guard3.is_none()); // At capacity
967
968        drop(guard1);
969        assert_eq!(counter.current(), 1);
970
971        let guard4 = counter.try_acquire_arc();
972        assert!(guard4.is_some());
973    }
974
975    // =========================================================================
976    // Builder validation tests
977    // =========================================================================
978
979    #[test]
980    fn test_builder_default_succeeds() {
981        // Default configuration should always succeed
982        let config = ServerConfig::builder().build();
983        assert_eq!(config.max_message_size, DEFAULT_MAX_MESSAGE_SIZE);
984        assert!(config.origin_validation.allow_localhost);
985        assert!(config.origin_validation.allowed_origins.is_empty());
986    }
987
988    #[test]
989    fn test_builder_origin_validation_overrides() {
990        let config = ServerConfig::builder()
991            .allow_origin("https://app.example.com")
992            .allow_localhost_origins(false)
993            .build();
994
995        assert!(!config.origin_validation.allow_localhost);
996        assert!(
997            config
998                .origin_validation
999                .allowed_origins
1000                .contains("https://app.example.com")
1001        );
1002    }
1003
1004    #[test]
1005    fn test_builder_try_build_valid() {
1006        let result = ServerConfig::builder()
1007            .max_message_size(1024 * 1024)
1008            .try_build();
1009        assert!(result.is_ok());
1010    }
1011
1012    #[test]
1013    fn test_builder_try_build_invalid_message_size() {
1014        let result = ServerConfig::builder()
1015            .max_message_size(100) // Below minimum
1016            .try_build();
1017        assert!(result.is_err());
1018        assert!(matches!(
1019            result.unwrap_err(),
1020            ConfigValidationError::InvalidMessageSize { .. }
1021        ));
1022    }
1023
1024    #[test]
1025    fn test_builder_try_build_invalid_rate_limit() {
1026        let result = ServerConfig::builder()
1027            .rate_limit(RateLimitConfig {
1028                max_requests: 0, // Invalid
1029                window: Duration::from_secs(1),
1030                per_client: true,
1031            })
1032            .try_build();
1033        assert!(result.is_err());
1034        assert!(matches!(
1035            result.unwrap_err(),
1036            ConfigValidationError::InvalidRateLimit { .. }
1037        ));
1038    }
1039
1040    #[test]
1041    fn test_builder_try_build_zero_window() {
1042        let result = ServerConfig::builder()
1043            .rate_limit(RateLimitConfig {
1044                max_requests: 100,
1045                window: Duration::ZERO, // Invalid
1046                per_client: true,
1047            })
1048            .try_build();
1049        assert!(result.is_err());
1050        assert!(matches!(
1051            result.unwrap_err(),
1052            ConfigValidationError::InvalidRateLimit { .. }
1053        ));
1054    }
1055
1056    #[test]
1057    fn test_builder_try_build_invalid_connection_limits() {
1058        let result = ServerConfig::builder()
1059            .connection_limits(ConnectionLimits {
1060                max_tcp_connections: 0,
1061                max_websocket_connections: 0,
1062                max_http_concurrent: 0,
1063                max_unix_connections: 0,
1064            })
1065            .try_build();
1066        assert!(result.is_err());
1067        assert!(matches!(
1068            result.unwrap_err(),
1069            ConfigValidationError::InvalidConnectionLimits { .. }
1070        ));
1071    }
1072}
1073
1074#[cfg(test)]
1075mod proptest_tests {
1076    use super::*;
1077    use proptest::prelude::*;
1078
1079    proptest! {
1080        #[test]
1081        fn config_builder_never_panics(
1082            max_msg_size in 0usize..10_000_000,
1083        ) {
1084            // Builder should never panic, just return errors for invalid inputs
1085            let _ = ServerConfig::builder()
1086                .max_message_size(max_msg_size)
1087                .try_build();
1088        }
1089
1090        #[test]
1091        fn connection_counter_bounded(max in 1usize..10000) {
1092            let counter = Arc::new(ConnectionCounter::new(max));
1093            let mut guards = Vec::new();
1094            // Should never acquire more than max
1095            for _ in 0..max + 10 {
1096                if let Some(guard) = counter.try_acquire_arc() {
1097                    guards.push(guard);
1098                }
1099            }
1100            assert_eq!(guards.len(), max);
1101            assert_eq!(counter.current(), max);
1102        }
1103    }
1104}