turbomcp_protocol/
capabilities.rs

1//! # Capability Negotiation
2//!
3//! This module provides sophisticated capability negotiation and feature detection
4//! for MCP protocol implementations.
5
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8
9use crate::types::{ClientCapabilities, ServerCapabilities};
10
11/// Capability matcher for negotiating features between client and server
12#[derive(Debug, Clone)]
13pub struct CapabilityMatcher {
14    /// Feature compatibility rules
15    compatibility_rules: HashMap<String, CompatibilityRule>,
16    /// Default feature states
17    defaults: HashMap<String, bool>,
18}
19
20/// Compatibility rule for a feature
21#[derive(Debug, Clone)]
22pub enum CompatibilityRule {
23    /// Feature requires both client and server support
24    RequireBoth,
25    /// Feature requires only client support
26    RequireClient,
27    /// Feature requires only server support  
28    RequireServer,
29    /// Feature is optional (either side can enable)
30    Optional,
31    /// Custom compatibility function
32    Custom(fn(&ClientCapabilities, &ServerCapabilities) -> bool),
33}
34
35/// Negotiated capability set
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct CapabilitySet {
38    /// Enabled features
39    pub enabled_features: HashSet<String>,
40    /// Negotiated client capabilities
41    pub client_capabilities: ClientCapabilities,
42    /// Negotiated server capabilities
43    pub server_capabilities: ServerCapabilities,
44    /// Additional metadata from negotiation
45    pub metadata: HashMap<String, serde_json::Value>,
46}
47
48/// Capability negotiator for handling the negotiation process
49#[derive(Debug, Clone)]
50pub struct CapabilityNegotiator {
51    /// Capability matcher
52    matcher: CapabilityMatcher,
53    /// Strict mode (fail on incompatible features)
54    strict_mode: bool,
55}
56
57impl Default for CapabilityMatcher {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl CapabilityMatcher {
64    /// Create a new capability matcher with default MCP rules
65    pub fn new() -> Self {
66        let mut matcher = Self {
67            compatibility_rules: HashMap::new(),
68            defaults: HashMap::new(),
69        };
70
71        // Set up default MCP capability rules
72        matcher.add_rule("tools", CompatibilityRule::RequireServer);
73        matcher.add_rule("prompts", CompatibilityRule::RequireServer);
74        matcher.add_rule("resources", CompatibilityRule::RequireServer);
75        matcher.add_rule("logging", CompatibilityRule::RequireServer);
76        matcher.add_rule("sampling", CompatibilityRule::RequireClient);
77        matcher.add_rule("roots", CompatibilityRule::RequireClient);
78        matcher.add_rule("progress", CompatibilityRule::Optional);
79
80        // Set defaults
81        matcher.set_default("progress", true);
82
83        matcher
84    }
85
86    /// Add a compatibility rule for a feature
87    pub fn add_rule(&mut self, feature: &str, rule: CompatibilityRule) {
88        self.compatibility_rules.insert(feature.to_string(), rule);
89    }
90
91    /// Set default state for a feature
92    pub fn set_default(&mut self, feature: &str, enabled: bool) {
93        self.defaults.insert(feature.to_string(), enabled);
94    }
95
96    /// Check if a feature is compatible between client and server
97    pub fn is_compatible(
98        &self,
99        feature: &str,
100        client: &ClientCapabilities,
101        server: &ServerCapabilities,
102    ) -> bool {
103        self.compatibility_rules.get(feature).map_or_else(
104            || {
105                // Unknown feature - check if either side supports it
106                Self::client_has_feature(feature, client)
107                    || Self::server_has_feature(feature, server)
108            },
109            |rule| match rule {
110                CompatibilityRule::RequireBoth => {
111                    Self::client_has_feature(feature, client)
112                        && Self::server_has_feature(feature, server)
113                }
114                CompatibilityRule::RequireClient => Self::client_has_feature(feature, client),
115                CompatibilityRule::RequireServer => Self::server_has_feature(feature, server),
116                CompatibilityRule::Optional => true,
117                CompatibilityRule::Custom(func) => func(client, server),
118            },
119        )
120    }
121
122    /// Check if client has a specific feature
123    fn client_has_feature(feature: &str, client: &ClientCapabilities) -> bool {
124        match feature {
125            "sampling" => client.sampling.is_some(),
126            "roots" => client.roots.is_some(),
127            _ => {
128                // Check experimental features
129                client
130                    .experimental
131                    .as_ref()
132                    .is_some_and(|experimental| experimental.contains_key(feature))
133            }
134        }
135    }
136
137    /// Check if server has a specific feature
138    fn server_has_feature(feature: &str, server: &ServerCapabilities) -> bool {
139        match feature {
140            "tools" => server.tools.is_some(),
141            "prompts" => server.prompts.is_some(),
142            "resources" => server.resources.is_some(),
143            "logging" => server.logging.is_some(),
144            _ => {
145                // Check experimental features
146                server
147                    .experimental
148                    .as_ref()
149                    .is_some_and(|experimental| experimental.contains_key(feature))
150            }
151        }
152    }
153
154    /// Get all features from both client and server
155    fn get_all_features(
156        &self,
157        client: &ClientCapabilities,
158        server: &ServerCapabilities,
159    ) -> HashSet<String> {
160        let mut features = HashSet::new();
161
162        // Standard client features
163        if client.sampling.is_some() {
164            features.insert("sampling".to_string());
165        }
166        if client.roots.is_some() {
167            features.insert("roots".to_string());
168        }
169
170        // Standard server features
171        if server.tools.is_some() {
172            features.insert("tools".to_string());
173        }
174        if server.prompts.is_some() {
175            features.insert("prompts".to_string());
176        }
177        if server.resources.is_some() {
178            features.insert("resources".to_string());
179        }
180        if server.logging.is_some() {
181            features.insert("logging".to_string());
182        }
183
184        // Experimental features
185        if let Some(experimental) = &client.experimental {
186            features.extend(experimental.keys().cloned());
187        }
188        if let Some(experimental) = &server.experimental {
189            features.extend(experimental.keys().cloned());
190        }
191
192        // Add default features
193        features.extend(self.defaults.keys().cloned());
194
195        features
196    }
197
198    /// Negotiate capabilities between client and server
199    pub fn negotiate(
200        &self,
201        client: &ClientCapabilities,
202        server: &ServerCapabilities,
203    ) -> Result<CapabilitySet, CapabilityError> {
204        let all_features = self.get_all_features(client, server);
205        let mut enabled_features = HashSet::new();
206        let mut incompatible_features = Vec::new();
207
208        for feature in &all_features {
209            if self.is_compatible(feature, client, server) {
210                enabled_features.insert(feature.clone());
211            } else {
212                incompatible_features.push(feature.clone());
213            }
214        }
215
216        if !incompatible_features.is_empty() {
217            return Err(CapabilityError::IncompatibleFeatures(incompatible_features));
218        }
219
220        // Apply defaults for features not explicitly enabled
221        for (feature, enabled) in &self.defaults {
222            if *enabled && !enabled_features.contains(feature) && all_features.contains(feature) {
223                enabled_features.insert(feature.clone());
224            }
225        }
226
227        Ok(CapabilitySet {
228            enabled_features,
229            client_capabilities: client.clone(),
230            server_capabilities: server.clone(),
231            metadata: HashMap::new(),
232        })
233    }
234}
235
236impl CapabilityNegotiator {
237    /// Create a new capability negotiator
238    pub const fn new(matcher: CapabilityMatcher) -> Self {
239        Self {
240            matcher,
241            strict_mode: false,
242        }
243    }
244
245    /// Enable strict mode (fail on any incompatible feature)
246    pub const fn with_strict_mode(mut self) -> Self {
247        self.strict_mode = true;
248        self
249    }
250
251    /// Negotiate capabilities between client and server
252    pub fn negotiate(
253        &self,
254        client: &ClientCapabilities,
255        server: &ServerCapabilities,
256    ) -> Result<CapabilitySet, CapabilityError> {
257        match self.matcher.negotiate(client, server) {
258            Ok(capability_set) => Ok(capability_set),
259            Err(CapabilityError::IncompatibleFeatures(features)) if !self.strict_mode => {
260                // In non-strict mode, just log the incompatible features and continue
261                tracing::warn!(
262                    "Some features are incompatible and will be disabled: {:?}",
263                    features
264                );
265
266                // Create a capability set with only compatible features
267                let all_features = self.matcher.get_all_features(client, server);
268                let mut enabled_features = HashSet::new();
269
270                for feature in &all_features {
271                    if self.matcher.is_compatible(feature, client, server) {
272                        enabled_features.insert(feature.clone());
273                    }
274                }
275
276                Ok(CapabilitySet {
277                    enabled_features,
278                    client_capabilities: client.clone(),
279                    server_capabilities: server.clone(),
280                    metadata: HashMap::new(),
281                })
282            }
283            Err(err) => Err(err),
284        }
285    }
286
287    /// Check if a specific feature is enabled in the capability set
288    pub fn is_feature_enabled(capability_set: &CapabilitySet, feature: &str) -> bool {
289        capability_set.enabled_features.contains(feature)
290    }
291
292    /// Get all enabled features as a sorted vector
293    pub fn get_enabled_features(capability_set: &CapabilitySet) -> Vec<String> {
294        let mut features: Vec<String> = capability_set.enabled_features.iter().cloned().collect();
295        features.sort();
296        features
297    }
298}
299
300impl Default for CapabilityNegotiator {
301    fn default() -> Self {
302        Self::new(CapabilityMatcher::new())
303    }
304}
305
306impl CapabilitySet {
307    /// Create a new empty capability set
308    pub fn empty() -> Self {
309        Self {
310            enabled_features: HashSet::new(),
311            client_capabilities: ClientCapabilities::default(),
312            server_capabilities: ServerCapabilities::default(),
313            metadata: HashMap::new(),
314        }
315    }
316
317    /// Check if a feature is enabled
318    pub fn has_feature(&self, feature: &str) -> bool {
319        self.enabled_features.contains(feature)
320    }
321
322    /// Add a feature to the enabled set
323    pub fn enable_feature(&mut self, feature: String) {
324        self.enabled_features.insert(feature);
325    }
326
327    /// Remove a feature from the enabled set
328    pub fn disable_feature(&mut self, feature: &str) {
329        self.enabled_features.remove(feature);
330    }
331
332    /// Get the number of enabled features
333    pub fn feature_count(&self) -> usize {
334        self.enabled_features.len()
335    }
336
337    /// Add metadata
338    pub fn add_metadata(&mut self, key: String, value: serde_json::Value) {
339        self.metadata.insert(key, value);
340    }
341
342    /// Get metadata
343    pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
344        self.metadata.get(key)
345    }
346
347    /// Create a summary of enabled capabilities
348    pub fn summary(&self) -> CapabilitySummary {
349        CapabilitySummary {
350            total_features: self.enabled_features.len(),
351            client_features: self.count_client_features(),
352            server_features: self.count_server_features(),
353            enabled_features: self.enabled_features.iter().cloned().collect(),
354        }
355    }
356
357    fn count_client_features(&self) -> usize {
358        let mut count = 0;
359        if self.client_capabilities.sampling.is_some() {
360            count += 1;
361        }
362        if self.client_capabilities.roots.is_some() {
363            count += 1;
364        }
365        if let Some(experimental) = &self.client_capabilities.experimental {
366            count += experimental.len();
367        }
368        count
369    }
370
371    fn count_server_features(&self) -> usize {
372        let mut count = 0;
373        if self.server_capabilities.tools.is_some() {
374            count += 1;
375        }
376        if self.server_capabilities.prompts.is_some() {
377            count += 1;
378        }
379        if self.server_capabilities.resources.is_some() {
380            count += 1;
381        }
382        if self.server_capabilities.logging.is_some() {
383            count += 1;
384        }
385        if let Some(experimental) = &self.server_capabilities.experimental {
386            count += experimental.len();
387        }
388        count
389    }
390}
391
392/// Capability negotiation errors
393#[derive(Debug, Clone, thiserror::Error)]
394pub enum CapabilityError {
395    /// Features are incompatible between client and server
396    #[error("Incompatible features: {0:?}")]
397    IncompatibleFeatures(Vec<String>),
398    /// Required feature is missing
399    #[error("Required feature missing: {0}")]
400    RequiredFeatureMissing(String),
401    /// Protocol version mismatch
402    #[error("Protocol version mismatch: client={client}, server={server}")]
403    VersionMismatch {
404        /// Client version string
405        client: String,
406        /// Server version string
407        server: String,
408    },
409    /// Capability negotiation failed
410    #[error("Capability negotiation failed: {0}")]
411    NegotiationFailed(String),
412}
413
414/// Summary of capability negotiation results
415#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct CapabilitySummary {
417    /// Total number of enabled features
418    pub total_features: usize,
419    /// Number of client-side features
420    pub client_features: usize,
421    /// Number of server-side features
422    pub server_features: usize,
423    /// List of enabled features
424    pub enabled_features: Vec<String>,
425}
426
427/// Utility functions for capability management
428pub mod utils {
429    use super::*;
430
431    /// Create a minimal client capability set
432    pub fn minimal_client_capabilities() -> ClientCapabilities {
433        ClientCapabilities::default()
434    }
435
436    /// Create a minimal server capability set
437    pub fn minimal_server_capabilities() -> ServerCapabilities {
438        ServerCapabilities::default()
439    }
440
441    /// Create a full-featured client capability set
442    pub fn full_client_capabilities() -> ClientCapabilities {
443        ClientCapabilities {
444            sampling: Some(Default::default()),
445            roots: Some(Default::default()),
446            elicitation: Some(Default::default()),
447            experimental: None,
448        }
449    }
450
451    /// Create a full-featured server capability set
452    pub fn full_server_capabilities() -> ServerCapabilities {
453        ServerCapabilities {
454            tools: Some(Default::default()),
455            prompts: Some(Default::default()),
456            resources: Some(Default::default()),
457            completions: Some(Default::default()),
458            logging: Some(Default::default()),
459            experimental: None,
460        }
461    }
462
463    /// Check if two capability sets are compatible
464    pub fn are_compatible(client: &ClientCapabilities, server: &ServerCapabilities) -> bool {
465        let matcher = CapabilityMatcher::new();
466        matcher.negotiate(client, server).is_ok()
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use crate::types::*;
474
475    #[test]
476    fn test_capability_matcher() {
477        let matcher = CapabilityMatcher::new();
478
479        let client = ClientCapabilities {
480            sampling: Some(SamplingCapabilities),
481            roots: None,
482            elicitation: None,
483            experimental: None,
484        };
485
486        let server = ServerCapabilities {
487            tools: Some(ToolsCapabilities::default()),
488            prompts: None,
489            resources: None,
490            logging: None,
491            completions: None,
492            experimental: None,
493        };
494
495        assert!(matcher.is_compatible("sampling", &client, &server));
496        assert!(matcher.is_compatible("tools", &client, &server));
497        assert!(!matcher.is_compatible("roots", &client, &server));
498    }
499
500    #[test]
501    fn test_capability_negotiation() {
502        let negotiator = CapabilityNegotiator::default();
503
504        let client = utils::full_client_capabilities();
505        let server = utils::full_server_capabilities();
506
507        let result = negotiator.negotiate(&client, &server);
508        assert!(result.is_ok());
509
510        let capability_set = result.unwrap();
511        assert!(capability_set.has_feature("sampling"));
512        assert!(capability_set.has_feature("tools"));
513        assert!(capability_set.has_feature("roots"));
514    }
515
516    #[test]
517    fn test_strict_mode() {
518        let negotiator = CapabilityNegotiator::default().with_strict_mode();
519
520        let client = ClientCapabilities::default();
521        let server = ServerCapabilities::default();
522
523        let result = negotiator.negotiate(&client, &server);
524        assert!(result.is_ok()); // Should still work with minimal capabilities
525    }
526
527    #[test]
528    fn test_capability_summary() {
529        let mut capability_set = CapabilitySet::empty();
530        capability_set.enable_feature("tools".to_string());
531        capability_set.enable_feature("prompts".to_string());
532
533        let summary = capability_set.summary();
534        assert_eq!(summary.total_features, 2);
535        assert!(summary.enabled_features.contains(&"tools".to_string()));
536    }
537}