1use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8
9use crate::types::{ClientCapabilities, ServerCapabilities};
10
11#[derive(Debug, Clone)]
13pub struct CapabilityMatcher {
14 compatibility_rules: HashMap<String, CompatibilityRule>,
16 defaults: HashMap<String, bool>,
18}
19
20#[derive(Debug, Clone)]
22pub enum CompatibilityRule {
23 RequireBoth,
25 RequireClient,
27 RequireServer,
29 Optional,
31 Custom(fn(&ClientCapabilities, &ServerCapabilities) -> bool),
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct CapabilitySet {
38 pub enabled_features: HashSet<String>,
40 pub client_capabilities: ClientCapabilities,
42 pub server_capabilities: ServerCapabilities,
44 pub metadata: HashMap<String, serde_json::Value>,
46}
47
48#[derive(Debug, Clone)]
50pub struct CapabilityNegotiator {
51 matcher: CapabilityMatcher,
53 strict_mode: bool,
55}
56
57impl Default for CapabilityMatcher {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl CapabilityMatcher {
64 pub fn new() -> Self {
66 let mut matcher = Self {
67 compatibility_rules: HashMap::new(),
68 defaults: HashMap::new(),
69 };
70
71 matcher.add_rule("tools", CompatibilityRule::RequireServer);
73 matcher.add_rule("prompts", CompatibilityRule::RequireServer);
74 matcher.add_rule("resources", CompatibilityRule::RequireServer);
75 matcher.add_rule("logging", CompatibilityRule::RequireServer);
76 matcher.add_rule("sampling", CompatibilityRule::RequireClient);
77 matcher.add_rule("roots", CompatibilityRule::RequireClient);
78 matcher.add_rule("progress", CompatibilityRule::Optional);
79
80 matcher.set_default("progress", true);
82
83 matcher
84 }
85
86 pub fn add_rule(&mut self, feature: &str, rule: CompatibilityRule) {
88 self.compatibility_rules.insert(feature.to_string(), rule);
89 }
90
91 pub fn set_default(&mut self, feature: &str, enabled: bool) {
93 self.defaults.insert(feature.to_string(), enabled);
94 }
95
96 pub fn is_compatible(
98 &self,
99 feature: &str,
100 client: &ClientCapabilities,
101 server: &ServerCapabilities,
102 ) -> bool {
103 self.compatibility_rules.get(feature).map_or_else(
104 || {
105 Self::client_has_feature(feature, client)
107 || Self::server_has_feature(feature, server)
108 },
109 |rule| match rule {
110 CompatibilityRule::RequireBoth => {
111 Self::client_has_feature(feature, client)
112 && Self::server_has_feature(feature, server)
113 }
114 CompatibilityRule::RequireClient => Self::client_has_feature(feature, client),
115 CompatibilityRule::RequireServer => Self::server_has_feature(feature, server),
116 CompatibilityRule::Optional => true,
117 CompatibilityRule::Custom(func) => func(client, server),
118 },
119 )
120 }
121
122 fn client_has_feature(feature: &str, client: &ClientCapabilities) -> bool {
124 match feature {
125 "sampling" => client.sampling.is_some(),
126 "roots" => client.roots.is_some(),
127 _ => {
128 client
130 .experimental
131 .as_ref()
132 .is_some_and(|experimental| experimental.contains_key(feature))
133 }
134 }
135 }
136
137 fn server_has_feature(feature: &str, server: &ServerCapabilities) -> bool {
139 match feature {
140 "tools" => server.tools.is_some(),
141 "prompts" => server.prompts.is_some(),
142 "resources" => server.resources.is_some(),
143 "logging" => server.logging.is_some(),
144 _ => {
145 server
147 .experimental
148 .as_ref()
149 .is_some_and(|experimental| experimental.contains_key(feature))
150 }
151 }
152 }
153
154 fn get_all_features(
156 &self,
157 client: &ClientCapabilities,
158 server: &ServerCapabilities,
159 ) -> HashSet<String> {
160 let mut features = HashSet::new();
161
162 if client.sampling.is_some() {
164 features.insert("sampling".to_string());
165 }
166 if client.roots.is_some() {
167 features.insert("roots".to_string());
168 }
169
170 if server.tools.is_some() {
172 features.insert("tools".to_string());
173 }
174 if server.prompts.is_some() {
175 features.insert("prompts".to_string());
176 }
177 if server.resources.is_some() {
178 features.insert("resources".to_string());
179 }
180 if server.logging.is_some() {
181 features.insert("logging".to_string());
182 }
183
184 if let Some(experimental) = &client.experimental {
186 features.extend(experimental.keys().cloned());
187 }
188 if let Some(experimental) = &server.experimental {
189 features.extend(experimental.keys().cloned());
190 }
191
192 features.extend(self.defaults.keys().cloned());
194
195 features
196 }
197
198 pub fn negotiate(
200 &self,
201 client: &ClientCapabilities,
202 server: &ServerCapabilities,
203 ) -> Result<CapabilitySet, CapabilityError> {
204 let all_features = self.get_all_features(client, server);
205 let mut enabled_features = HashSet::new();
206 let mut incompatible_features = Vec::new();
207
208 for feature in &all_features {
209 if self.is_compatible(feature, client, server) {
210 enabled_features.insert(feature.clone());
211 } else {
212 incompatible_features.push(feature.clone());
213 }
214 }
215
216 if !incompatible_features.is_empty() {
217 return Err(CapabilityError::IncompatibleFeatures(incompatible_features));
218 }
219
220 for (feature, enabled) in &self.defaults {
222 if *enabled && !enabled_features.contains(feature) && all_features.contains(feature) {
223 enabled_features.insert(feature.clone());
224 }
225 }
226
227 Ok(CapabilitySet {
228 enabled_features,
229 client_capabilities: client.clone(),
230 server_capabilities: server.clone(),
231 metadata: HashMap::new(),
232 })
233 }
234}
235
236impl CapabilityNegotiator {
237 pub const fn new(matcher: CapabilityMatcher) -> Self {
239 Self {
240 matcher,
241 strict_mode: false,
242 }
243 }
244
245 pub const fn with_strict_mode(mut self) -> Self {
247 self.strict_mode = true;
248 self
249 }
250
251 pub fn negotiate(
253 &self,
254 client: &ClientCapabilities,
255 server: &ServerCapabilities,
256 ) -> Result<CapabilitySet, CapabilityError> {
257 match self.matcher.negotiate(client, server) {
258 Ok(capability_set) => Ok(capability_set),
259 Err(CapabilityError::IncompatibleFeatures(features)) if !self.strict_mode => {
260 tracing::warn!(
262 "Some features are incompatible and will be disabled: {:?}",
263 features
264 );
265
266 let all_features = self.matcher.get_all_features(client, server);
268 let mut enabled_features = HashSet::new();
269
270 for feature in &all_features {
271 if self.matcher.is_compatible(feature, client, server) {
272 enabled_features.insert(feature.clone());
273 }
274 }
275
276 Ok(CapabilitySet {
277 enabled_features,
278 client_capabilities: client.clone(),
279 server_capabilities: server.clone(),
280 metadata: HashMap::new(),
281 })
282 }
283 Err(err) => Err(err),
284 }
285 }
286
287 pub fn is_feature_enabled(capability_set: &CapabilitySet, feature: &str) -> bool {
289 capability_set.enabled_features.contains(feature)
290 }
291
292 pub fn get_enabled_features(capability_set: &CapabilitySet) -> Vec<String> {
294 let mut features: Vec<String> = capability_set.enabled_features.iter().cloned().collect();
295 features.sort();
296 features
297 }
298}
299
300impl Default for CapabilityNegotiator {
301 fn default() -> Self {
302 Self::new(CapabilityMatcher::new())
303 }
304}
305
306impl CapabilitySet {
307 pub fn empty() -> Self {
309 Self {
310 enabled_features: HashSet::new(),
311 client_capabilities: ClientCapabilities::default(),
312 server_capabilities: ServerCapabilities::default(),
313 metadata: HashMap::new(),
314 }
315 }
316
317 pub fn has_feature(&self, feature: &str) -> bool {
319 self.enabled_features.contains(feature)
320 }
321
322 pub fn enable_feature(&mut self, feature: String) {
324 self.enabled_features.insert(feature);
325 }
326
327 pub fn disable_feature(&mut self, feature: &str) {
329 self.enabled_features.remove(feature);
330 }
331
332 pub fn feature_count(&self) -> usize {
334 self.enabled_features.len()
335 }
336
337 pub fn add_metadata(&mut self, key: String, value: serde_json::Value) {
339 self.metadata.insert(key, value);
340 }
341
342 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
344 self.metadata.get(key)
345 }
346
347 pub fn summary(&self) -> CapabilitySummary {
349 CapabilitySummary {
350 total_features: self.enabled_features.len(),
351 client_features: self.count_client_features(),
352 server_features: self.count_server_features(),
353 enabled_features: self.enabled_features.iter().cloned().collect(),
354 }
355 }
356
357 fn count_client_features(&self) -> usize {
358 let mut count = 0;
359 if self.client_capabilities.sampling.is_some() {
360 count += 1;
361 }
362 if self.client_capabilities.roots.is_some() {
363 count += 1;
364 }
365 if let Some(experimental) = &self.client_capabilities.experimental {
366 count += experimental.len();
367 }
368 count
369 }
370
371 fn count_server_features(&self) -> usize {
372 let mut count = 0;
373 if self.server_capabilities.tools.is_some() {
374 count += 1;
375 }
376 if self.server_capabilities.prompts.is_some() {
377 count += 1;
378 }
379 if self.server_capabilities.resources.is_some() {
380 count += 1;
381 }
382 if self.server_capabilities.logging.is_some() {
383 count += 1;
384 }
385 if let Some(experimental) = &self.server_capabilities.experimental {
386 count += experimental.len();
387 }
388 count
389 }
390}
391
392#[derive(Debug, Clone, thiserror::Error)]
394pub enum CapabilityError {
395 #[error("Incompatible features: {0:?}")]
397 IncompatibleFeatures(Vec<String>),
398 #[error("Required feature missing: {0}")]
400 RequiredFeatureMissing(String),
401 #[error("Protocol version mismatch: client={client}, server={server}")]
403 VersionMismatch {
404 client: String,
406 server: String,
408 },
409 #[error("Capability negotiation failed: {0}")]
411 NegotiationFailed(String),
412}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct CapabilitySummary {
417 pub total_features: usize,
419 pub client_features: usize,
421 pub server_features: usize,
423 pub enabled_features: Vec<String>,
425}
426
427pub mod utils {
429 use super::*;
430
431 pub fn minimal_client_capabilities() -> ClientCapabilities {
433 ClientCapabilities::default()
434 }
435
436 pub fn minimal_server_capabilities() -> ServerCapabilities {
438 ServerCapabilities::default()
439 }
440
441 pub fn full_client_capabilities() -> ClientCapabilities {
443 ClientCapabilities {
444 sampling: Some(Default::default()),
445 roots: Some(Default::default()),
446 elicitation: Some(Default::default()),
447 experimental: None,
448 }
449 }
450
451 pub fn full_server_capabilities() -> ServerCapabilities {
453 ServerCapabilities {
454 tools: Some(Default::default()),
455 prompts: Some(Default::default()),
456 resources: Some(Default::default()),
457 completions: Some(Default::default()),
458 logging: Some(Default::default()),
459 experimental: None,
460 }
461 }
462
463 pub fn are_compatible(client: &ClientCapabilities, server: &ServerCapabilities) -> bool {
465 let matcher = CapabilityMatcher::new();
466 matcher.negotiate(client, server).is_ok()
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use crate::types::*;
474
475 #[test]
476 fn test_capability_matcher() {
477 let matcher = CapabilityMatcher::new();
478
479 let client = ClientCapabilities {
480 sampling: Some(SamplingCapabilities),
481 roots: None,
482 elicitation: None,
483 experimental: None,
484 };
485
486 let server = ServerCapabilities {
487 tools: Some(ToolsCapabilities::default()),
488 prompts: None,
489 resources: None,
490 logging: None,
491 completions: None,
492 experimental: None,
493 };
494
495 assert!(matcher.is_compatible("sampling", &client, &server));
496 assert!(matcher.is_compatible("tools", &client, &server));
497 assert!(!matcher.is_compatible("roots", &client, &server));
498 }
499
500 #[test]
501 fn test_capability_negotiation() {
502 let negotiator = CapabilityNegotiator::default();
503
504 let client = utils::full_client_capabilities();
505 let server = utils::full_server_capabilities();
506
507 let result = negotiator.negotiate(&client, &server);
508 assert!(result.is_ok());
509
510 let capability_set = result.unwrap();
511 assert!(capability_set.has_feature("sampling"));
512 assert!(capability_set.has_feature("tools"));
513 assert!(capability_set.has_feature("roots"));
514 }
515
516 #[test]
517 fn test_strict_mode() {
518 let negotiator = CapabilityNegotiator::default().with_strict_mode();
519
520 let client = ClientCapabilities::default();
521 let server = ServerCapabilities::default();
522
523 let result = negotiator.negotiate(&client, &server);
524 assert!(result.is_ok()); }
526
527 #[test]
528 fn test_capability_summary() {
529 let mut capability_set = CapabilitySet::empty();
530 capability_set.enable_feature("tools".to_string());
531 capability_set.enable_feature("prompts".to_string());
532
533 let summary = capability_set.summary();
534 assert_eq!(summary.total_features, 2);
535 assert!(summary.enabled_features.contains(&"tools".to_string()));
536 }
537}