1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use std::collections::{BTreeMap, HashMap};
4
5#[derive(Debug, Clone, Deserialize, Serialize)]
7pub struct McpClientConfig {
8 #[serde(default = "default_mcp_enabled")]
10 pub enabled: bool,
11
12 #[serde(default)]
14 pub ui: McpUiConfig,
15
16 #[serde(default)]
18 pub providers: Vec<McpProviderConfig>,
19
20 #[serde(default)]
22 pub server: McpServerConfig,
23
24 #[serde(default)]
26 pub allowlist: McpAllowListConfig,
27
28 #[serde(default = "default_max_concurrent_connections")]
30 pub max_concurrent_connections: usize,
31
32 #[serde(default = "default_request_timeout_seconds")]
34 pub request_timeout_seconds: u64,
35
36 #[serde(default = "default_retry_attempts")]
38 pub retry_attempts: u32,
39}
40
41impl Default for McpClientConfig {
42 fn default() -> Self {
43 Self {
44 enabled: default_mcp_enabled(),
45 ui: McpUiConfig::default(),
46 providers: Vec::new(),
47 server: McpServerConfig::default(),
48 allowlist: McpAllowListConfig::default(),
49 max_concurrent_connections: default_max_concurrent_connections(),
50 request_timeout_seconds: default_request_timeout_seconds(),
51 retry_attempts: default_retry_attempts(),
52 }
53 }
54}
55
56#[derive(Debug, Clone, Deserialize, Serialize)]
58pub struct McpUiConfig {
59 #[serde(default = "default_mcp_ui_mode")]
61 pub mode: McpUiMode,
62
63 #[serde(default = "default_max_mcp_events")]
65 pub max_events: usize,
66
67 #[serde(default = "default_show_provider_names")]
69 pub show_provider_names: bool,
70}
71
72impl Default for McpUiConfig {
73 fn default() -> Self {
74 Self {
75 mode: default_mcp_ui_mode(),
76 max_events: default_max_mcp_events(),
77 show_provider_names: default_show_provider_names(),
78 }
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
84#[serde(rename_all = "snake_case")]
85pub enum McpUiMode {
86 Compact,
88 Full,
90}
91
92impl std::fmt::Display for McpUiMode {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 match self {
95 McpUiMode::Compact => write!(f, "compact"),
96 McpUiMode::Full => write!(f, "full"),
97 }
98 }
99}
100
101impl Default for McpUiMode {
102 fn default() -> Self {
103 McpUiMode::Compact
104 }
105}
106
107#[derive(Debug, Clone, Deserialize, Serialize)]
109pub struct McpProviderConfig {
110 pub name: String,
112
113 #[serde(flatten)]
115 pub transport: McpTransportConfig,
116
117 #[serde(default)]
119 pub env: HashMap<String, String>,
120
121 #[serde(default = "default_provider_enabled")]
123 pub enabled: bool,
124
125 #[serde(default = "default_provider_max_concurrent")]
127 pub max_concurrent_requests: usize,
128}
129
130impl Default for McpProviderConfig {
131 fn default() -> Self {
132 Self {
133 name: String::new(),
134 transport: McpTransportConfig::Stdio(McpStdioServerConfig::default()),
135 env: HashMap::new(),
136 enabled: default_provider_enabled(),
137 max_concurrent_requests: default_provider_max_concurrent(),
138 }
139 }
140}
141
142#[derive(Debug, Clone, Deserialize, Serialize)]
144pub struct McpAllowListConfig {
145 #[serde(default = "default_allowlist_enforced")]
147 pub enforce: bool,
148
149 #[serde(default)]
151 pub default: McpAllowListRules,
152
153 #[serde(default)]
155 pub providers: BTreeMap<String, McpAllowListRules>,
156}
157
158impl Default for McpAllowListConfig {
159 fn default() -> Self {
160 Self {
161 enforce: default_allowlist_enforced(),
162 default: McpAllowListRules::default(),
163 providers: BTreeMap::new(),
164 }
165 }
166}
167
168impl McpAllowListConfig {
169 pub fn is_tool_allowed(&self, provider: &str, tool_name: &str) -> bool {
171 if !self.enforce {
172 return true;
173 }
174
175 self.resolve_match(provider, tool_name, |rules| &rules.tools)
176 }
177
178 pub fn is_resource_allowed(&self, provider: &str, resource: &str) -> bool {
180 if !self.enforce {
181 return true;
182 }
183
184 self.resolve_match(provider, resource, |rules| &rules.resources)
185 }
186
187 pub fn is_prompt_allowed(&self, provider: &str, prompt: &str) -> bool {
189 if !self.enforce {
190 return true;
191 }
192
193 self.resolve_match(provider, prompt, |rules| &rules.prompts)
194 }
195
196 pub fn is_logging_channel_allowed(&self, provider: Option<&str>, channel: &str) -> bool {
198 if !self.enforce {
199 return true;
200 }
201
202 if let Some(name) = provider {
203 if let Some(rules) = self.providers.get(name) {
204 if let Some(patterns) = &rules.logging {
205 return pattern_matches(patterns, channel);
206 }
207 }
208 }
209
210 if let Some(patterns) = &self.default.logging {
211 if pattern_matches(patterns, channel) {
212 return true;
213 }
214 }
215
216 false
217 }
218
219 pub fn is_configuration_allowed(
221 &self,
222 provider: Option<&str>,
223 category: &str,
224 key: &str,
225 ) -> bool {
226 if !self.enforce {
227 return true;
228 }
229
230 if let Some(name) = provider {
231 if let Some(rules) = self.providers.get(name) {
232 if let Some(result) = configuration_allowed(rules, category, key) {
233 return result;
234 }
235 }
236 }
237
238 if let Some(result) = configuration_allowed(&self.default, category, key) {
239 return result;
240 }
241
242 false
243 }
244
245 fn resolve_match<'a, F>(&'a self, provider: &str, candidate: &str, accessor: F) -> bool
246 where
247 F: Fn(&'a McpAllowListRules) -> &'a Option<Vec<String>>,
248 {
249 if let Some(rules) = self.providers.get(provider) {
250 if let Some(patterns) = accessor(rules) {
251 return pattern_matches(patterns, candidate);
252 }
253 }
254
255 if let Some(patterns) = accessor(&self.default) {
256 if pattern_matches(patterns, candidate) {
257 return true;
258 }
259 }
260
261 false
262 }
263}
264
265fn configuration_allowed(rules: &McpAllowListRules, category: &str, key: &str) -> Option<bool> {
266 rules.configuration.as_ref().and_then(|entries| {
267 entries
268 .get(category)
269 .map(|patterns| pattern_matches(patterns, key))
270 })
271}
272
273fn pattern_matches(patterns: &[String], candidate: &str) -> bool {
274 patterns
275 .iter()
276 .any(|pattern| wildcard_match(pattern, candidate))
277}
278
279fn wildcard_match(pattern: &str, candidate: &str) -> bool {
280 if pattern == "*" {
281 return true;
282 }
283
284 let mut regex_pattern = String::from("^");
285 let mut literal_buffer = String::new();
286
287 for ch in pattern.chars() {
288 match ch {
289 '*' => {
290 if !literal_buffer.is_empty() {
291 regex_pattern.push_str(®ex::escape(&literal_buffer));
292 literal_buffer.clear();
293 }
294 regex_pattern.push_str(".*");
295 }
296 '?' => {
297 if !literal_buffer.is_empty() {
298 regex_pattern.push_str(®ex::escape(&literal_buffer));
299 literal_buffer.clear();
300 }
301 regex_pattern.push('.');
302 }
303 _ => literal_buffer.push(ch),
304 }
305 }
306
307 if !literal_buffer.is_empty() {
308 regex_pattern.push_str(®ex::escape(&literal_buffer));
309 }
310
311 regex_pattern.push('$');
312
313 Regex::new(®ex_pattern)
314 .map(|regex| regex.is_match(candidate))
315 .unwrap_or(false)
316}
317
318#[derive(Debug, Clone, Deserialize, Serialize, Default)]
320pub struct McpAllowListRules {
321 #[serde(default)]
323 pub tools: Option<Vec<String>>,
324
325 #[serde(default)]
327 pub resources: Option<Vec<String>>,
328
329 #[serde(default)]
331 pub prompts: Option<Vec<String>>,
332
333 #[serde(default)]
335 pub logging: Option<Vec<String>>,
336
337 #[serde(default)]
339 pub configuration: Option<BTreeMap<String, Vec<String>>>,
340}
341
342#[derive(Debug, Clone, Deserialize, Serialize)]
344pub struct McpServerConfig {
345 #[serde(default = "default_mcp_server_enabled")]
347 pub enabled: bool,
348
349 #[serde(default = "default_mcp_server_bind")]
351 pub bind_address: String,
352
353 #[serde(default = "default_mcp_server_port")]
355 pub port: u16,
356
357 #[serde(default = "default_mcp_server_transport")]
359 pub transport: McpServerTransport,
360
361 #[serde(default = "default_mcp_server_name")]
363 pub name: String,
364
365 #[serde(default = "default_mcp_server_version")]
367 pub version: String,
368
369 #[serde(default)]
371 pub exposed_tools: Vec<String>,
372}
373
374impl Default for McpServerConfig {
375 fn default() -> Self {
376 Self {
377 enabled: default_mcp_server_enabled(),
378 bind_address: default_mcp_server_bind(),
379 port: default_mcp_server_port(),
380 transport: default_mcp_server_transport(),
381 name: default_mcp_server_name(),
382 version: default_mcp_server_version(),
383 exposed_tools: Vec::new(),
384 }
385 }
386}
387
388#[derive(Debug, Clone, Deserialize, Serialize)]
390#[serde(rename_all = "snake_case")]
391pub enum McpServerTransport {
392 Sse,
394 Http,
396}
397
398impl Default for McpServerTransport {
399 fn default() -> Self {
400 McpServerTransport::Sse
401 }
402}
403
404#[derive(Debug, Clone, Deserialize, Serialize)]
406#[serde(untagged)]
407pub enum McpTransportConfig {
408 Stdio(McpStdioServerConfig),
410 Http(McpHttpServerConfig),
412}
413
414#[derive(Debug, Clone, Deserialize, Serialize)]
416pub struct McpStdioServerConfig {
417 pub command: String,
419
420 pub args: Vec<String>,
422
423 #[serde(default)]
425 pub working_directory: Option<String>,
426}
427
428impl Default for McpStdioServerConfig {
429 fn default() -> Self {
430 Self {
431 command: String::new(),
432 args: Vec::new(),
433 working_directory: None,
434 }
435 }
436}
437
438#[derive(Debug, Clone, Deserialize, Serialize)]
444pub struct McpHttpServerConfig {
445 pub endpoint: String,
447
448 #[serde(default)]
450 pub api_key_env: Option<String>,
451
452 #[serde(default = "default_mcp_protocol_version")]
454 pub protocol_version: String,
455
456 #[serde(default)]
458 pub headers: HashMap<String, String>,
459}
460
461impl Default for McpHttpServerConfig {
462 fn default() -> Self {
463 Self {
464 endpoint: String::new(),
465 api_key_env: None,
466 protocol_version: default_mcp_protocol_version(),
467 headers: HashMap::new(),
468 }
469 }
470}
471
472fn default_mcp_enabled() -> bool {
474 false
475}
476
477fn default_mcp_ui_mode() -> McpUiMode {
478 McpUiMode::Compact
479}
480
481fn default_max_mcp_events() -> usize {
482 50
483}
484
485fn default_show_provider_names() -> bool {
486 true
487}
488
489fn default_max_concurrent_connections() -> usize {
490 5
491}
492
493fn default_request_timeout_seconds() -> u64 {
494 30
495}
496
497fn default_retry_attempts() -> u32 {
498 3
499}
500
501fn default_provider_enabled() -> bool {
502 true
503}
504
505fn default_provider_max_concurrent() -> usize {
506 3
507}
508
509fn default_allowlist_enforced() -> bool {
510 false
511}
512
513fn default_mcp_protocol_version() -> String {
514 "2024-11-05".to_string()
515}
516
517fn default_mcp_server_enabled() -> bool {
518 false
519}
520
521fn default_mcp_server_bind() -> String {
522 "127.0.0.1".to_string()
523}
524
525fn default_mcp_server_port() -> u16 {
526 3000
527}
528
529fn default_mcp_server_transport() -> McpServerTransport {
530 McpServerTransport::Sse
531}
532
533fn default_mcp_server_name() -> String {
534 "vtcode-mcp-server".to_string()
535}
536
537fn default_mcp_server_version() -> String {
538 env!("CARGO_PKG_VERSION").to_string()
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544
545 #[test]
546 fn test_mcp_config_defaults() {
547 let config = McpClientConfig::default();
548 assert!(!config.enabled);
549 assert_eq!(config.ui.mode, McpUiMode::Compact);
550 assert_eq!(config.ui.max_events, 50);
551 assert!(config.ui.show_provider_names);
552 assert_eq!(config.max_concurrent_connections, 5);
553 assert_eq!(config.request_timeout_seconds, 30);
554 assert_eq!(config.retry_attempts, 3);
555 assert!(config.providers.is_empty());
556 assert!(!config.server.enabled);
557 assert!(!config.allowlist.enforce);
558 assert!(config.allowlist.default.tools.is_none());
559 }
560
561 #[test]
562 fn test_allowlist_pattern_matching() {
563 let patterns = vec!["get_*".to_string(), "convert_timezone".to_string()];
564 assert!(pattern_matches(&patterns, "get_current_time"));
565 assert!(pattern_matches(&patterns, "convert_timezone"));
566 assert!(!pattern_matches(&patterns, "delete_timezone"));
567 }
568
569 #[test]
570 fn test_allowlist_provider_override() {
571 let mut config = McpAllowListConfig::default();
572 config.enforce = true;
573 config.default.tools = Some(vec!["get_*".to_string()]);
574
575 let mut provider_rules = McpAllowListRules::default();
576 provider_rules.tools = Some(vec!["list_*".to_string()]);
577 config
578 .providers
579 .insert("context7".to_string(), provider_rules);
580
581 assert!(config.is_tool_allowed("context7", "list_documents"));
582 assert!(!config.is_tool_allowed("context7", "get_current_time"));
583 assert!(config.is_tool_allowed("other", "get_timezone"));
584 assert!(!config.is_tool_allowed("other", "list_documents"));
585 }
586
587 #[test]
588 fn test_allowlist_configuration_rules() {
589 let mut config = McpAllowListConfig::default();
590 config.enforce = true;
591
592 let mut default_rules = McpAllowListRules::default();
593 default_rules.configuration = Some(HashMap::from([(
594 "ui".to_string(),
595 vec!["mode".to_string(), "max_events".to_string()],
596 )]));
597 config.default = default_rules;
598
599 let mut provider_rules = McpAllowListRules::default();
600 provider_rules.configuration = Some(HashMap::from([(
601 "provider".to_string(),
602 vec!["max_concurrent_requests".to_string()],
603 )]));
604 config.providers.insert("time".to_string(), provider_rules);
605
606 assert!(config.is_configuration_allowed(None, "ui", "mode"));
607 assert!(!config.is_configuration_allowed(None, "ui", "show_provider_names"));
608 assert!(config.is_configuration_allowed(
609 Some("time"),
610 "provider",
611 "max_concurrent_requests"
612 ));
613 assert!(!config.is_configuration_allowed(Some("time"), "provider", "retry_attempts"));
614 }
615
616 #[test]
617 fn test_allowlist_resource_override() {
618 let mut config = McpAllowListConfig::default();
619 config.enforce = true;
620 config.default.resources = Some(vec!["docs/*".to_string()]);
621
622 let mut provider_rules = McpAllowListRules::default();
623 provider_rules.resources = Some(vec!["journals/*".to_string()]);
624 config
625 .providers
626 .insert("context7".to_string(), provider_rules);
627
628 assert!(config.is_resource_allowed("context7", "journals/2024"));
629 assert!(!config.is_resource_allowed("context7", "docs/manual"));
630 assert!(config.is_resource_allowed("other", "docs/reference"));
631 assert!(!config.is_resource_allowed("other", "journals/2023"));
632 }
633
634 #[test]
635 fn test_allowlist_logging_override() {
636 let mut config = McpAllowListConfig::default();
637 config.enforce = true;
638 config.default.logging = Some(vec!["info".to_string(), "debug".to_string()]);
639
640 let mut provider_rules = McpAllowListRules::default();
641 provider_rules.logging = Some(vec!["audit".to_string()]);
642 config
643 .providers
644 .insert("sequential".to_string(), provider_rules);
645
646 assert!(config.is_logging_channel_allowed(Some("sequential"), "audit"));
647 assert!(!config.is_logging_channel_allowed(Some("sequential"), "info"));
648 assert!(config.is_logging_channel_allowed(Some("other"), "info"));
649 assert!(!config.is_logging_channel_allowed(Some("other"), "trace"));
650 }
651}