vtcode_core/config/
mcp.rs

1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use std::collections::{BTreeMap, HashMap};
4
5/// Top-level MCP configuration
6#[derive(Debug, Clone, Deserialize, Serialize)]
7pub struct McpClientConfig {
8    /// Enable MCP functionality
9    #[serde(default = "default_mcp_enabled")]
10    pub enabled: bool,
11
12    /// MCP UI display configuration
13    #[serde(default)]
14    pub ui: McpUiConfig,
15
16    /// Configured MCP providers
17    #[serde(default)]
18    pub providers: Vec<McpProviderConfig>,
19
20    /// MCP server configuration (for vtcode to expose tools)
21    #[serde(default)]
22    pub server: McpServerConfig,
23
24    /// Allow list configuration for MCP access control
25    #[serde(default)]
26    pub allowlist: McpAllowListConfig,
27
28    /// Maximum number of concurrent MCP connections
29    #[serde(default = "default_max_concurrent_connections")]
30    pub max_concurrent_connections: usize,
31
32    /// Request timeout in seconds
33    #[serde(default = "default_request_timeout_seconds")]
34    pub request_timeout_seconds: u64,
35
36    /// Connection retry attempts
37    #[serde(default = "default_retry_attempts")]
38    pub retry_attempts: u32,
39}
40
41impl Default for McpClientConfig {
42    fn default() -> Self {
43        Self {
44            enabled: default_mcp_enabled(),
45            ui: McpUiConfig::default(),
46            providers: Vec::new(),
47            server: McpServerConfig::default(),
48            allowlist: McpAllowListConfig::default(),
49            max_concurrent_connections: default_max_concurrent_connections(),
50            request_timeout_seconds: default_request_timeout_seconds(),
51            retry_attempts: default_retry_attempts(),
52        }
53    }
54}
55
56/// UI configuration for MCP display
57#[derive(Debug, Clone, Deserialize, Serialize)]
58pub struct McpUiConfig {
59    /// UI mode for MCP events: "compact" or "full"
60    #[serde(default = "default_mcp_ui_mode")]
61    pub mode: McpUiMode,
62
63    /// Maximum number of MCP events to display
64    #[serde(default = "default_max_mcp_events")]
65    pub max_events: usize,
66
67    /// Show MCP provider names in UI
68    #[serde(default = "default_show_provider_names")]
69    pub show_provider_names: bool,
70}
71
72impl Default for McpUiConfig {
73    fn default() -> Self {
74        Self {
75            mode: default_mcp_ui_mode(),
76            max_events: default_max_mcp_events(),
77            show_provider_names: default_show_provider_names(),
78        }
79    }
80}
81
82/// UI mode for MCP event display
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
84#[serde(rename_all = "snake_case")]
85pub enum McpUiMode {
86    /// Compact mode - shows only event titles
87    Compact,
88    /// Full mode - shows detailed event logs
89    Full,
90}
91
92impl std::fmt::Display for McpUiMode {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        match self {
95            McpUiMode::Compact => write!(f, "compact"),
96            McpUiMode::Full => write!(f, "full"),
97        }
98    }
99}
100
101impl Default for McpUiMode {
102    fn default() -> Self {
103        McpUiMode::Compact
104    }
105}
106
107/// Configuration for a single MCP provider
108#[derive(Debug, Clone, Deserialize, Serialize)]
109pub struct McpProviderConfig {
110    /// Provider name (used for identification)
111    pub name: String,
112
113    /// Transport configuration
114    #[serde(flatten)]
115    pub transport: McpTransportConfig,
116
117    /// Provider-specific environment variables
118    #[serde(default)]
119    pub env: HashMap<String, String>,
120
121    /// Whether this provider is enabled
122    #[serde(default = "default_provider_enabled")]
123    pub enabled: bool,
124
125    /// Maximum number of concurrent requests to this provider
126    #[serde(default = "default_provider_max_concurrent")]
127    pub max_concurrent_requests: usize,
128}
129
130impl Default for McpProviderConfig {
131    fn default() -> Self {
132        Self {
133            name: String::new(),
134            transport: McpTransportConfig::Stdio(McpStdioServerConfig::default()),
135            env: HashMap::new(),
136            enabled: default_provider_enabled(),
137            max_concurrent_requests: default_provider_max_concurrent(),
138        }
139    }
140}
141
142/// Allow list configuration for MCP providers
143#[derive(Debug, Clone, Deserialize, Serialize)]
144pub struct McpAllowListConfig {
145    /// Whether to enforce allow list checks
146    #[serde(default = "default_allowlist_enforced")]
147    pub enforce: bool,
148
149    /// Default rules applied when provider-specific rules are absent
150    #[serde(default)]
151    pub default: McpAllowListRules,
152
153    /// Provider-specific allow list rules
154    #[serde(default)]
155    pub providers: BTreeMap<String, McpAllowListRules>,
156}
157
158impl Default for McpAllowListConfig {
159    fn default() -> Self {
160        Self {
161            enforce: default_allowlist_enforced(),
162            default: McpAllowListRules::default(),
163            providers: BTreeMap::new(),
164        }
165    }
166}
167
168impl McpAllowListConfig {
169    /// Determine whether a tool is permitted for the given provider
170    pub fn is_tool_allowed(&self, provider: &str, tool_name: &str) -> bool {
171        if !self.enforce {
172            return true;
173        }
174
175        self.resolve_match(provider, tool_name, |rules| &rules.tools)
176    }
177
178    /// Determine whether a resource is permitted for the given provider
179    pub fn is_resource_allowed(&self, provider: &str, resource: &str) -> bool {
180        if !self.enforce {
181            return true;
182        }
183
184        self.resolve_match(provider, resource, |rules| &rules.resources)
185    }
186
187    /// Determine whether a prompt is permitted for the given provider
188    pub fn is_prompt_allowed(&self, provider: &str, prompt: &str) -> bool {
189        if !self.enforce {
190            return true;
191        }
192
193        self.resolve_match(provider, prompt, |rules| &rules.prompts)
194    }
195
196    /// Determine whether a logging channel is permitted
197    pub fn is_logging_channel_allowed(&self, provider: Option<&str>, channel: &str) -> bool {
198        if !self.enforce {
199            return true;
200        }
201
202        if let Some(name) = provider {
203            if let Some(rules) = self.providers.get(name) {
204                if let Some(patterns) = &rules.logging {
205                    return pattern_matches(patterns, channel);
206                }
207            }
208        }
209
210        if let Some(patterns) = &self.default.logging {
211            if pattern_matches(patterns, channel) {
212                return true;
213            }
214        }
215
216        false
217    }
218
219    /// Determine whether a configuration key can be modified
220    pub fn is_configuration_allowed(
221        &self,
222        provider: Option<&str>,
223        category: &str,
224        key: &str,
225    ) -> bool {
226        if !self.enforce {
227            return true;
228        }
229
230        if let Some(name) = provider {
231            if let Some(rules) = self.providers.get(name) {
232                if let Some(result) = configuration_allowed(rules, category, key) {
233                    return result;
234                }
235            }
236        }
237
238        if let Some(result) = configuration_allowed(&self.default, category, key) {
239            return result;
240        }
241
242        false
243    }
244
245    fn resolve_match<'a, F>(&'a self, provider: &str, candidate: &str, accessor: F) -> bool
246    where
247        F: Fn(&'a McpAllowListRules) -> &'a Option<Vec<String>>,
248    {
249        if let Some(rules) = self.providers.get(provider) {
250            if let Some(patterns) = accessor(rules) {
251                return pattern_matches(patterns, candidate);
252            }
253        }
254
255        if let Some(patterns) = accessor(&self.default) {
256            if pattern_matches(patterns, candidate) {
257                return true;
258            }
259        }
260
261        false
262    }
263}
264
265fn configuration_allowed(rules: &McpAllowListRules, category: &str, key: &str) -> Option<bool> {
266    rules.configuration.as_ref().and_then(|entries| {
267        entries
268            .get(category)
269            .map(|patterns| pattern_matches(patterns, key))
270    })
271}
272
273fn pattern_matches(patterns: &[String], candidate: &str) -> bool {
274    patterns
275        .iter()
276        .any(|pattern| wildcard_match(pattern, candidate))
277}
278
279fn wildcard_match(pattern: &str, candidate: &str) -> bool {
280    if pattern == "*" {
281        return true;
282    }
283
284    let mut regex_pattern = String::from("^");
285    let mut literal_buffer = String::new();
286
287    for ch in pattern.chars() {
288        match ch {
289            '*' => {
290                if !literal_buffer.is_empty() {
291                    regex_pattern.push_str(&regex::escape(&literal_buffer));
292                    literal_buffer.clear();
293                }
294                regex_pattern.push_str(".*");
295            }
296            '?' => {
297                if !literal_buffer.is_empty() {
298                    regex_pattern.push_str(&regex::escape(&literal_buffer));
299                    literal_buffer.clear();
300                }
301                regex_pattern.push('.');
302            }
303            _ => literal_buffer.push(ch),
304        }
305    }
306
307    if !literal_buffer.is_empty() {
308        regex_pattern.push_str(&regex::escape(&literal_buffer));
309    }
310
311    regex_pattern.push('$');
312
313    Regex::new(&regex_pattern)
314        .map(|regex| regex.is_match(candidate))
315        .unwrap_or(false)
316}
317
318/// Allow list rules for a provider or default configuration
319#[derive(Debug, Clone, Deserialize, Serialize, Default)]
320pub struct McpAllowListRules {
321    /// Tool name patterns permitted for the provider
322    #[serde(default)]
323    pub tools: Option<Vec<String>>,
324
325    /// Resource name patterns permitted for the provider
326    #[serde(default)]
327    pub resources: Option<Vec<String>>,
328
329    /// Prompt name patterns permitted for the provider
330    #[serde(default)]
331    pub prompts: Option<Vec<String>>,
332
333    /// Logging channels permitted for the provider
334    #[serde(default)]
335    pub logging: Option<Vec<String>>,
336
337    /// Configuration keys permitted for the provider grouped by category
338    #[serde(default)]
339    pub configuration: Option<BTreeMap<String, Vec<String>>>,
340}
341
342/// Configuration for the MCP server (vtcode acting as an MCP server)
343#[derive(Debug, Clone, Deserialize, Serialize)]
344pub struct McpServerConfig {
345    /// Enable vtcode's MCP server capability
346    #[serde(default = "default_mcp_server_enabled")]
347    pub enabled: bool,
348
349    /// Bind address for the MCP server
350    #[serde(default = "default_mcp_server_bind")]
351    pub bind_address: String,
352
353    /// Port for the MCP server
354    #[serde(default = "default_mcp_server_port")]
355    pub port: u16,
356
357    /// Server transport type
358    #[serde(default = "default_mcp_server_transport")]
359    pub transport: McpServerTransport,
360
361    /// Server identifier
362    #[serde(default = "default_mcp_server_name")]
363    pub name: String,
364
365    /// Server version
366    #[serde(default = "default_mcp_server_version")]
367    pub version: String,
368
369    /// Tools exposed by the vtcode MCP server
370    #[serde(default)]
371    pub exposed_tools: Vec<String>,
372}
373
374impl Default for McpServerConfig {
375    fn default() -> Self {
376        Self {
377            enabled: default_mcp_server_enabled(),
378            bind_address: default_mcp_server_bind(),
379            port: default_mcp_server_port(),
380            transport: default_mcp_server_transport(),
381            name: default_mcp_server_name(),
382            version: default_mcp_server_version(),
383            exposed_tools: Vec::new(),
384        }
385    }
386}
387
388/// MCP server transport types
389#[derive(Debug, Clone, Deserialize, Serialize)]
390#[serde(rename_all = "snake_case")]
391pub enum McpServerTransport {
392    /// Server Sent Events transport
393    Sse,
394    /// HTTP transport
395    Http,
396}
397
398impl Default for McpServerTransport {
399    fn default() -> Self {
400        McpServerTransport::Sse
401    }
402}
403
404/// Transport configuration for MCP providers
405#[derive(Debug, Clone, Deserialize, Serialize)]
406#[serde(untagged)]
407pub enum McpTransportConfig {
408    /// Standard I/O transport (stdio)
409    Stdio(McpStdioServerConfig),
410    /// HTTP transport
411    Http(McpHttpServerConfig),
412}
413
414/// Configuration for stdio-based MCP servers
415#[derive(Debug, Clone, Deserialize, Serialize)]
416pub struct McpStdioServerConfig {
417    /// Command to execute
418    pub command: String,
419
420    /// Command arguments
421    pub args: Vec<String>,
422
423    /// Working directory for the command
424    #[serde(default)]
425    pub working_directory: Option<String>,
426}
427
428impl Default for McpStdioServerConfig {
429    fn default() -> Self {
430        Self {
431            command: String::new(),
432            args: Vec::new(),
433            working_directory: None,
434        }
435    }
436}
437
438/// Configuration for HTTP-based MCP servers
439///
440/// Note: HTTP transport is partially implemented. Basic connectivity testing is supported,
441/// but full streamable HTTP MCP server support requires additional implementation
442/// using Server-Sent Events (SSE) or WebSocket connections.
443#[derive(Debug, Clone, Deserialize, Serialize)]
444pub struct McpHttpServerConfig {
445    /// Server endpoint URL
446    pub endpoint: String,
447
448    /// API key environment variable name
449    #[serde(default)]
450    pub api_key_env: Option<String>,
451
452    /// Protocol version
453    #[serde(default = "default_mcp_protocol_version")]
454    pub protocol_version: String,
455
456    /// Headers to include in requests
457    #[serde(default)]
458    pub headers: HashMap<String, String>,
459}
460
461impl Default for McpHttpServerConfig {
462    fn default() -> Self {
463        Self {
464            endpoint: String::new(),
465            api_key_env: None,
466            protocol_version: default_mcp_protocol_version(),
467            headers: HashMap::new(),
468        }
469    }
470}
471
472/// Default value functions
473fn default_mcp_enabled() -> bool {
474    false
475}
476
477fn default_mcp_ui_mode() -> McpUiMode {
478    McpUiMode::Compact
479}
480
481fn default_max_mcp_events() -> usize {
482    50
483}
484
485fn default_show_provider_names() -> bool {
486    true
487}
488
489fn default_max_concurrent_connections() -> usize {
490    5
491}
492
493fn default_request_timeout_seconds() -> u64 {
494    30
495}
496
497fn default_retry_attempts() -> u32 {
498    3
499}
500
501fn default_provider_enabled() -> bool {
502    true
503}
504
505fn default_provider_max_concurrent() -> usize {
506    3
507}
508
509fn default_allowlist_enforced() -> bool {
510    false
511}
512
513fn default_mcp_protocol_version() -> String {
514    "2024-11-05".to_string()
515}
516
517fn default_mcp_server_enabled() -> bool {
518    false
519}
520
521fn default_mcp_server_bind() -> String {
522    "127.0.0.1".to_string()
523}
524
525fn default_mcp_server_port() -> u16 {
526    3000
527}
528
529fn default_mcp_server_transport() -> McpServerTransport {
530    McpServerTransport::Sse
531}
532
533fn default_mcp_server_name() -> String {
534    "vtcode-mcp-server".to_string()
535}
536
537fn default_mcp_server_version() -> String {
538    env!("CARGO_PKG_VERSION").to_string()
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544
545    #[test]
546    fn test_mcp_config_defaults() {
547        let config = McpClientConfig::default();
548        assert!(!config.enabled);
549        assert_eq!(config.ui.mode, McpUiMode::Compact);
550        assert_eq!(config.ui.max_events, 50);
551        assert!(config.ui.show_provider_names);
552        assert_eq!(config.max_concurrent_connections, 5);
553        assert_eq!(config.request_timeout_seconds, 30);
554        assert_eq!(config.retry_attempts, 3);
555        assert!(config.providers.is_empty());
556        assert!(!config.server.enabled);
557        assert!(!config.allowlist.enforce);
558        assert!(config.allowlist.default.tools.is_none());
559    }
560
561    #[test]
562    fn test_allowlist_pattern_matching() {
563        let patterns = vec!["get_*".to_string(), "convert_timezone".to_string()];
564        assert!(pattern_matches(&patterns, "get_current_time"));
565        assert!(pattern_matches(&patterns, "convert_timezone"));
566        assert!(!pattern_matches(&patterns, "delete_timezone"));
567    }
568
569    #[test]
570    fn test_allowlist_provider_override() {
571        let mut config = McpAllowListConfig::default();
572        config.enforce = true;
573        config.default.tools = Some(vec!["get_*".to_string()]);
574
575        let mut provider_rules = McpAllowListRules::default();
576        provider_rules.tools = Some(vec!["list_*".to_string()]);
577        config
578            .providers
579            .insert("context7".to_string(), provider_rules);
580
581        assert!(config.is_tool_allowed("context7", "list_documents"));
582        assert!(!config.is_tool_allowed("context7", "get_current_time"));
583        assert!(config.is_tool_allowed("other", "get_timezone"));
584        assert!(!config.is_tool_allowed("other", "list_documents"));
585    }
586
587    #[test]
588    fn test_allowlist_configuration_rules() {
589        let mut config = McpAllowListConfig::default();
590        config.enforce = true;
591
592        let mut default_rules = McpAllowListRules::default();
593        default_rules.configuration = Some(HashMap::from([(
594            "ui".to_string(),
595            vec!["mode".to_string(), "max_events".to_string()],
596        )]));
597        config.default = default_rules;
598
599        let mut provider_rules = McpAllowListRules::default();
600        provider_rules.configuration = Some(HashMap::from([(
601            "provider".to_string(),
602            vec!["max_concurrent_requests".to_string()],
603        )]));
604        config.providers.insert("time".to_string(), provider_rules);
605
606        assert!(config.is_configuration_allowed(None, "ui", "mode"));
607        assert!(!config.is_configuration_allowed(None, "ui", "show_provider_names"));
608        assert!(config.is_configuration_allowed(
609            Some("time"),
610            "provider",
611            "max_concurrent_requests"
612        ));
613        assert!(!config.is_configuration_allowed(Some("time"), "provider", "retry_attempts"));
614    }
615
616    #[test]
617    fn test_allowlist_resource_override() {
618        let mut config = McpAllowListConfig::default();
619        config.enforce = true;
620        config.default.resources = Some(vec!["docs/*".to_string()]);
621
622        let mut provider_rules = McpAllowListRules::default();
623        provider_rules.resources = Some(vec!["journals/*".to_string()]);
624        config
625            .providers
626            .insert("context7".to_string(), provider_rules);
627
628        assert!(config.is_resource_allowed("context7", "journals/2024"));
629        assert!(!config.is_resource_allowed("context7", "docs/manual"));
630        assert!(config.is_resource_allowed("other", "docs/reference"));
631        assert!(!config.is_resource_allowed("other", "journals/2023"));
632    }
633
634    #[test]
635    fn test_allowlist_logging_override() {
636        let mut config = McpAllowListConfig::default();
637        config.enforce = true;
638        config.default.logging = Some(vec!["info".to_string(), "debug".to_string()]);
639
640        let mut provider_rules = McpAllowListRules::default();
641        provider_rules.logging = Some(vec!["audit".to_string()]);
642        config
643            .providers
644            .insert("sequential".to_string(), provider_rules);
645
646        assert!(config.is_logging_channel_allowed(Some("sequential"), "audit"));
647        assert!(!config.is_logging_channel_allowed(Some("sequential"), "info"));
648        assert!(config.is_logging_channel_allowed(Some("other"), "info"));
649        assert!(!config.is_logging_channel_allowed(Some("other"), "trace"));
650    }
651}