1use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8
9use crate::types::{ClientCapabilities, ServerCapabilities};
10
11#[derive(Debug, Clone)]
31pub struct CapabilityMatcher {
32 compatibility_rules: HashMap<String, CompatibilityRule>,
34 defaults: HashMap<String, bool>,
36}
37
38#[derive(Debug, Clone)]
40pub enum CompatibilityRule {
41 RequireBoth,
43 RequireClient,
45 RequireServer,
47 Optional,
49 Custom(fn(&ClientCapabilities, &ServerCapabilities) -> bool),
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct CapabilitySet {
72 pub enabled_features: HashSet<String>,
74 pub client_capabilities: ClientCapabilities,
76 pub server_capabilities: ServerCapabilities,
78 pub metadata: HashMap<String, serde_json::Value>,
80}
81
82#[derive(Debug, Clone)]
84pub struct CapabilityNegotiator {
85 matcher: CapabilityMatcher,
87 strict_mode: bool,
89}
90
91impl Default for CapabilityMatcher {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl CapabilityMatcher {
98 pub fn new() -> Self {
100 let mut matcher = Self {
101 compatibility_rules: HashMap::new(),
102 defaults: HashMap::new(),
103 };
104
105 matcher.add_rule("tools", CompatibilityRule::RequireServer);
107 matcher.add_rule("prompts", CompatibilityRule::RequireServer);
108 matcher.add_rule("resources", CompatibilityRule::RequireServer);
109 matcher.add_rule("logging", CompatibilityRule::RequireServer);
110 matcher.add_rule("sampling", CompatibilityRule::RequireClient);
111 matcher.add_rule("roots", CompatibilityRule::RequireClient);
112 matcher.add_rule("progress", CompatibilityRule::Optional);
113
114 matcher.set_default("progress", true);
116
117 matcher
118 }
119
120 pub fn add_rule(&mut self, feature: &str, rule: CompatibilityRule) {
122 self.compatibility_rules.insert(feature.to_string(), rule);
123 }
124
125 pub fn set_default(&mut self, feature: &str, enabled: bool) {
127 self.defaults.insert(feature.to_string(), enabled);
128 }
129
130 pub fn is_compatible(
132 &self,
133 feature: &str,
134 client: &ClientCapabilities,
135 server: &ServerCapabilities,
136 ) -> bool {
137 self.compatibility_rules.get(feature).map_or_else(
138 || {
139 Self::client_has_feature(feature, client)
141 || Self::server_has_feature(feature, server)
142 },
143 |rule| match rule {
144 CompatibilityRule::RequireBoth => {
145 Self::client_has_feature(feature, client)
146 && Self::server_has_feature(feature, server)
147 }
148 CompatibilityRule::RequireClient => Self::client_has_feature(feature, client),
149 CompatibilityRule::RequireServer => Self::server_has_feature(feature, server),
150 CompatibilityRule::Optional => true,
151 CompatibilityRule::Custom(func) => func(client, server),
152 },
153 )
154 }
155
156 fn client_has_feature(feature: &str, client: &ClientCapabilities) -> bool {
158 match feature {
159 "sampling" => client.sampling.is_some(),
160 "roots" => client.roots.is_some(),
161 _ => {
162 client
164 .experimental
165 .as_ref()
166 .is_some_and(|experimental| experimental.contains_key(feature))
167 }
168 }
169 }
170
171 fn server_has_feature(feature: &str, server: &ServerCapabilities) -> bool {
173 match feature {
174 "tools" => server.tools.is_some(),
175 "prompts" => server.prompts.is_some(),
176 "resources" => server.resources.is_some(),
177 "logging" => server.logging.is_some(),
178 _ => {
179 server
181 .experimental
182 .as_ref()
183 .is_some_and(|experimental| experimental.contains_key(feature))
184 }
185 }
186 }
187
188 fn get_all_features(
190 &self,
191 client: &ClientCapabilities,
192 server: &ServerCapabilities,
193 ) -> HashSet<String> {
194 let mut features = HashSet::new();
195
196 if client.sampling.is_some() {
198 features.insert("sampling".to_string());
199 }
200 if client.roots.is_some() {
201 features.insert("roots".to_string());
202 }
203
204 if server.tools.is_some() {
206 features.insert("tools".to_string());
207 }
208 if server.prompts.is_some() {
209 features.insert("prompts".to_string());
210 }
211 if server.resources.is_some() {
212 features.insert("resources".to_string());
213 }
214 if server.logging.is_some() {
215 features.insert("logging".to_string());
216 }
217
218 if let Some(experimental) = &client.experimental {
220 features.extend(experimental.keys().cloned());
221 }
222 if let Some(experimental) = &server.experimental {
223 features.extend(experimental.keys().cloned());
224 }
225
226 features.extend(self.defaults.keys().cloned());
228
229 features
230 }
231
232 pub fn negotiate(
239 &self,
240 client: &ClientCapabilities,
241 server: &ServerCapabilities,
242 ) -> Result<CapabilitySet, CapabilityError> {
243 let all_features = self.get_all_features(client, server);
244 let mut enabled_features = HashSet::new();
245 let mut incompatible_features = Vec::new();
246
247 for feature in &all_features {
248 if self.is_compatible(feature, client, server) {
249 enabled_features.insert(feature.clone());
250 } else {
251 incompatible_features.push(feature.clone());
252 }
253 }
254
255 if !incompatible_features.is_empty() {
256 return Err(CapabilityError::IncompatibleFeatures(incompatible_features));
257 }
258
259 for (feature, enabled) in &self.defaults {
261 if *enabled && !enabled_features.contains(feature) && all_features.contains(feature) {
262 enabled_features.insert(feature.clone());
263 }
264 }
265
266 Ok(CapabilitySet {
267 enabled_features,
268 client_capabilities: client.clone(),
269 server_capabilities: server.clone(),
270 metadata: HashMap::new(),
271 })
272 }
273}
274
275impl CapabilityNegotiator {
276 pub const fn new(matcher: CapabilityMatcher) -> Self {
278 Self {
279 matcher,
280 strict_mode: false,
281 }
282 }
283
284 pub const fn with_strict_mode(mut self) -> Self {
286 self.strict_mode = true;
287 self
288 }
289
290 pub fn negotiate(
298 &self,
299 client: &ClientCapabilities,
300 server: &ServerCapabilities,
301 ) -> Result<CapabilitySet, CapabilityError> {
302 match self.matcher.negotiate(client, server) {
303 Ok(capability_set) => Ok(capability_set),
304 Err(CapabilityError::IncompatibleFeatures(features)) if !self.strict_mode => {
305 tracing::warn!(
307 "Some features are incompatible and will be disabled: {:?}",
308 features
309 );
310
311 let all_features = self.matcher.get_all_features(client, server);
313 let mut enabled_features = HashSet::new();
314
315 for feature in &all_features {
316 if self.matcher.is_compatible(feature, client, server) {
317 enabled_features.insert(feature.clone());
318 }
319 }
320
321 Ok(CapabilitySet {
322 enabled_features,
323 client_capabilities: client.clone(),
324 server_capabilities: server.clone(),
325 metadata: HashMap::new(),
326 })
327 }
328 Err(err) => Err(err),
329 }
330 }
331
332 pub fn is_feature_enabled(capability_set: &CapabilitySet, feature: &str) -> bool {
334 capability_set.enabled_features.contains(feature)
335 }
336
337 pub fn get_enabled_features(capability_set: &CapabilitySet) -> Vec<String> {
339 let mut features: Vec<String> = capability_set.enabled_features.iter().cloned().collect();
340 features.sort();
341 features
342 }
343}
344
345impl Default for CapabilityNegotiator {
346 fn default() -> Self {
347 Self::new(CapabilityMatcher::new())
348 }
349}
350
351impl CapabilitySet {
352 pub fn empty() -> Self {
354 Self {
355 enabled_features: HashSet::new(),
356 client_capabilities: ClientCapabilities::default(),
357 server_capabilities: ServerCapabilities::default(),
358 metadata: HashMap::new(),
359 }
360 }
361
362 pub fn has_feature(&self, feature: &str) -> bool {
364 self.enabled_features.contains(feature)
365 }
366
367 pub fn enable_feature(&mut self, feature: String) {
369 self.enabled_features.insert(feature);
370 }
371
372 pub fn disable_feature(&mut self, feature: &str) {
374 self.enabled_features.remove(feature);
375 }
376
377 pub fn feature_count(&self) -> usize {
379 self.enabled_features.len()
380 }
381
382 pub fn add_metadata(&mut self, key: String, value: serde_json::Value) {
384 self.metadata.insert(key, value);
385 }
386
387 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
389 self.metadata.get(key)
390 }
391
392 pub fn summary(&self) -> CapabilitySummary {
394 CapabilitySummary {
395 total_features: self.enabled_features.len(),
396 client_features: self.count_client_features(),
397 server_features: self.count_server_features(),
398 enabled_features: self.enabled_features.iter().cloned().collect(),
399 }
400 }
401
402 fn count_client_features(&self) -> usize {
403 let mut count = 0;
404 if self.client_capabilities.sampling.is_some() {
405 count += 1;
406 }
407 if self.client_capabilities.roots.is_some() {
408 count += 1;
409 }
410 if let Some(experimental) = &self.client_capabilities.experimental {
411 count += experimental.len();
412 }
413 count
414 }
415
416 fn count_server_features(&self) -> usize {
417 let mut count = 0;
418 if self.server_capabilities.tools.is_some() {
419 count += 1;
420 }
421 if self.server_capabilities.prompts.is_some() {
422 count += 1;
423 }
424 if self.server_capabilities.resources.is_some() {
425 count += 1;
426 }
427 if self.server_capabilities.logging.is_some() {
428 count += 1;
429 }
430 if let Some(experimental) = &self.server_capabilities.experimental {
431 count += experimental.len();
432 }
433 count
434 }
435}
436
437#[derive(Debug, Clone, thiserror::Error)]
439pub enum CapabilityError {
440 #[error("Incompatible features: {0:?}")]
442 IncompatibleFeatures(Vec<String>),
443 #[error("Required feature missing: {0}")]
445 RequiredFeatureMissing(String),
446 #[error("Protocol version mismatch: client={client}, server={server}")]
448 VersionMismatch {
449 client: String,
451 server: String,
453 },
454 #[error("Capability negotiation failed: {0}")]
456 NegotiationFailed(String),
457}
458
459#[derive(Debug, Clone, Serialize, Deserialize)]
461pub struct CapabilitySummary {
462 pub total_features: usize,
464 pub client_features: usize,
466 pub server_features: usize,
468 pub enabled_features: Vec<String>,
470}
471
472pub mod utils {
474 use super::*;
475
476 pub fn minimal_client_capabilities() -> ClientCapabilities {
478 ClientCapabilities::default()
479 }
480
481 pub fn minimal_server_capabilities() -> ServerCapabilities {
483 ServerCapabilities::default()
484 }
485
486 pub fn full_client_capabilities() -> ClientCapabilities {
488 ClientCapabilities {
489 sampling: Some(Default::default()),
490 roots: Some(Default::default()),
491 elicitation: Some(Default::default()),
492 experimental: None,
493 }
494 }
495
496 pub fn full_server_capabilities() -> ServerCapabilities {
498 ServerCapabilities {
499 tools: Some(Default::default()),
500 prompts: Some(Default::default()),
501 resources: Some(Default::default()),
502 completions: Some(Default::default()),
503 logging: Some(Default::default()),
504 experimental: None,
505 }
506 }
507
508 pub fn are_compatible(client: &ClientCapabilities, server: &ServerCapabilities) -> bool {
510 let matcher = CapabilityMatcher::new();
511 matcher.negotiate(client, server).is_ok()
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use crate::types::*;
519
520 #[test]
521 fn test_capability_matcher() {
522 let matcher = CapabilityMatcher::new();
523
524 let client = ClientCapabilities {
525 sampling: Some(SamplingCapabilities),
526 roots: None,
527 elicitation: None,
528 experimental: None,
529 };
530
531 let server = ServerCapabilities {
532 tools: Some(ToolsCapabilities::default()),
533 prompts: None,
534 resources: None,
535 logging: None,
536 completions: None,
537 experimental: None,
538 };
539
540 assert!(matcher.is_compatible("sampling", &client, &server));
541 assert!(matcher.is_compatible("tools", &client, &server));
542 assert!(!matcher.is_compatible("roots", &client, &server));
543 }
544
545 #[test]
546 fn test_capability_negotiation() {
547 let negotiator = CapabilityNegotiator::default();
548
549 let client = utils::full_client_capabilities();
550 let server = utils::full_server_capabilities();
551
552 let result = negotiator.negotiate(&client, &server);
553 assert!(result.is_ok());
554
555 let capability_set = result.unwrap();
556 assert!(capability_set.has_feature("sampling"));
557 assert!(capability_set.has_feature("tools"));
558 assert!(capability_set.has_feature("roots"));
559 }
560
561 #[test]
562 fn test_strict_mode() {
563 let negotiator = CapabilityNegotiator::default().with_strict_mode();
564
565 let client = ClientCapabilities::default();
566 let server = ServerCapabilities::default();
567
568 let result = negotiator.negotiate(&client, &server);
569 assert!(result.is_ok()); }
571
572 #[test]
573 fn test_capability_summary() {
574 let mut capability_set = CapabilitySet::empty();
575 capability_set.enable_feature("tools".to_string());
576 capability_set.enable_feature("prompts".to_string());
577
578 let summary = capability_set.summary();
579 assert_eq!(summary.total_features, 2);
580 assert!(summary.enabled_features.contains(&"tools".to_string()));
581 }
582}
583
584pub mod builders {
594 use crate::types::{
595 ClientCapabilities, CompletionCapabilities, ElicitationCapabilities, LoggingCapabilities,
596 PromptsCapabilities, ResourcesCapabilities, RootsCapabilities, SamplingCapabilities,
597 ServerCapabilities, ToolsCapabilities,
598 };
599 use serde_json;
600 use std::collections::HashMap;
601 use std::marker::PhantomData;
602
603 #[derive(Debug, Clone)]
617 pub struct ServerCapabilitiesBuilderState<
618 const EXPERIMENTAL: bool = false,
619 const LOGGING: bool = false,
620 const COMPLETIONS: bool = false,
621 const PROMPTS: bool = false,
622 const RESOURCES: bool = false,
623 const TOOLS: bool = false,
624 >;
625
626 #[derive(Debug, Clone)]
632 pub struct ServerCapabilitiesBuilder<S = ServerCapabilitiesBuilderState> {
633 experimental: Option<HashMap<String, serde_json::Value>>,
634 logging: Option<LoggingCapabilities>,
635 completions: Option<CompletionCapabilities>,
636 prompts: Option<PromptsCapabilities>,
637 resources: Option<ResourcesCapabilities>,
638 tools: Option<ToolsCapabilities>,
639
640 negotiator: Option<super::CapabilityNegotiator>,
642 strict_validation: bool,
643
644 _state: PhantomData<S>,
645 }
646
647 impl ServerCapabilities {
648 pub fn builder() -> ServerCapabilitiesBuilder {
653 ServerCapabilitiesBuilder::new()
654 }
655 }
656
657 impl Default for ServerCapabilitiesBuilder {
658 fn default() -> Self {
659 Self::new()
660 }
661 }
662
663 impl ServerCapabilitiesBuilder {
664 pub fn new() -> Self {
666 Self {
667 experimental: None,
668 logging: None,
669 completions: None,
670 prompts: None,
671 resources: None,
672 tools: None,
673 negotiator: None,
674 strict_validation: false,
675 _state: PhantomData,
676 }
677 }
678 }
679
680 impl<S> ServerCapabilitiesBuilder<S> {
682 pub fn build(self) -> ServerCapabilities {
687 ServerCapabilities {
688 experimental: self.experimental,
689 logging: self.logging,
690 completions: self.completions,
691 prompts: self.prompts,
692 resources: self.resources,
693 tools: self.tools,
694 }
695 }
696
697 pub fn with_strict_validation(mut self) -> Self {
702 self.strict_validation = true;
703 self
704 }
705
706 pub fn with_negotiator(mut self, negotiator: super::CapabilityNegotiator) -> Self {
711 self.negotiator = Some(negotiator);
712 self
713 }
714
715 pub fn validate(&self) -> Result<(), String> {
720 if self.strict_validation {
721 if self.tools.is_none() && self.prompts.is_none() && self.resources.is_none() {
723 return Err("Server must provide at least one capability (tools, prompts, or resources)".to_string());
724 }
725
726 if let Some(ref experimental) = self.experimental {
728 for (key, value) in experimental {
729 if key.starts_with("turbomcp_") {
730 match key.as_str() {
732 "turbomcp_simd_level" => {
733 if !value.is_string() {
734 return Err(
735 "turbomcp_simd_level must be a string".to_string()
736 );
737 }
738 let level = value.as_str().unwrap_or("");
739 if !["none", "sse2", "sse4", "avx2", "avx512"].contains(&level)
740 {
741 return Err(format!("Invalid SIMD level: {}", level));
742 }
743 }
744 "turbomcp_enterprise_security" => {
745 if !value.is_boolean() {
746 return Err(
747 "turbomcp_enterprise_security must be a boolean"
748 .to_string(),
749 );
750 }
751 }
752 _ => {
753 }
755 }
756 }
757 }
758 }
759 }
760 Ok(())
761 }
762
763 pub fn summary(&self) -> String {
767 let mut capabilities = Vec::new();
768 if self.experimental.is_some() {
769 capabilities.push("experimental");
770 }
771 if self.logging.is_some() {
772 capabilities.push("logging");
773 }
774 if self.completions.is_some() {
775 capabilities.push("completions");
776 }
777 if self.prompts.is_some() {
778 capabilities.push("prompts");
779 }
780 if self.resources.is_some() {
781 capabilities.push("resources");
782 }
783 if self.tools.is_some() {
784 capabilities.push("tools");
785 }
786
787 if capabilities.is_empty() {
788 "No capabilities enabled".to_string()
789 } else {
790 format!("Enabled capabilities: {}", capabilities.join(", "))
791 }
792 }
793 }
794
795 impl<const L: bool, const C: bool, const P: bool, const R: bool, const T: bool>
801 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<false, L, C, P, R, T>>
802 {
803 pub fn enable_experimental(
808 self,
809 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<true, L, C, P, R, T>>
810 {
811 ServerCapabilitiesBuilder {
812 experimental: Some(HashMap::new()),
813 logging: self.logging,
814 completions: self.completions,
815 prompts: self.prompts,
816 resources: self.resources,
817 tools: self.tools,
818 negotiator: self.negotiator,
819 strict_validation: self.strict_validation,
820 _state: PhantomData,
821 }
822 }
823
824 pub fn enable_experimental_with(
826 self,
827 experimental: HashMap<String, serde_json::Value>,
828 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<true, L, C, P, R, T>>
829 {
830 ServerCapabilitiesBuilder {
831 experimental: Some(experimental),
832 logging: self.logging,
833 completions: self.completions,
834 prompts: self.prompts,
835 resources: self.resources,
836 tools: self.tools,
837 negotiator: self.negotiator,
838 strict_validation: self.strict_validation,
839 _state: PhantomData,
840 }
841 }
842 }
843
844 impl<const E: bool, const C: bool, const P: bool, const R: bool, const T: bool>
846 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, false, C, P, R, T>>
847 {
848 pub fn enable_logging(
850 self,
851 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, true, C, P, R, T>>
852 {
853 ServerCapabilitiesBuilder {
854 experimental: self.experimental,
855 logging: Some(LoggingCapabilities),
856 completions: self.completions,
857 prompts: self.prompts,
858 resources: self.resources,
859 tools: self.tools,
860 negotiator: self.negotiator,
861 strict_validation: self.strict_validation,
862 _state: PhantomData,
863 }
864 }
865 }
866
867 impl<const E: bool, const L: bool, const P: bool, const R: bool, const T: bool>
869 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, false, P, R, T>>
870 {
871 pub fn enable_completions(
873 self,
874 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, true, P, R, T>>
875 {
876 ServerCapabilitiesBuilder {
877 experimental: self.experimental,
878 logging: self.logging,
879 completions: Some(CompletionCapabilities),
880 prompts: self.prompts,
881 resources: self.resources,
882 tools: self.tools,
883 negotiator: self.negotiator,
884 strict_validation: self.strict_validation,
885 _state: PhantomData,
886 }
887 }
888 }
889
890 impl<const E: bool, const L: bool, const C: bool, const R: bool, const T: bool>
892 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, false, R, T>>
893 {
894 pub fn enable_prompts(
896 self,
897 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, true, R, T>>
898 {
899 ServerCapabilitiesBuilder {
900 experimental: self.experimental,
901 logging: self.logging,
902 completions: self.completions,
903 prompts: Some(PromptsCapabilities { list_changed: None }),
904 resources: self.resources,
905 tools: self.tools,
906 negotiator: self.negotiator,
907 strict_validation: self.strict_validation,
908 _state: PhantomData,
909 }
910 }
911 }
912
913 impl<const E: bool, const L: bool, const C: bool, const P: bool, const T: bool>
915 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, false, T>>
916 {
917 pub fn enable_resources(
919 self,
920 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, true, T>>
921 {
922 ServerCapabilitiesBuilder {
923 experimental: self.experimental,
924 logging: self.logging,
925 completions: self.completions,
926 prompts: self.prompts,
927 resources: Some(ResourcesCapabilities {
928 subscribe: None,
929 list_changed: None,
930 }),
931 tools: self.tools,
932 negotiator: self.negotiator,
933 strict_validation: self.strict_validation,
934 _state: PhantomData,
935 }
936 }
937 }
938
939 impl<const E: bool, const L: bool, const C: bool, const P: bool, const R: bool>
941 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, R, false>>
942 {
943 pub fn enable_tools(
945 self,
946 ) -> ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, R, true>>
947 {
948 ServerCapabilitiesBuilder {
949 experimental: self.experimental,
950 logging: self.logging,
951 completions: self.completions,
952 prompts: self.prompts,
953 resources: self.resources,
954 tools: Some(ToolsCapabilities { list_changed: None }),
955 negotiator: self.negotiator,
956 strict_validation: self.strict_validation,
957 _state: PhantomData,
958 }
959 }
960 }
961
962 impl<const E: bool, const L: bool, const C: bool, const P: bool, const R: bool>
968 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, R, true>>
969 {
970 pub fn enable_tool_list_changed(mut self) -> Self {
975 if let Some(ref mut tools) = self.tools {
976 tools.list_changed = Some(true);
977 }
978 self
979 }
980 }
981
982 impl<const E: bool, const L: bool, const C: bool, const R: bool, const T: bool>
984 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, true, R, T>>
985 {
986 pub fn enable_prompts_list_changed(mut self) -> Self {
988 if let Some(ref mut prompts) = self.prompts {
989 prompts.list_changed = Some(true);
990 }
991 self
992 }
993 }
994
995 impl<const E: bool, const L: bool, const C: bool, const P: bool, const T: bool>
997 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<E, L, C, P, true, T>>
998 {
999 pub fn enable_resources_list_changed(mut self) -> Self {
1001 if let Some(ref mut resources) = self.resources {
1002 resources.list_changed = Some(true);
1003 }
1004 self
1005 }
1006
1007 pub fn enable_resources_subscribe(mut self) -> Self {
1009 if let Some(ref mut resources) = self.resources {
1010 resources.subscribe = Some(true);
1011 }
1012 self
1013 }
1014 }
1015
1016 impl<const L: bool, const C: bool, const P: bool, const R: bool, const T: bool>
1018 ServerCapabilitiesBuilder<ServerCapabilitiesBuilderState<true, L, C, P, R, T>>
1019 {
1020 pub fn add_experimental_capability<K, V>(mut self, key: K, value: V) -> Self
1024 where
1025 K: Into<String>,
1026 V: Into<serde_json::Value>,
1027 {
1028 if let Some(ref mut experimental) = self.experimental {
1029 experimental.insert(key.into(), value.into());
1030 }
1031 self
1032 }
1033
1034 pub fn with_simd_optimization(mut self, level: &str) -> Self {
1038 if let Some(ref mut experimental) = self.experimental {
1039 experimental.insert(
1040 "turbomcp_simd_level".to_string(),
1041 serde_json::Value::String(level.to_string()),
1042 );
1043 }
1044 self
1045 }
1046
1047 pub fn with_enterprise_security(mut self, enabled: bool) -> Self {
1051 if let Some(ref mut experimental) = self.experimental {
1052 experimental.insert(
1053 "turbomcp_enterprise_security".to_string(),
1054 serde_json::Value::Bool(enabled),
1055 );
1056 }
1057 self
1058 }
1059 }
1060
1061 impl ServerCapabilitiesBuilder {
1063 pub fn full_featured() -> ServerCapabilitiesBuilder<
1067 ServerCapabilitiesBuilderState<true, true, true, true, true, true>,
1068 > {
1069 Self::new()
1070 .enable_experimental()
1071 .enable_logging()
1072 .enable_completions()
1073 .enable_prompts()
1074 .enable_resources()
1075 .enable_tools()
1076 .enable_tool_list_changed()
1077 .enable_prompts_list_changed()
1078 .enable_resources_list_changed()
1079 .enable_resources_subscribe()
1080 .with_simd_optimization("avx2")
1081 .with_enterprise_security(true)
1082 }
1083
1084 pub fn minimal() -> ServerCapabilitiesBuilder<
1088 ServerCapabilitiesBuilderState<false, false, false, false, false, true>,
1089 > {
1090 Self::new().enable_tools()
1091 }
1092 }
1093
1094 #[derive(Debug, Clone)]
1106 pub struct ClientCapabilitiesBuilderState<
1107 const EXPERIMENTAL: bool = false,
1108 const ROOTS: bool = false,
1109 const SAMPLING: bool = false,
1110 const ELICITATION: bool = false,
1111 >;
1112
1113 #[derive(Debug, Clone)]
1119 pub struct ClientCapabilitiesBuilder<S = ClientCapabilitiesBuilderState> {
1120 experimental: Option<HashMap<String, serde_json::Value>>,
1121 roots: Option<RootsCapabilities>,
1122 sampling: Option<SamplingCapabilities>,
1123 elicitation: Option<ElicitationCapabilities>,
1124
1125 negotiator: Option<super::CapabilityNegotiator>,
1127 strict_validation: bool,
1128
1129 _state: PhantomData<S>,
1130 }
1131
1132 impl ClientCapabilities {
1133 pub fn builder() -> ClientCapabilitiesBuilder {
1138 ClientCapabilitiesBuilder::new()
1139 }
1140 }
1141
1142 impl Default for ClientCapabilitiesBuilder {
1143 fn default() -> Self {
1144 Self::new()
1145 }
1146 }
1147
1148 impl ClientCapabilitiesBuilder {
1149 pub fn new() -> Self {
1151 Self {
1152 experimental: None,
1153 roots: None,
1154 sampling: None,
1155 elicitation: None,
1156 negotiator: None,
1157 strict_validation: false,
1158 _state: PhantomData,
1159 }
1160 }
1161 }
1162
1163 impl<S> ClientCapabilitiesBuilder<S> {
1165 pub fn build(self) -> ClientCapabilities {
1170 ClientCapabilities {
1171 experimental: self.experimental,
1172 roots: self.roots,
1173 sampling: self.sampling,
1174 elicitation: self.elicitation,
1175 }
1176 }
1177
1178 pub fn with_strict_validation(mut self) -> Self {
1183 self.strict_validation = true;
1184 self
1185 }
1186
1187 pub fn with_negotiator(mut self, negotiator: super::CapabilityNegotiator) -> Self {
1192 self.negotiator = Some(negotiator);
1193 self
1194 }
1195
1196 pub fn validate(&self) -> Result<(), String> {
1201 if self.strict_validation {
1202 if let Some(ref experimental) = self.experimental {
1204 for (key, value) in experimental {
1205 if key.starts_with("turbomcp_") {
1206 match key.as_str() {
1208 "turbomcp_llm_provider" => {
1209 let obj = value.as_object().ok_or_else(|| {
1210 "turbomcp_llm_provider must be an object".to_string()
1211 })?;
1212 if !obj.contains_key("provider") || !obj.contains_key("version")
1213 {
1214 return Err("turbomcp_llm_provider must have 'provider' and 'version' fields".to_string());
1215 }
1216 }
1217 "turbomcp_ui_capabilities" => {
1218 let arr = value.as_array().ok_or_else(|| {
1219 "turbomcp_ui_capabilities must be an array".to_string()
1220 })?;
1221 let valid_ui_caps = [
1222 "form",
1223 "dialog",
1224 "notification",
1225 "toast",
1226 "modal",
1227 "sidebar",
1228 ];
1229 for cap in arr {
1230 if let Some(cap_str) = cap.as_str() {
1231 if !valid_ui_caps.contains(&cap_str) {
1232 return Err(format!(
1233 "Invalid UI capability: {}",
1234 cap_str
1235 ));
1236 }
1237 } else {
1238 return Err(
1239 "UI capabilities must be strings".to_string()
1240 );
1241 }
1242 }
1243 }
1244 _ => {
1245 }
1247 }
1248 }
1249 }
1250 }
1251 }
1252 Ok(())
1253 }
1254
1255 pub fn summary(&self) -> String {
1259 let mut capabilities = Vec::new();
1260 if self.experimental.is_some() {
1261 capabilities.push("experimental");
1262 }
1263 if self.roots.is_some() {
1264 capabilities.push("roots");
1265 }
1266 if self.sampling.is_some() {
1267 capabilities.push("sampling");
1268 }
1269 if self.elicitation.is_some() {
1270 capabilities.push("elicitation");
1271 }
1272
1273 if capabilities.is_empty() {
1274 "No capabilities enabled".to_string()
1275 } else {
1276 format!("Enabled capabilities: {}", capabilities.join(", "))
1277 }
1278 }
1279 }
1280
1281 impl<const R: bool, const S: bool, const E: bool>
1287 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<false, R, S, E>>
1288 {
1289 pub fn enable_experimental(
1294 self,
1295 ) -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<true, R, S, E>> {
1296 ClientCapabilitiesBuilder {
1297 experimental: Some(HashMap::new()),
1298 roots: self.roots,
1299 sampling: self.sampling,
1300 elicitation: self.elicitation,
1301 negotiator: self.negotiator,
1302 strict_validation: self.strict_validation,
1303 _state: PhantomData,
1304 }
1305 }
1306
1307 pub fn enable_experimental_with(
1309 self,
1310 experimental: HashMap<String, serde_json::Value>,
1311 ) -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<true, R, S, E>> {
1312 ClientCapabilitiesBuilder {
1313 experimental: Some(experimental),
1314 roots: self.roots,
1315 sampling: self.sampling,
1316 elicitation: self.elicitation,
1317 negotiator: self.negotiator,
1318 strict_validation: self.strict_validation,
1319 _state: PhantomData,
1320 }
1321 }
1322 }
1323
1324 impl<const X: bool, const S: bool, const E: bool>
1326 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, false, S, E>>
1327 {
1328 pub fn enable_roots(
1330 self,
1331 ) -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, true, S, E>> {
1332 ClientCapabilitiesBuilder {
1333 experimental: self.experimental,
1334 roots: Some(RootsCapabilities { list_changed: None }),
1335 sampling: self.sampling,
1336 elicitation: self.elicitation,
1337 negotiator: self.negotiator,
1338 strict_validation: self.strict_validation,
1339 _state: PhantomData,
1340 }
1341 }
1342 }
1343
1344 impl<const X: bool, const R: bool, const E: bool>
1346 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, R, false, E>>
1347 {
1348 pub fn enable_sampling(
1350 self,
1351 ) -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, R, true, E>> {
1352 ClientCapabilitiesBuilder {
1353 experimental: self.experimental,
1354 roots: self.roots,
1355 sampling: Some(SamplingCapabilities),
1356 elicitation: self.elicitation,
1357 negotiator: self.negotiator,
1358 strict_validation: self.strict_validation,
1359 _state: PhantomData,
1360 }
1361 }
1362 }
1363
1364 impl<const X: bool, const R: bool, const S: bool>
1366 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, R, S, false>>
1367 {
1368 pub fn enable_elicitation(
1370 self,
1371 ) -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, R, S, true>> {
1372 ClientCapabilitiesBuilder {
1373 experimental: self.experimental,
1374 roots: self.roots,
1375 sampling: self.sampling,
1376 elicitation: Some(ElicitationCapabilities::default()),
1377 negotiator: self.negotiator,
1378 strict_validation: self.strict_validation,
1379 _state: PhantomData,
1380 }
1381 }
1382
1383 pub fn enable_elicitation_with_schema_validation(
1385 self,
1386 ) -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, R, S, true>> {
1387 ClientCapabilitiesBuilder {
1388 experimental: self.experimental,
1389 roots: self.roots,
1390 sampling: self.sampling,
1391 elicitation: Some(ElicitationCapabilities::default().with_schema_validation()),
1392 negotiator: self.negotiator,
1393 strict_validation: self.strict_validation,
1394 _state: PhantomData,
1395 }
1396 }
1397 }
1398
1399 impl<const X: bool, const S: bool, const E: bool>
1405 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<X, true, S, E>>
1406 {
1407 pub fn enable_roots_list_changed(mut self) -> Self {
1412 if let Some(ref mut roots) = self.roots {
1413 roots.list_changed = Some(true);
1414 }
1415 self
1416 }
1417 }
1418
1419 impl<const R: bool, const S: bool, const E: bool>
1421 ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<true, R, S, E>>
1422 {
1423 pub fn add_experimental_capability<K, V>(mut self, key: K, value: V) -> Self
1427 where
1428 K: Into<String>,
1429 V: Into<serde_json::Value>,
1430 {
1431 if let Some(ref mut experimental) = self.experimental {
1432 experimental.insert(key.into(), value.into());
1433 }
1434 self
1435 }
1436
1437 pub fn with_llm_provider(mut self, provider: &str, version: &str) -> Self {
1441 if let Some(ref mut experimental) = self.experimental {
1442 experimental.insert(
1443 "turbomcp_llm_provider".to_string(),
1444 serde_json::json!({
1445 "provider": provider,
1446 "version": version
1447 }),
1448 );
1449 }
1450 self
1451 }
1452
1453 pub fn with_ui_capabilities(mut self, capabilities: Vec<&str>) -> Self {
1457 if let Some(ref mut experimental) = self.experimental {
1458 experimental.insert(
1459 "turbomcp_ui_capabilities".to_string(),
1460 serde_json::Value::Array(
1461 capabilities
1462 .into_iter()
1463 .map(|s| serde_json::Value::String(s.to_string()))
1464 .collect(),
1465 ),
1466 );
1467 }
1468 self
1469 }
1470 }
1471
1472 impl ClientCapabilitiesBuilder {
1474 pub fn full_featured()
1478 -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<true, true, true, true>>
1479 {
1480 Self::new()
1481 .enable_experimental()
1482 .enable_roots()
1483 .enable_sampling()
1484 .enable_elicitation()
1485 .enable_roots_list_changed()
1486 .with_llm_provider("openai", "gpt-4")
1487 .with_ui_capabilities(vec!["form", "dialog", "notification"])
1488 }
1489
1490 pub fn minimal()
1494 -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<false, false, true, false>>
1495 {
1496 Self::new().enable_sampling()
1497 }
1498
1499 pub fn sampling_focused()
1503 -> ClientCapabilitiesBuilder<ClientCapabilitiesBuilderState<true, false, true, false>>
1504 {
1505 Self::new()
1506 .enable_experimental()
1507 .enable_sampling()
1508 .with_llm_provider("anthropic", "claude-3")
1509 }
1510 }
1511
1512 #[cfg(test)]
1513 mod type_state_tests {
1514 use super::*;
1515
1516 #[test]
1517 fn test_server_capabilities_builder_type_state() {
1518 let builder = ServerCapabilities::builder();
1520 assert!(format!("{:?}", builder).contains("ServerCapabilitiesBuilder"));
1521
1522 let builder_with_tools = builder.enable_tools();
1524
1525 let _final_builder = builder_with_tools.enable_tool_list_changed();
1527
1528 let full_capabilities = ServerCapabilitiesBuilder::full_featured().build();
1530
1531 assert!(full_capabilities.experimental.is_some());
1532 assert!(full_capabilities.logging.is_some());
1533 assert!(full_capabilities.completions.is_some());
1534 assert!(full_capabilities.prompts.is_some());
1535 assert!(full_capabilities.resources.is_some());
1536 assert!(full_capabilities.tools.is_some());
1537
1538 if let Some(ref tools) = full_capabilities.tools {
1540 assert_eq!(tools.list_changed, Some(true));
1541 }
1542
1543 if let Some(ref resources) = full_capabilities.resources {
1544 assert_eq!(resources.list_changed, Some(true));
1545 assert_eq!(resources.subscribe, Some(true));
1546 }
1547 }
1548
1549 #[test]
1550 fn test_client_capabilities_builder_type_state() {
1551 let builder = ClientCapabilities::builder();
1553 assert!(format!("{:?}", builder).contains("ClientCapabilitiesBuilder"));
1554
1555 let builder_with_roots = builder.enable_roots();
1557
1558 let _final_builder = builder_with_roots.enable_roots_list_changed();
1560
1561 let full_capabilities = ClientCapabilitiesBuilder::full_featured().build();
1563
1564 assert!(full_capabilities.experimental.is_some());
1565 assert!(full_capabilities.roots.is_some());
1566 assert!(full_capabilities.sampling.is_some());
1567 assert!(full_capabilities.elicitation.is_some());
1568
1569 if let Some(ref roots) = full_capabilities.roots {
1571 assert_eq!(roots.list_changed, Some(true));
1572 }
1573 }
1574
1575 #[test]
1576 fn test_turbomcp_extensions() {
1577 let server_caps = ServerCapabilities::builder()
1579 .enable_experimental()
1580 .with_simd_optimization("avx2")
1581 .with_enterprise_security(true)
1582 .build();
1583
1584 if let Some(ref experimental) = server_caps.experimental {
1585 assert!(experimental.contains_key("turbomcp_simd_level"));
1586 assert!(experimental.contains_key("turbomcp_enterprise_security"));
1587 assert_eq!(
1588 experimental.get("turbomcp_simd_level").unwrap().as_str(),
1589 Some("avx2")
1590 );
1591 assert_eq!(
1592 experimental
1593 .get("turbomcp_enterprise_security")
1594 .unwrap()
1595 .as_bool(),
1596 Some(true)
1597 );
1598 } else {
1599 panic!("Expected experimental capabilities to be set");
1600 }
1601
1602 let client_caps = ClientCapabilities::builder()
1604 .enable_experimental()
1605 .with_llm_provider("openai", "gpt-4")
1606 .with_ui_capabilities(vec!["form", "dialog"])
1607 .build();
1608
1609 if let Some(ref experimental) = client_caps.experimental {
1610 assert!(experimental.contains_key("turbomcp_llm_provider"));
1611 assert!(experimental.contains_key("turbomcp_ui_capabilities"));
1612 } else {
1613 panic!("Expected experimental capabilities to be set");
1614 }
1615 }
1616
1617 #[test]
1618 fn test_convenience_builders() {
1619 let minimal_server = ServerCapabilitiesBuilder::minimal().build();
1621 assert!(minimal_server.tools.is_some());
1622 assert!(minimal_server.prompts.is_none());
1623
1624 let minimal_client = ClientCapabilitiesBuilder::minimal().build();
1626 assert!(minimal_client.sampling.is_some());
1627 assert!(minimal_client.roots.is_none());
1628
1629 let sampling_focused_client = ClientCapabilitiesBuilder::sampling_focused().build();
1630 assert!(sampling_focused_client.experimental.is_some());
1631 assert!(sampling_focused_client.sampling.is_some());
1632 }
1633
1634 #[test]
1635 fn test_builder_default_implementations() {
1636 let default_server_builder = ServerCapabilitiesBuilder::default();
1638 let server_caps = default_server_builder.build();
1639 assert!(server_caps.tools.is_none());
1640
1641 let default_client_builder = ClientCapabilitiesBuilder::default();
1642 let client_caps = default_client_builder.build();
1643 assert!(client_caps.sampling.is_none());
1644 }
1645
1646 #[test]
1647 fn test_builder_chaining() {
1648 let server_caps = ServerCapabilities::builder()
1650 .enable_experimental()
1651 .enable_tools()
1652 .enable_prompts()
1653 .enable_resources()
1654 .enable_tool_list_changed()
1655 .enable_prompts_list_changed()
1656 .enable_resources_list_changed()
1657 .enable_resources_subscribe()
1658 .add_experimental_capability("custom_feature", true)
1659 .build();
1660
1661 assert!(server_caps.experimental.is_some());
1662 assert!(server_caps.tools.is_some());
1663 assert!(server_caps.prompts.is_some());
1664 assert!(server_caps.resources.is_some());
1665
1666 if let Some(ref experimental) = server_caps.experimental {
1668 assert!(experimental.contains_key("custom_feature"));
1669 }
1670 }
1671
1672 #[test]
1673 fn test_with_negotiator_integration() {
1674 let negotiator = super::super::CapabilityNegotiator::default();
1676
1677 let server_caps = ServerCapabilities::builder()
1678 .enable_tools()
1679 .with_negotiator(negotiator.clone())
1680 .with_strict_validation()
1681 .build();
1682
1683 assert!(server_caps.tools.is_some());
1684 }
1687
1688 #[test]
1689 fn test_builder_validation_methods() {
1690 let server_builder = ServerCapabilities::builder()
1692 .enable_experimental()
1693 .enable_tools()
1694 .with_simd_optimization("avx2")
1695 .with_enterprise_security(true)
1696 .with_strict_validation();
1697
1698 assert!(server_builder.validate().is_ok());
1700
1701 let summary = server_builder.summary();
1703 assert!(summary.contains("experimental"));
1704 assert!(summary.contains("tools"));
1705
1706 let client_builder = ClientCapabilities::builder()
1708 .enable_experimental()
1709 .enable_sampling()
1710 .with_llm_provider("openai", "gpt-4")
1711 .with_ui_capabilities(vec!["form", "dialog"])
1712 .with_strict_validation();
1713
1714 assert!(client_builder.validate().is_ok());
1716
1717 let summary = client_builder.summary();
1719 assert!(summary.contains("experimental"));
1720 assert!(summary.contains("sampling"));
1721 }
1722
1723 #[test]
1724 fn test_builder_validation_errors() {
1725 let server_builder = ServerCapabilities::builder()
1727 .enable_experimental()
1728 .with_strict_validation();
1729
1730 assert!(server_builder.validate().is_err());
1732 let error = server_builder.validate().unwrap_err();
1733 assert!(error.contains("at least one capability"));
1734
1735 let invalid_server_builder = ServerCapabilities::builder()
1737 .enable_experimental()
1738 .enable_tools()
1739 .add_experimental_capability("turbomcp_simd_level", "invalid_level")
1740 .with_strict_validation();
1741
1742 assert!(invalid_server_builder.validate().is_err());
1743 let error = invalid_server_builder.validate().unwrap_err();
1744 assert!(error.contains("Invalid SIMD level"));
1745
1746 let invalid_client_builder = ClientCapabilities::builder()
1748 .enable_experimental()
1749 .enable_sampling()
1750 .add_experimental_capability("turbomcp_ui_capabilities", vec!["invalid_capability"])
1751 .with_strict_validation();
1752
1753 assert!(invalid_client_builder.validate().is_err());
1754 let error = invalid_client_builder.validate().unwrap_err();
1755 assert!(error.contains("Invalid UI capability"));
1756 }
1757
1758 #[test]
1759 fn test_builder_clone_support() {
1760 let original_server_builder = ServerCapabilities::builder()
1762 .enable_tools()
1763 .enable_prompts();
1764
1765 let cloned_server_builder = original_server_builder.clone();
1766
1767 let original_caps = original_server_builder.build();
1769 let cloned_caps = cloned_server_builder.build();
1770
1771 assert_eq!(original_caps.tools.is_some(), cloned_caps.tools.is_some());
1772 assert_eq!(
1773 original_caps.prompts.is_some(),
1774 cloned_caps.prompts.is_some()
1775 );
1776
1777 let original_client_builder = ClientCapabilities::builder()
1779 .enable_sampling()
1780 .enable_elicitation();
1781
1782 let cloned_client_builder = original_client_builder.clone();
1783
1784 let original_caps = original_client_builder.build();
1785 let cloned_caps = cloned_client_builder.build();
1786
1787 assert_eq!(
1788 original_caps.sampling.is_some(),
1789 cloned_caps.sampling.is_some()
1790 );
1791 assert_eq!(
1792 original_caps.elicitation.is_some(),
1793 cloned_caps.elicitation.is_some()
1794 );
1795 }
1796 }
1797}