1use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8
9use crate::types::{ClientCapabilities, ServerCapabilities};
10
11#[derive(Debug, Clone)]
13pub struct CapabilityMatcher {
14 compatibility_rules: HashMap<String, CompatibilityRule>,
16 defaults: HashMap<String, bool>,
18}
19
20#[derive(Debug, Clone)]
22pub enum CompatibilityRule {
23 RequireBoth,
25 RequireClient,
27 RequireServer,
29 Optional,
31 Custom(fn(&ClientCapabilities, &ServerCapabilities) -> bool),
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct CapabilitySet {
38 pub enabled_features: HashSet<String>,
40 pub client_capabilities: ClientCapabilities,
42 pub server_capabilities: ServerCapabilities,
44 pub metadata: HashMap<String, serde_json::Value>,
46}
47
48#[derive(Debug, Clone)]
50pub struct CapabilityNegotiator {
51 matcher: CapabilityMatcher,
53 strict_mode: bool,
55}
56
57impl Default for CapabilityMatcher {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl CapabilityMatcher {
64 pub fn new() -> Self {
66 let mut matcher = Self {
67 compatibility_rules: HashMap::new(),
68 defaults: HashMap::new(),
69 };
70
71 matcher.add_rule("tools", CompatibilityRule::RequireServer);
73 matcher.add_rule("prompts", CompatibilityRule::RequireServer);
74 matcher.add_rule("resources", CompatibilityRule::RequireServer);
75 matcher.add_rule("logging", CompatibilityRule::RequireServer);
76 matcher.add_rule("sampling", CompatibilityRule::RequireClient);
77 matcher.add_rule("roots", CompatibilityRule::RequireClient);
78 matcher.add_rule("progress", CompatibilityRule::Optional);
79
80 matcher.set_default("progress", true);
82
83 matcher
84 }
85
86 pub fn add_rule(&mut self, feature: &str, rule: CompatibilityRule) {
88 self.compatibility_rules.insert(feature.to_string(), rule);
89 }
90
91 pub fn set_default(&mut self, feature: &str, enabled: bool) {
93 self.defaults.insert(feature.to_string(), enabled);
94 }
95
96 pub fn is_compatible(
98 &self,
99 feature: &str,
100 client: &ClientCapabilities,
101 server: &ServerCapabilities,
102 ) -> bool {
103 self.compatibility_rules.get(feature).map_or_else(
104 || {
105 Self::client_has_feature(feature, client)
107 || Self::server_has_feature(feature, server)
108 },
109 |rule| match rule {
110 CompatibilityRule::RequireBoth => {
111 Self::client_has_feature(feature, client)
112 && Self::server_has_feature(feature, server)
113 }
114 CompatibilityRule::RequireClient => Self::client_has_feature(feature, client),
115 CompatibilityRule::RequireServer => Self::server_has_feature(feature, server),
116 CompatibilityRule::Optional => true,
117 CompatibilityRule::Custom(func) => func(client, server),
118 },
119 )
120 }
121
122 fn client_has_feature(feature: &str, client: &ClientCapabilities) -> bool {
124 match feature {
125 "sampling" => client.sampling.is_some(),
126 "roots" => client.roots.is_some(),
127 _ => {
128 client
130 .experimental
131 .as_ref()
132 .is_some_and(|experimental| experimental.contains_key(feature))
133 }
134 }
135 }
136
137 fn server_has_feature(feature: &str, server: &ServerCapabilities) -> bool {
139 match feature {
140 "tools" => server.tools.is_some(),
141 "prompts" => server.prompts.is_some(),
142 "resources" => server.resources.is_some(),
143 "logging" => server.logging.is_some(),
144 _ => {
145 server
147 .experimental
148 .as_ref()
149 .is_some_and(|experimental| experimental.contains_key(feature))
150 }
151 }
152 }
153
154 fn get_all_features(
156 &self,
157 client: &ClientCapabilities,
158 server: &ServerCapabilities,
159 ) -> HashSet<String> {
160 let mut features = HashSet::new();
161
162 if client.sampling.is_some() {
164 features.insert("sampling".to_string());
165 }
166 if client.roots.is_some() {
167 features.insert("roots".to_string());
168 }
169
170 if server.tools.is_some() {
172 features.insert("tools".to_string());
173 }
174 if server.prompts.is_some() {
175 features.insert("prompts".to_string());
176 }
177 if server.resources.is_some() {
178 features.insert("resources".to_string());
179 }
180 if server.logging.is_some() {
181 features.insert("logging".to_string());
182 }
183
184 if let Some(experimental) = &client.experimental {
186 features.extend(experimental.keys().cloned());
187 }
188 if let Some(experimental) = &server.experimental {
189 features.extend(experimental.keys().cloned());
190 }
191
192 features.extend(self.defaults.keys().cloned());
194
195 features
196 }
197
198 pub fn negotiate(
200 &self,
201 client: &ClientCapabilities,
202 server: &ServerCapabilities,
203 ) -> Result<CapabilitySet, CapabilityError> {
204 let all_features = self.get_all_features(client, server);
205 let mut enabled_features = HashSet::new();
206 let mut incompatible_features = Vec::new();
207
208 for feature in &all_features {
209 if self.is_compatible(feature, client, server) {
210 enabled_features.insert(feature.clone());
211 } else {
212 incompatible_features.push(feature.clone());
213 }
214 }
215
216 if !incompatible_features.is_empty() {
217 return Err(CapabilityError::IncompatibleFeatures(incompatible_features));
218 }
219
220 for (feature, enabled) in &self.defaults {
222 if *enabled && !enabled_features.contains(feature) && all_features.contains(feature) {
223 enabled_features.insert(feature.clone());
224 }
225 }
226
227 Ok(CapabilitySet {
228 enabled_features,
229 client_capabilities: client.clone(),
230 server_capabilities: server.clone(),
231 metadata: HashMap::new(),
232 })
233 }
234}
235
236impl CapabilityNegotiator {
237 pub const fn new(matcher: CapabilityMatcher) -> Self {
239 Self {
240 matcher,
241 strict_mode: false,
242 }
243 }
244
245 pub const fn with_strict_mode(mut self) -> Self {
247 self.strict_mode = true;
248 self
249 }
250
251 pub fn negotiate(
253 &self,
254 client: &ClientCapabilities,
255 server: &ServerCapabilities,
256 ) -> Result<CapabilitySet, CapabilityError> {
257 match self.matcher.negotiate(client, server) {
258 Ok(capability_set) => Ok(capability_set),
259 Err(CapabilityError::IncompatibleFeatures(features)) if !self.strict_mode => {
260 tracing::warn!(
262 "Some features are incompatible and will be disabled: {:?}",
263 features
264 );
265
266 let all_features = self.matcher.get_all_features(client, server);
268 let mut enabled_features = HashSet::new();
269
270 for feature in &all_features {
271 if self.matcher.is_compatible(feature, client, server) {
272 enabled_features.insert(feature.clone());
273 }
274 }
275
276 Ok(CapabilitySet {
277 enabled_features,
278 client_capabilities: client.clone(),
279 server_capabilities: server.clone(),
280 metadata: HashMap::new(),
281 })
282 }
283 Err(err) => Err(err),
284 }
285 }
286
287 pub fn is_feature_enabled(capability_set: &CapabilitySet, feature: &str) -> bool {
289 capability_set.enabled_features.contains(feature)
290 }
291
292 pub fn get_enabled_features(capability_set: &CapabilitySet) -> Vec<String> {
294 let mut features: Vec<String> = capability_set.enabled_features.iter().cloned().collect();
295 features.sort();
296 features
297 }
298}
299
300impl Default for CapabilityNegotiator {
301 fn default() -> Self {
302 Self::new(CapabilityMatcher::new())
303 }
304}
305
306impl CapabilitySet {
307 pub fn empty() -> Self {
309 Self {
310 enabled_features: HashSet::new(),
311 client_capabilities: ClientCapabilities::default(),
312 server_capabilities: ServerCapabilities::default(),
313 metadata: HashMap::new(),
314 }
315 }
316
317 pub fn has_feature(&self, feature: &str) -> bool {
319 self.enabled_features.contains(feature)
320 }
321
322 pub fn enable_feature(&mut self, feature: String) {
324 self.enabled_features.insert(feature);
325 }
326
327 pub fn disable_feature(&mut self, feature: &str) {
329 self.enabled_features.remove(feature);
330 }
331
332 pub fn feature_count(&self) -> usize {
334 self.enabled_features.len()
335 }
336
337 pub fn add_metadata(&mut self, key: String, value: serde_json::Value) {
339 self.metadata.insert(key, value);
340 }
341
342 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
344 self.metadata.get(key)
345 }
346
347 pub fn summary(&self) -> CapabilitySummary {
349 CapabilitySummary {
350 total_features: self.enabled_features.len(),
351 client_features: self.count_client_features(),
352 server_features: self.count_server_features(),
353 enabled_features: self.enabled_features.iter().cloned().collect(),
354 }
355 }
356
357 fn count_client_features(&self) -> usize {
358 let mut count = 0;
359 if self.client_capabilities.sampling.is_some() {
360 count += 1;
361 }
362 if self.client_capabilities.roots.is_some() {
363 count += 1;
364 }
365 if let Some(experimental) = &self.client_capabilities.experimental {
366 count += experimental.len();
367 }
368 count
369 }
370
371 fn count_server_features(&self) -> usize {
372 let mut count = 0;
373 if self.server_capabilities.tools.is_some() {
374 count += 1;
375 }
376 if self.server_capabilities.prompts.is_some() {
377 count += 1;
378 }
379 if self.server_capabilities.resources.is_some() {
380 count += 1;
381 }
382 if self.server_capabilities.logging.is_some() {
383 count += 1;
384 }
385 if let Some(experimental) = &self.server_capabilities.experimental {
386 count += experimental.len();
387 }
388 count
389 }
390}
391
392#[derive(Debug, Clone, thiserror::Error)]
394pub enum CapabilityError {
395 #[error("Incompatible features: {0:?}")]
397 IncompatibleFeatures(Vec<String>),
398 #[error("Required feature missing: {0}")]
400 RequiredFeatureMissing(String),
401 #[error("Protocol version mismatch: client={client}, server={server}")]
403 VersionMismatch {
404 client: String,
406 server: String,
408 },
409 #[error("Capability negotiation failed: {0}")]
411 NegotiationFailed(String),
412}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct CapabilitySummary {
417 pub total_features: usize,
419 pub client_features: usize,
421 pub server_features: usize,
423 pub enabled_features: Vec<String>,
425}
426
427pub mod utils {
429 use super::*;
430
431 pub fn minimal_client_capabilities() -> ClientCapabilities {
433 ClientCapabilities::default()
434 }
435
436 pub fn minimal_server_capabilities() -> ServerCapabilities {
438 ServerCapabilities::default()
439 }
440
441 pub fn full_client_capabilities() -> ClientCapabilities {
443 ClientCapabilities {
444 sampling: Some(Default::default()),
445 roots: Some(Default::default()),
446 elicitation: Some(Default::default()),
447 experimental: None,
448 }
449 }
450
451 pub fn full_server_capabilities() -> ServerCapabilities {
453 ServerCapabilities {
454 tools: Some(Default::default()),
455 prompts: Some(Default::default()),
456 resources: Some(Default::default()),
457 completions: Some(Default::default()),
458 logging: Some(Default::default()),
459 experimental: None,
460 }
461 }
462
463 pub fn are_compatible(client: &ClientCapabilities, server: &ServerCapabilities) -> bool {
465 let matcher = CapabilityMatcher::new();
466 matcher.negotiate(client, server).is_ok()
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use crate::types::*;
474
475 #[test]
476 fn test_capability_matcher() {
477 let matcher = CapabilityMatcher::new();
478
479 let client = ClientCapabilities {
480 sampling: Some(SamplingCapabilities),
481 roots: None,
482 elicitation: None,
483 experimental: None,
484 };
485
486 let server = ServerCapabilities {
487 tools: Some(ToolsCapabilities::default()),
488 prompts: None,
489 resources: None,
490 logging: None,
491 completions: None,
492 experimental: None,
493 };
494
495 assert!(matcher.is_compatible("sampling", &client, &server));
496 assert!(matcher.is_compatible("tools", &client, &server));
497 assert!(!matcher.is_compatible("roots", &client, &server));
498 }
499
500 #[test]
501 fn test_capability_negotiation() {
502 let negotiator = CapabilityNegotiator::default();
503
504 let client = utils::full_client_capabilities();
505 let server = utils::full_server_capabilities();
506
507 let result = negotiator.negotiate(&client, &server);
508 assert!(result.is_ok());
509
510 let capability_set = result.unwrap();
511 assert!(capability_set.has_feature("sampling"));
512 assert!(capability_set.has_feature("tools"));
513 assert!(capability_set.has_feature("roots"));
514 }
515
516 #[test]
517 fn test_strict_mode() {
518 let negotiator = CapabilityNegotiator::default().with_strict_mode();
519
520 let client = ClientCapabilities::default();
521 let server = ServerCapabilities::default();
522
523 let result = negotiator.negotiate(&client, &server);
524 assert!(result.is_ok()); }
526
527 #[test]
528 fn test_capability_summary() {
529 let mut capability_set = CapabilitySet::empty();
530 capability_set.enable_feature("tools".to_string());
531 capability_set.enable_feature("prompts".to_string());
532
533 let summary = capability_set.summary();
534 assert_eq!(summary.total_features, 2);
535 assert!(summary.enabled_features.contains(&"tools".to_string()));
536 }
537}
538
539pub mod builders {
549 use crate::types::{
550 ClientCapabilities, CompletionCapabilities, ElicitationCapabilities, LoggingCapabilities,
551 PromptsCapabilities, ResourcesCapabilities, RootsCapabilities, SamplingCapabilities,
552 ServerCapabilities, ToolsCapabilities,
553 };
554 use serde_json;
555 use std::collections::HashMap;
556 use std::marker::PhantomData;
557
558 #[derive(Debug, Clone)]
572 pub struct ServerCapabilitiesBuilderState<
573 const EXPERIMENTAL: bool = false,
574 const LOGGING: bool = false,
575 const COMPLETIONS: bool = false,
576 const PROMPTS: bool = false,
577 const RESOURCES: bool = false,
578 const TOOLS: bool = false,
579 >;
580
581 #[derive(Debug, Clone)]
587 pub struct ServerCapabilitiesBuilder<S = ServerCapabilitiesBuilderState> {
588 experimental: Option<HashMap<String, serde_json::Value>>,
589 logging: Option<LoggingCapabilities>,
590 completions: Option<CompletionCapabilities>,
591 prompts: Option<PromptsCapabilities>,
592 resources: Option<ResourcesCapabilities>,
593 tools: Option<ToolsCapabilities>,
594
595 negotiator: Option<super::CapabilityNegotiator>,
597 strict_validation: bool,
598
599 _state: PhantomData<S>,
600 }
601
602 impl ServerCapabilities {
603 pub fn builder() -> ServerCapabilitiesBuilder {
608 ServerCapabilitiesBuilder::new()
609 }
610 }
611
612 impl Default for ServerCapabilitiesBuilder {
613 fn default() -> Self {
614 Self::new()
615 }
616 }
617
618 impl ServerCapabilitiesBuilder {
619 pub fn new() -> Self {
621 Self {
622 experimental: None,
623 logging: None,
624 completions: None,
625 prompts: None,
626 resources: None,
627 tools: None,
628 negotiator: None,
629 strict_validation: false,
630 _state: PhantomData,
631 }
632 }
633 }
634
635 impl<S> ServerCapabilitiesBuilder<S> {
637 pub fn build(self) -> ServerCapabilities {
642 ServerCapabilities {
643 experimental: self.experimental,
644 logging: self.logging,
645 completions: self.completions,
646 prompts: self.prompts,
647 resources: self.resources,
648 tools: self.tools,
649 }
650 }
651
652 pub fn with_strict_validation(mut self) -> Self {
657 self.strict_validation = true;
658 self
659 }
660
661 pub fn with_negotiator(mut self, negotiator: super::CapabilityNegotiator) -> Self {
666 self.negotiator = Some(negotiator);
667 self
668 }
669
670 pub fn validate(&self) -> Result<(), String> {
675 if self.strict_validation {
676 if self.tools.is_none() && self.prompts.is_none() && self.resources.is_none() {
678 return Err("Server must provide at least one capability (tools, prompts, or resources)".to_string());
679 }
680
681 if let Some(ref experimental) = self.experimental {
683 for (key, value) in experimental {
684 if key.starts_with("turbomcp_") {
685 match key.as_str() {
687 "turbomcp_simd_level" => {
688 if !value.is_string() {
689 return Err(
690 "turbomcp_simd_level must be a string".to_string()
691 );
692 }
693 let level = value.as_str().unwrap_or("");
694 if !["none", "sse2", "sse4", "avx2", "avx512"].contains(&level)
695 {
696 return Err(format!("Invalid SIMD level: {}", level));
697 }
698 }
699 "turbomcp_enterprise_security" => {
700 if !value.is_boolean() {
701 return Err(
702 "turbomcp_enterprise_security must be a boolean"
703 .to_string(),
704 );
705 }
706 }
707 _ => {
708 }
710 }
711 }
712 }
713 }
714 }
715 Ok(())
716 }
717
718 pub fn summary(&self) -> String {
722 let mut capabilities = Vec::new();
723 if self.experimental.is_some() {
724 capabilities.push("experimental");
725 }
726 if self.logging.is_some() {
727 capabilities.push("logging");
728 }
729 if self.completions.is_some() {
730 capabilities.push("completions");
731 }
732 if self.prompts.is_some() {
733 capabilities.push("prompts");
734 }
735 if self.resources.is_some() {
736 capabilities.push("resources");
737 }
738 if self.tools.is_some() {
739 capabilities.push("tools");
740 }
741
742 if capabilities.is_empty() {
743 "No capabilities enabled".to_string()
744 } else {
745 format!("Enabled capabilities: {}", capabilities.join(", "))
746 }
747 }
748 }
749
750 impl<const L: bool, const C: bool, const P: bool, const R: bool, const T: bool>
756 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<false, L, C, P, R, T>>
757 {
758 pub fn enable_experimental(
763 self,
764 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<true, L, C, P, R, T>>
765 {
766 ServerCapabilitiesBuilder {
767 experimental: Some(HashMap::new()),
768 logging: self.logging,
769 completions: self.completions,
770 prompts: self.prompts,
771 resources: self.resources,
772 tools: self.tools,
773 negotiator: self.negotiator,
774 strict_validation: self.strict_validation,
775 _state: PhantomData,
776 }
777 }
778
779 pub fn enable_experimental_with(
781 self,
782 experimental: HashMap<String, serde_json::Value>,
783 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<true, L, C, P, R, T>>
784 {
785 ServerCapabilitiesBuilder {
786 experimental: Some(experimental),
787 logging: self.logging,
788 completions: self.completions,
789 prompts: self.prompts,
790 resources: self.resources,
791 tools: self.tools,
792 negotiator: self.negotiator,
793 strict_validation: self.strict_validation,
794 _state: PhantomData,
795 }
796 }
797 }
798
799 impl<const E: bool, const C: bool, const P: bool, const R: bool, const T: bool>
801 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, false, C, P, R, T>>
802 {
803 pub fn enable_logging(
805 self,
806 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, true, C, P, R, T>>
807 {
808 ServerCapabilitiesBuilder {
809 experimental: self.experimental,
810 logging: Some(LoggingCapabilities),
811 completions: self.completions,
812 prompts: self.prompts,
813 resources: self.resources,
814 tools: self.tools,
815 negotiator: self.negotiator,
816 strict_validation: self.strict_validation,
817 _state: PhantomData,
818 }
819 }
820 }
821
822 impl<const E: bool, const L: bool, const P: bool, const R: bool, const T: bool>
824 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, false, P, R, T>>
825 {
826 pub fn enable_completions(
828 self,
829 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, true, P, R, T>>
830 {
831 ServerCapabilitiesBuilder {
832 experimental: self.experimental,
833 logging: self.logging,
834 completions: Some(CompletionCapabilities),
835 prompts: self.prompts,
836 resources: self.resources,
837 tools: self.tools,
838 negotiator: self.negotiator,
839 strict_validation: self.strict_validation,
840 _state: PhantomData,
841 }
842 }
843 }
844
845 impl<const E: bool, const L: bool, const C: bool, const R: bool, const T: bool>
847 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, false, R, T>>
848 {
849 pub fn enable_prompts(
851 self,
852 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, true, R, T>>
853 {
854 ServerCapabilitiesBuilder {
855 experimental: self.experimental,
856 logging: self.logging,
857 completions: self.completions,
858 prompts: Some(PromptsCapabilities { list_changed: None }),
859 resources: self.resources,
860 tools: self.tools,
861 negotiator: self.negotiator,
862 strict_validation: self.strict_validation,
863 _state: PhantomData,
864 }
865 }
866 }
867
868 impl<const E: bool, const L: bool, const C: bool, const P: bool, const T: bool>
870 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, false, T>>
871 {
872 pub fn enable_resources(
874 self,
875 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, true, T>>
876 {
877 ServerCapabilitiesBuilder {
878 experimental: self.experimental,
879 logging: self.logging,
880 completions: self.completions,
881 prompts: self.prompts,
882 resources: Some(ResourcesCapabilities {
883 subscribe: None,
884 list_changed: None,
885 }),
886 tools: self.tools,
887 negotiator: self.negotiator,
888 strict_validation: self.strict_validation,
889 _state: PhantomData,
890 }
891 }
892 }
893
894 impl<const E: bool, const L: bool, const C: bool, const P: bool, const R: bool>
896 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, R, false>>
897 {
898 pub fn enable_tools(
900 self,
901 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, R, true>>
902 {
903 ServerCapabilitiesBuilder {
904 experimental: self.experimental,
905 logging: self.logging,
906 completions: self.completions,
907 prompts: self.prompts,
908 resources: self.resources,
909 tools: Some(ToolsCapabilities { list_changed: None }),
910 negotiator: self.negotiator,
911 strict_validation: self.strict_validation,
912 _state: PhantomData,
913 }
914 }
915 }
916
917 impl<const E: bool, const L: bool, const C: bool, const P: bool, const R: bool>
923 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, R, true>>
924 {
925 pub fn enable_tool_list_changed(mut self) -> Self {
930 if let Some(ref mut tools) = self.tools {
931 tools.list_changed = Some(true);
932 }
933 self
934 }
935 }
936
937 impl<const E: bool, const L: bool, const C: bool, const R: bool, const T: bool>
939 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, true, R, T>>
940 {
941 pub fn enable_prompts_list_changed(mut self) -> Self {
943 if let Some(ref mut prompts) = self.prompts {
944 prompts.list_changed = Some(true);
945 }
946 self
947 }
948 }
949
950 impl<const E: bool, const L: bool, const C: bool, const P: bool, const T: bool>
952 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, true, T>>
953 {
954 pub fn enable_resources_list_changed(mut self) -> Self {
956 if let Some(ref mut resources) = self.resources {
957 resources.list_changed = Some(true);
958 }
959 self
960 }
961
962 pub fn enable_resources_subscribe(mut self) -> Self {
964 if let Some(ref mut resources) = self.resources {
965 resources.subscribe = Some(true);
966 }
967 self
968 }
969 }
970
971 impl<const L: bool, const C: bool, const P: bool, const R: bool, const T: bool>
973 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<true, L, C, P, R, T>>
974 {
975 pub fn add_experimental_capability<K, V>(mut self, key: K, value: V) -> Self
979 where
980 K: Into<String>,
981 V: Into<serde_json::Value>,
982 {
983 if let Some(ref mut experimental) = self.experimental {
984 experimental.insert(key.into(), value.into());
985 }
986 self
987 }
988
989 pub fn with_simd_optimization(mut self, level: &str) -> Self {
993 if let Some(ref mut experimental) = self.experimental {
994 experimental.insert(
995 "turbomcp_simd_level".to_string(),
996 serde_json::Value::String(level.to_string()),
997 );
998 }
999 self
1000 }
1001
1002 pub fn with_enterprise_security(mut self, enabled: bool) -> Self {
1006 if let Some(ref mut experimental) = self.experimental {
1007 experimental.insert(
1008 "turbomcp_enterprise_security".to_string(),
1009 serde_json::Value::Bool(enabled),
1010 );
1011 }
1012 self
1013 }
1014 }
1015
1016 impl ServerCapabilitiesBuilder {
1018 pub fn full_featured() -> ServerCapabilitiesBuilder<
1022 ServerCapabilitiesBuilderState<true, true, true, true, true, true>,
1023 > {
1024 Self::new()
1025 .enable_experimental()
1026 .enable_logging()
1027 .enable_completions()
1028 .enable_prompts()
1029 .enable_resources()
1030 .enable_tools()
1031 .enable_tool_list_changed()
1032 .enable_prompts_list_changed()
1033 .enable_resources_list_changed()
1034 .enable_resources_subscribe()
1035 .with_simd_optimization("avx2")
1036 .with_enterprise_security(true)
1037 }
1038
1039 pub fn minimal() -> ServerCapabilitiesBuilder<
1043 ServerCapabilitiesBuilderState<false, false, false, false, false, true>,
1044 > {
1045 Self::new().enable_tools()
1046 }
1047 }
1048
1049 #[derive(Debug, Clone)]
1061 pub struct ClientCapabilitiesBuilderState<
1062 const EXPERIMENTAL: bool = false,
1063 const ROOTS: bool = false,
1064 const SAMPLING: bool = false,
1065 const ELICITATION: bool = false,
1066 >;
1067
1068 #[derive(Debug, Clone)]
1074 pub struct ClientCapabilitiesBuilder<S = ClientCapabilitiesBuilderState> {
1075 experimental: Option<HashMap<String, serde_json::Value>>,
1076 roots: Option<RootsCapabilities>,
1077 sampling: Option<SamplingCapabilities>,
1078 elicitation: Option<ElicitationCapabilities>,
1079
1080 negotiator: Option<super::CapabilityNegotiator>,
1082 strict_validation: bool,
1083
1084 _state: PhantomData<S>,
1085 }
1086
1087 impl ClientCapabilities {
1088 pub fn builder() -> ClientCapabilitiesBuilder {
1093 ClientCapabilitiesBuilder::new()
1094 }
1095 }
1096
1097 impl Default for ClientCapabilitiesBuilder {
1098 fn default() -> Self {
1099 Self::new()
1100 }
1101 }
1102
1103 impl ClientCapabilitiesBuilder {
1104 pub fn new() -> Self {
1106 Self {
1107 experimental: None,
1108 roots: None,
1109 sampling: None,
1110 elicitation: None,
1111 negotiator: None,
1112 strict_validation: false,
1113 _state: PhantomData,
1114 }
1115 }
1116 }
1117
1118 impl<S> ClientCapabilitiesBuilder<S> {
1120 pub fn build(self) -> ClientCapabilities {
1125 ClientCapabilities {
1126 experimental: self.experimental,
1127 roots: self.roots,
1128 sampling: self.sampling,
1129 elicitation: self.elicitation,
1130 }
1131 }
1132
1133 pub fn with_strict_validation(mut self) -> Self {
1138 self.strict_validation = true;
1139 self
1140 }
1141
1142 pub fn with_negotiator(mut self, negotiator: super::CapabilityNegotiator) -> Self {
1147 self.negotiator = Some(negotiator);
1148 self
1149 }
1150
1151 pub fn validate(&self) -> Result<(), String> {
1156 if self.strict_validation {
1157 if let Some(ref experimental) = self.experimental {
1159 for (key, value) in experimental {
1160 if key.starts_with("turbomcp_") {
1161 match key.as_str() {
1163 "turbomcp_llm_provider" => {
1164 if !value.is_object() {
1165 return Err(
1166 "turbomcp_llm_provider must be an object".to_string()
1167 );
1168 }
1169 let obj = value.as_object().unwrap();
1170 if !obj.contains_key("provider") || !obj.contains_key("version")
1171 {
1172 return Err("turbomcp_llm_provider must have 'provider' and 'version' fields".to_string());
1173 }
1174 }
1175 "turbomcp_ui_capabilities" => {
1176 if !value.is_array() {
1177 return Err(
1178 "turbomcp_ui_capabilities must be an array".to_string()
1179 );
1180 }
1181 let arr = value.as_array().unwrap();
1182 let valid_ui_caps = [
1183 "form",
1184 "dialog",
1185 "notification",
1186 "toast",
1187 "modal",
1188 "sidebar",
1189 ];
1190 for cap in arr {
1191 if let Some(cap_str) = cap.as_str() {
1192 if !valid_ui_caps.contains(&cap_str) {
1193 return Err(format!(
1194 "Invalid UI capability: {}",
1195 cap_str
1196 ));
1197 }
1198 } else {
1199 return Err(
1200 "UI capabilities must be strings".to_string()
1201 );
1202 }
1203 }
1204 }
1205 _ => {
1206 }
1208 }
1209 }
1210 }
1211 }
1212 }
1213 Ok(())
1214 }
1215
1216 pub fn summary(&self) -> String {
1220 let mut capabilities = Vec::new();
1221 if self.experimental.is_some() {
1222 capabilities.push("experimental");
1223 }
1224 if self.roots.is_some() {
1225 capabilities.push("roots");
1226 }
1227 if self.sampling.is_some() {
1228 capabilities.push("sampling");
1229 }
1230 if self.elicitation.is_some() {
1231 capabilities.push("elicitation");
1232 }
1233
1234 if capabilities.is_empty() {
1235 "No capabilities enabled".to_string()
1236 } else {
1237 format!("Enabled capabilities: {}", capabilities.join(", "))
1238 }
1239 }
1240 }
1241
1242 impl<const R: bool, const S: bool, const E: bool>
1248 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<false, R, S, E>>
1249 {
1250 pub fn enable_experimental(
1255 self,
1256 ) -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<true, R, S, E>> {
1257 ClientCapabilitiesBuilder {
1258 experimental: Some(HashMap::new()),
1259 roots: self.roots,
1260 sampling: self.sampling,
1261 elicitation: self.elicitation,
1262 negotiator: self.negotiator,
1263 strict_validation: self.strict_validation,
1264 _state: PhantomData,
1265 }
1266 }
1267
1268 pub fn enable_experimental_with(
1270 self,
1271 experimental: HashMap<String, serde_json::Value>,
1272 ) -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<true, R, S, E>> {
1273 ClientCapabilitiesBuilder {
1274 experimental: Some(experimental),
1275 roots: self.roots,
1276 sampling: self.sampling,
1277 elicitation: self.elicitation,
1278 negotiator: self.negotiator,
1279 strict_validation: self.strict_validation,
1280 _state: PhantomData,
1281 }
1282 }
1283 }
1284
1285 impl<const X: bool, const S: bool, const E: bool>
1287 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, false, S, E>>
1288 {
1289 pub fn enable_roots(
1291 self,
1292 ) -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, true, S, E>> {
1293 ClientCapabilitiesBuilder {
1294 experimental: self.experimental,
1295 roots: Some(RootsCapabilities { list_changed: None }),
1296 sampling: self.sampling,
1297 elicitation: self.elicitation,
1298 negotiator: self.negotiator,
1299 strict_validation: self.strict_validation,
1300 _state: PhantomData,
1301 }
1302 }
1303 }
1304
1305 impl<const X: bool, const R: bool, const E: bool>
1307 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, R, false, E>>
1308 {
1309 pub fn enable_sampling(
1311 self,
1312 ) -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, R, true, E>> {
1313 ClientCapabilitiesBuilder {
1314 experimental: self.experimental,
1315 roots: self.roots,
1316 sampling: Some(SamplingCapabilities),
1317 elicitation: self.elicitation,
1318 negotiator: self.negotiator,
1319 strict_validation: self.strict_validation,
1320 _state: PhantomData,
1321 }
1322 }
1323 }
1324
1325 impl<const X: bool, const R: bool, const S: bool>
1327 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, R, S, false>>
1328 {
1329 pub fn enable_elicitation(
1331 self,
1332 ) -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, R, S, true>> {
1333 ClientCapabilitiesBuilder {
1334 experimental: self.experimental,
1335 roots: self.roots,
1336 sampling: self.sampling,
1337 elicitation: Some(ElicitationCapabilities),
1338 negotiator: self.negotiator,
1339 strict_validation: self.strict_validation,
1340 _state: PhantomData,
1341 }
1342 }
1343 }
1344
1345 impl<const X: bool, const S: bool, const E: bool>
1351 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, true, S, E>>
1352 {
1353 pub fn enable_roots_list_changed(mut self) -> Self {
1358 if let Some(ref mut roots) = self.roots {
1359 roots.list_changed = Some(true);
1360 }
1361 self
1362 }
1363 }
1364
1365 impl<const R: bool, const S: bool, const E: bool>
1367 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<true, R, S, E>>
1368 {
1369 pub fn add_experimental_capability<K, V>(mut self, key: K, value: V) -> Self
1373 where
1374 K: Into<String>,
1375 V: Into<serde_json::Value>,
1376 {
1377 if let Some(ref mut experimental) = self.experimental {
1378 experimental.insert(key.into(), value.into());
1379 }
1380 self
1381 }
1382
1383 pub fn with_llm_provider(mut self, provider: &str, version: &str) -> Self {
1387 if let Some(ref mut experimental) = self.experimental {
1388 experimental.insert(
1389 "turbomcp_llm_provider".to_string(),
1390 serde_json::json!({
1391 "provider": provider,
1392 "version": version
1393 }),
1394 );
1395 }
1396 self
1397 }
1398
1399 pub fn with_ui_capabilities(mut self, capabilities: Vec<&str>) -> Self {
1403 if let Some(ref mut experimental) = self.experimental {
1404 experimental.insert(
1405 "turbomcp_ui_capabilities".to_string(),
1406 serde_json::Value::Array(
1407 capabilities
1408 .into_iter()
1409 .map(|s| serde_json::Value::String(s.to_string()))
1410 .collect(),
1411 ),
1412 );
1413 }
1414 self
1415 }
1416 }
1417
1418 impl ClientCapabilitiesBuilder {
1420 pub fn full_featured()
1424 -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<true, true, true, true>>
1425 {
1426 Self::new()
1427 .enable_experimental()
1428 .enable_roots()
1429 .enable_sampling()
1430 .enable_elicitation()
1431 .enable_roots_list_changed()
1432 .with_llm_provider("openai", "gpt-4")
1433 .with_ui_capabilities(vec!["form", "dialog", "notification"])
1434 }
1435
1436 pub fn minimal()
1440 -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<false, false, true, false>>
1441 {
1442 Self::new().enable_sampling()
1443 }
1444
1445 pub fn sampling_focused()
1449 -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<true, false, true, false>>
1450 {
1451 Self::new()
1452 .enable_experimental()
1453 .enable_sampling()
1454 .with_llm_provider("anthropic", "claude-3")
1455 }
1456 }
1457
1458 #[cfg(test)]
1459 mod type_state_tests {
1460 use super::*;
1461
1462 #[test]
1463 fn test_server_capabilities_builder_type_state() {
1464 let builder = ServerCapabilities::builder();
1466 assert!(format!("{:?}", builder).contains("ServerCapabilitiesBuilder"));
1467
1468 let builder_with_tools = builder.enable_tools();
1470
1471 let _final_builder = builder_with_tools.enable_tool_list_changed();
1473
1474 let full_capabilities = ServerCapabilitiesBuilder::full_featured().build();
1476
1477 assert!(full_capabilities.experimental.is_some());
1478 assert!(full_capabilities.logging.is_some());
1479 assert!(full_capabilities.completions.is_some());
1480 assert!(full_capabilities.prompts.is_some());
1481 assert!(full_capabilities.resources.is_some());
1482 assert!(full_capabilities.tools.is_some());
1483
1484 if let Some(ref tools) = full_capabilities.tools {
1486 assert_eq!(tools.list_changed, Some(true));
1487 }
1488
1489 if let Some(ref resources) = full_capabilities.resources {
1490 assert_eq!(resources.list_changed, Some(true));
1491 assert_eq!(resources.subscribe, Some(true));
1492 }
1493 }
1494
1495 #[test]
1496 fn test_client_capabilities_builder_type_state() {
1497 let builder = ClientCapabilities::builder();
1499 assert!(format!("{:?}", builder).contains("ClientCapabilitiesBuilder"));
1500
1501 let builder_with_roots = builder.enable_roots();
1503
1504 let _final_builder = builder_with_roots.enable_roots_list_changed();
1506
1507 let full_capabilities = ClientCapabilitiesBuilder::full_featured().build();
1509
1510 assert!(full_capabilities.experimental.is_some());
1511 assert!(full_capabilities.roots.is_some());
1512 assert!(full_capabilities.sampling.is_some());
1513 assert!(full_capabilities.elicitation.is_some());
1514
1515 if let Some(ref roots) = full_capabilities.roots {
1517 assert_eq!(roots.list_changed, Some(true));
1518 }
1519 }
1520
1521 #[test]
1522 fn test_turbomcp_extensions() {
1523 let server_caps = ServerCapabilities::builder()
1525 .enable_experimental()
1526 .with_simd_optimization("avx2")
1527 .with_enterprise_security(true)
1528 .build();
1529
1530 if let Some(ref experimental) = server_caps.experimental {
1531 assert!(experimental.contains_key("turbomcp_simd_level"));
1532 assert!(experimental.contains_key("turbomcp_enterprise_security"));
1533 assert_eq!(
1534 experimental.get("turbomcp_simd_level").unwrap().as_str(),
1535 Some("avx2")
1536 );
1537 assert_eq!(
1538 experimental
1539 .get("turbomcp_enterprise_security")
1540 .unwrap()
1541 .as_bool(),
1542 Some(true)
1543 );
1544 } else {
1545 panic!("Expected experimental capabilities to be set");
1546 }
1547
1548 let client_caps = ClientCapabilities::builder()
1550 .enable_experimental()
1551 .with_llm_provider("openai", "gpt-4")
1552 .with_ui_capabilities(vec!["form", "dialog"])
1553 .build();
1554
1555 if let Some(ref experimental) = client_caps.experimental {
1556 assert!(experimental.contains_key("turbomcp_llm_provider"));
1557 assert!(experimental.contains_key("turbomcp_ui_capabilities"));
1558 } else {
1559 panic!("Expected experimental capabilities to be set");
1560 }
1561 }
1562
1563 #[test]
1564 fn test_convenience_builders() {
1565 let minimal_server = ServerCapabilitiesBuilder::minimal().build();
1567 assert!(minimal_server.tools.is_some());
1568 assert!(minimal_server.prompts.is_none());
1569
1570 let minimal_client = ClientCapabilitiesBuilder::minimal().build();
1572 assert!(minimal_client.sampling.is_some());
1573 assert!(minimal_client.roots.is_none());
1574
1575 let sampling_focused_client = ClientCapabilitiesBuilder::sampling_focused().build();
1576 assert!(sampling_focused_client.experimental.is_some());
1577 assert!(sampling_focused_client.sampling.is_some());
1578 }
1579
1580 #[test]
1581 fn test_builder_default_implementations() {
1582 let default_server_builder = ServerCapabilitiesBuilder::default();
1584 let server_caps = default_server_builder.build();
1585 assert!(server_caps.tools.is_none());
1586
1587 let default_client_builder = ClientCapabilitiesBuilder::default();
1588 let client_caps = default_client_builder.build();
1589 assert!(client_caps.sampling.is_none());
1590 }
1591
1592 #[test]
1593 fn test_builder_chaining() {
1594 let server_caps = ServerCapabilities::builder()
1596 .enable_experimental()
1597 .enable_tools()
1598 .enable_prompts()
1599 .enable_resources()
1600 .enable_tool_list_changed()
1601 .enable_prompts_list_changed()
1602 .enable_resources_list_changed()
1603 .enable_resources_subscribe()
1604 .add_experimental_capability("custom_feature", true)
1605 .build();
1606
1607 assert!(server_caps.experimental.is_some());
1608 assert!(server_caps.tools.is_some());
1609 assert!(server_caps.prompts.is_some());
1610 assert!(server_caps.resources.is_some());
1611
1612 if let Some(ref experimental) = server_caps.experimental {
1614 assert!(experimental.contains_key("custom_feature"));
1615 }
1616 }
1617
1618 #[test]
1619 fn test_with_negotiator_integration() {
1620 let negotiator = super::super::CapabilityNegotiator::default();
1622
1623 let server_caps = ServerCapabilities::builder()
1624 .enable_tools()
1625 .with_negotiator(negotiator.clone())
1626 .with_strict_validation()
1627 .build();
1628
1629 assert!(server_caps.tools.is_some());
1630 }
1633
1634 #[test]
1635 fn test_builder_validation_methods() {
1636 let server_builder = ServerCapabilities::builder()
1638 .enable_experimental()
1639 .enable_tools()
1640 .with_simd_optimization("avx2")
1641 .with_enterprise_security(true)
1642 .with_strict_validation();
1643
1644 assert!(server_builder.validate().is_ok());
1646
1647 let summary = server_builder.summary();
1649 assert!(summary.contains("experimental"));
1650 assert!(summary.contains("tools"));
1651
1652 let client_builder = ClientCapabilities::builder()
1654 .enable_experimental()
1655 .enable_sampling()
1656 .with_llm_provider("openai", "gpt-4")
1657 .with_ui_capabilities(vec!["form", "dialog"])
1658 .with_strict_validation();
1659
1660 assert!(client_builder.validate().is_ok());
1662
1663 let summary = client_builder.summary();
1665 assert!(summary.contains("experimental"));
1666 assert!(summary.contains("sampling"));
1667 }
1668
1669 #[test]
1670 fn test_builder_validation_errors() {
1671 let server_builder = ServerCapabilities::builder()
1673 .enable_experimental()
1674 .with_strict_validation();
1675
1676 assert!(server_builder.validate().is_err());
1678 let error = server_builder.validate().unwrap_err();
1679 assert!(error.contains("at least one capability"));
1680
1681 let invalid_server_builder = ServerCapabilities::builder()
1683 .enable_experimental()
1684 .enable_tools()
1685 .add_experimental_capability("turbomcp_simd_level", "invalid_level")
1686 .with_strict_validation();
1687
1688 assert!(invalid_server_builder.validate().is_err());
1689 let error = invalid_server_builder.validate().unwrap_err();
1690 assert!(error.contains("Invalid SIMD level"));
1691
1692 let invalid_client_builder = ClientCapabilities::builder()
1694 .enable_experimental()
1695 .enable_sampling()
1696 .add_experimental_capability("turbomcp_ui_capabilities", vec!["invalid_capability"])
1697 .with_strict_validation();
1698
1699 assert!(invalid_client_builder.validate().is_err());
1700 let error = invalid_client_builder.validate().unwrap_err();
1701 assert!(error.contains("Invalid UI capability"));
1702 }
1703
1704 #[test]
1705 fn test_builder_clone_support() {
1706 let original_server_builder = ServerCapabilities::builder()
1708 .enable_tools()
1709 .enable_prompts();
1710
1711 let cloned_server_builder = original_server_builder.clone();
1712
1713 let original_caps = original_server_builder.build();
1715 let cloned_caps = cloned_server_builder.build();
1716
1717 assert_eq!(original_caps.tools.is_some(), cloned_caps.tools.is_some());
1718 assert_eq!(
1719 original_caps.prompts.is_some(),
1720 cloned_caps.prompts.is_some()
1721 );
1722
1723 let original_client_builder = ClientCapabilities::builder()
1725 .enable_sampling()
1726 .enable_elicitation();
1727
1728 let cloned_client_builder = original_client_builder.clone();
1729
1730 let original_caps = original_client_builder.build();
1731 let cloned_caps = cloned_client_builder.build();
1732
1733 assert_eq!(
1734 original_caps.sampling.is_some(),
1735 cloned_caps.sampling.is_some()
1736 );
1737 assert_eq!(
1738 original_caps.elicitation.is_some(),
1739 cloned_caps.elicitation.is_some()
1740 );
1741 }
1742 }
1743}