Skip to main content

vtcode_core/mcp/
mod.rs

1//! MCP client management built on top of the Codex MCP building blocks.
2//!
3//! This module adapts the reference MCP client, server and type
4//! definitions from <https://github.com/openai/codex> to integrate them
5//! with VT Code's multi-provider configuration model. The original
6//! implementation inside this project had grown organically and mixed a
7//! large amount of bookkeeping logic with the lower level rmcp client
8//! transport. The rewritten version keeps the VT Code specific surface
9//! (allow lists, tool indexing, status reporting) but delegates the
10//! actual protocol interaction to a lightweight `RmcpClient` adapter
11//! that mirrors Codex' `mcp-client` crate. This dramatically reduces
12//! the amount of bespoke glue we have to maintain while aligning the
13//! behaviour with the upstream MCP implementations.
14
15use crate::config::mcp::McpClientConfig;
16
17pub mod cli;
18mod client;
19pub mod connection_pool;
20pub mod enhanced_config;
21pub mod errors;
22mod provider;
23mod rmcp_client;
24pub mod rmcp_transport;
25pub mod schema;
26pub mod tool_discovery;
27pub mod tool_discovery_cache;
28pub mod traits;
29pub mod types;
30pub mod utils;
31
32pub use client::McpClient;
33
34pub use connection_pool::{
35    ConnectionPoolStats, McpConnectionPool, McpPoolError, PooledMcpManager, PooledMcpStats,
36};
37pub use errors::{
38    ErrorCode, McpResult, configuration_error, initialization_timeout, provider_not_found,
39    provider_unavailable, schema_invalid, tool_invocation_failed, tool_not_found,
40};
41pub use provider::McpProvider;
42pub(crate) use rmcp_client::RmcpClient;
43pub use rmcp_transport::{
44    HttpTransport, create_http_transport, create_stdio_transport,
45    create_stdio_transport_with_stderr,
46};
47pub use schema::{validate_against_schema, validate_tool_input};
48pub use tool_discovery::{DetailLevel, ToolDiscovery, ToolDiscoveryResult};
49pub use traits::{McpElicitationHandler, McpToolExecutor};
50pub use types::{
51    FileParamSchemaEntry, FileUploadResult, McpClientStatus, McpElicitationRequest,
52    McpElicitationResponse, McpPromptDetail, McpPromptInfo, McpResourceData, McpResourceInfo,
53    McpToolInfo, OPENAI_FILE_PARAMS_META_KEY, OPENAI_FILE_PARAMS_VALUE, ProvidedFilePayload,
54};
55pub use utils::{
56    LOCAL_TIMEZONE_ENV_VAR, TIMEZONE_ARGUMENT, TZ_ENV_VAR, build_headers, detect_local_timezone,
57    ensure_timezone_argument, schema_requires_field,
58};
59
60use anyhow::{Result, anyhow};
61use hashbrown::HashMap;
62pub use rmcp::model::ElicitationAction;
63use std::ffi::OsString;
64use std::fmt::Write;
65
66/// MCP protocol version constants
67pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
68pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &[LATEST_PROTOCOL_VERSION];
69
70/// Convert any serializable type to rmcp model type via JSON serialization
71pub(crate) fn convert_to_rmcp<T, U>(value: T) -> Result<U>
72where
73    T: serde::Serialize,
74    U: serde::de::DeserializeOwned,
75{
76    let json = serde_json::to_value(value)?;
77    serde_json::from_value(json).map_err(|err| anyhow!(err))
78}
79
80fn create_env_for_mcp_server(
81    extra_env: Option<HashMap<OsString, OsString>>,
82) -> HashMap<OsString, OsString> {
83    DEFAULT_ENV_VARS
84        .iter()
85        .filter_map(|var| std::env::var_os(var).map(|value| (OsString::from(*var), value)))
86        .chain(extra_env.unwrap_or_default())
87        .collect()
88}
89
90/// Validate MCP configuration settings
91pub fn validate_mcp_config(config: &McpClientConfig) -> Result<()> {
92    // Validate server configuration if enabled
93    if config.server.enabled {
94        // Validate port range
95        if config.server.port == 0 {
96            return Err(anyhow::anyhow!(
97                "Invalid server port: {}",
98                config.server.port
99            ));
100        }
101
102        // Validate bind address
103        if config.server.bind_address.is_empty() {
104            return Err(anyhow::anyhow!("Server bind address cannot be empty"));
105        }
106
107        // Validate security settings if auth is enabled
108        if config.security.auth_enabled && config.security.api_key_env.is_none() {
109            return Err(anyhow::anyhow!(
110                "API key environment variable must be set when auth is enabled"
111            ));
112        }
113    }
114
115    // Validate timeouts
116    if let Some(startup_timeout) = config.startup_timeout_seconds
117        && startup_timeout > 300
118    {
119        // Max 5 minutes
120        return Err(anyhow::anyhow!("Startup timeout cannot exceed 300 seconds"));
121    }
122
123    if let Some(tool_timeout) = config.tool_timeout_seconds
124        && tool_timeout > 3600
125    {
126        // Max 1 hour
127        return Err(anyhow::anyhow!("Tool timeout cannot exceed 3600 seconds"));
128    }
129
130    // Validate provider configurations
131    for provider in &config.providers {
132        if provider.name.is_empty() {
133            return Err(anyhow::anyhow!("MCP provider name cannot be empty"));
134        }
135
136        // Validate max_concurrent_requests
137        if provider.max_concurrent_requests == 0 {
138            return Err(anyhow::anyhow!(
139                "Max concurrent requests must be greater than 0 for provider '{}'",
140                provider.name
141            ));
142        }
143    }
144
145    Ok(())
146}
147
148#[cfg(unix)]
149const DEFAULT_ENV_VARS: &[&str] = &[
150    "HOME",
151    "LOGNAME",
152    "PATH",
153    "SHELL",
154    "USER",
155    "__CF_USER_TEXT_ENCODING",
156    "LANG",
157    "LC_ALL",
158    "TERM",
159    "TMPDIR",
160    "TZ",
161];
162
163#[cfg(windows)]
164const DEFAULT_ENV_VARS: &[&str] = &[
165    // Core path resolution
166    "PATH",
167    "PATHEXT",
168    // Shell and system roots
169    "COMSPEC",
170    "SYSTEMROOT",
171    "SYSTEMDRIVE",
172    // User context and profiles
173    "USERNAME",
174    "USERDOMAIN",
175    "USERPROFILE",
176    "HOMEDRIVE",
177    "HOMEPATH",
178    // Program locations
179    "PROGRAMFILES",
180    "PROGRAMFILES(X86)",
181    "PROGRAMW6432",
182    "PROGRAMDATA",
183    // App data and caches
184    "LOCALAPPDATA",
185    "APPDATA",
186    // Temp locations
187    "TEMP",
188    "TMP",
189    // Common shells/pwsh hints
190    "POWERSHELL",
191    "PWSH",
192];
193
194// Helper functions for file-based tool discovery
195
196/// Sanitize a string for use in a filename
197fn sanitize_filename(name: &str) -> String {
198    name.chars()
199        .map(|c| {
200            if c.is_alphanumeric() || c == '_' || c == '-' {
201                c
202            } else {
203                '_'
204            }
205        })
206        .collect()
207}
208
209/// Format a tool description as Markdown
210fn format_tool_markdown(tool: &McpToolInfo) -> String {
211    let mut content = String::new();
212    let _ = write!(content, "# {}\n\n", tool.name);
213    let _ = write!(content, "**Provider**: {}\n\n", tool.provider);
214    content.push_str("## Description\n\n");
215    content.push_str(&tool.description);
216    content.push_str("\n\n");
217
218    content.push_str("## Input Schema\n\n");
219    content.push_str("```json\n");
220    content.push_str(
221        &serde_json::to_string_pretty(&tool.input_schema)
222            .unwrap_or_else(|_| tool.input_schema.to_string()),
223    );
224    content.push_str("\n```\n\n");
225
226    // Extract required fields if present
227    if let Some(obj) = tool.input_schema.as_object() {
228        if let Some(required) = obj.get("required").and_then(|v| v.as_array())
229            && !required.is_empty()
230        {
231            content.push_str("## Required Parameters\n\n");
232            for req in required {
233                if let Some(name) = req.as_str() {
234                    let _ = writeln!(content, "- `{}`", name);
235                }
236            }
237            content.push('\n');
238        }
239
240        // Extract properties descriptions
241        if let Some(props) = obj.get("properties").and_then(|v| v.as_object())
242            && !props.is_empty()
243        {
244            content.push_str("## Parameters\n\n");
245            for (param_name, param_schema) in props {
246                let param_type = param_schema
247                    .get("type")
248                    .and_then(|t| t.as_str())
249                    .unwrap_or("any");
250                let param_desc = param_schema
251                    .get("description")
252                    .and_then(|d| d.as_str())
253                    .unwrap_or("");
254                let _ = write!(content, "### `{}`\n\n", param_name);
255                let _ = writeln!(content, "- **Type**: {}", param_type);
256                if !param_desc.is_empty() {
257                    let _ = writeln!(content, "- **Description**: {}", param_desc);
258                }
259                content.push('\n');
260            }
261        }
262    }
263
264    content.push_str("---\n");
265    content.push_str("*Generated automatically for dynamic context discovery.*\n");
266
267    content
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::config::mcp::{McpProviderConfig, McpStdioServerConfig, McpTransportConfig};
274    use crate::mcp::rmcp_client::{
275        build_elicitation_validator, directory_to_file_uri, validate_elicitation_payload,
276    };
277    use crate::mcp::utils::{clear_test_env_override, set_test_env_override};
278    use hashbrown::HashMap;
279    use serde_json::{Map, Value, json};
280    // Re-export rmcp types for tests
281    use rmcp::model::{
282        ClientCapabilities, Implementation, InitializeRequestParams, RootsCapabilities,
283    };
284
285    #[cfg(unix)]
286    use serial_test::serial;
287
288    #[cfg(unix)]
289    use std::os::unix::ffi::OsStringExt;
290
291    struct EnvGuard {
292        key: &'static str,
293    }
294
295    impl EnvGuard {
296        fn set(key: &'static str, value: &str) -> Self {
297            set_test_env_override(key, Some(value));
298            Self { key }
299        }
300    }
301
302    impl Drop for EnvGuard {
303        fn drop(&mut self) {
304            clear_test_env_override(self.key);
305        }
306    }
307
308    #[test]
309    fn schema_detection_handles_required_entries() {
310        let schema = json!({
311            "type": "object",
312            "required": [TIMEZONE_ARGUMENT],
313            "properties": {
314                TIMEZONE_ARGUMENT: { "type": "string" }
315            }
316        });
317
318        assert!(schema_requires_field(&schema, TIMEZONE_ARGUMENT));
319        assert!(!schema_requires_field(&schema, "location"));
320    }
321
322    #[test]
323    fn ensure_timezone_injects_from_override_env() {
324        let _guard = EnvGuard::set(LOCAL_TIMEZONE_ENV_VAR, "Etc/UTC");
325        let mut arguments = Map::new();
326
327        ensure_timezone_argument(&mut arguments, true).unwrap();
328
329        assert_eq!(
330            arguments.get(TIMEZONE_ARGUMENT).and_then(Value::as_str),
331            Some("Etc/UTC")
332        );
333    }
334
335    #[test]
336    fn ensure_timezone_does_not_override_existing_value() {
337        let mut arguments = Map::new();
338        arguments.insert(
339            TIMEZONE_ARGUMENT.to_string(),
340            Value::String("America/New_York".to_owned()),
341        );
342
343        ensure_timezone_argument(&mut arguments, true).unwrap();
344
345        assert_eq!(
346            arguments.get(TIMEZONE_ARGUMENT).and_then(Value::as_str),
347            Some("America/New_York")
348        );
349    }
350
351    #[test]
352    fn create_env_merges_configured_values() {
353        let mut extra_env = HashMap::new();
354        extra_env.insert(OsString::from("A"), OsString::from("1"));
355        extra_env.insert(OsString::from("B"), OsString::from("2"));
356
357        let env = create_env_for_mcp_server(Some(extra_env));
358
359        assert_eq!(env.get(&OsString::from("A")), Some(&OsString::from("1")));
360        assert_eq!(env.get(&OsString::from("B")), Some(&OsString::from("2")));
361    }
362
363    #[test]
364    #[cfg(unix)]
365    #[serial]
366    fn create_env_preserves_non_utf8_path() {
367        let env_guard = vtcode_commons::env_lock::lock();
368        let original_path = std::env::var_os("PATH");
369        let non_utf8_path = OsString::from_vec(b"/tmp/alpha:\xFFbeta".to_vec());
370
371        env_guard.set_var("PATH", &non_utf8_path);
372
373        let env = create_env_for_mcp_server(None);
374
375        env_guard.restore_var("PATH", original_path);
376
377        assert_eq!(env.get(&OsString::from("PATH")), Some(&non_utf8_path));
378    }
379
380    #[tokio::test]
381    async fn convert_to_rmcp_round_trip() {
382        let mut capabilities = ClientCapabilities::default();
383        capabilities.roots = Some(RootsCapabilities {
384            list_changed: Some(true),
385        });
386        let params =
387            InitializeRequestParams::new(capabilities, Implementation::new("vtcode", "1.0"))
388                .with_protocol_version(rmcp::model::ProtocolVersion::V_2024_11_05);
389
390        let converted: InitializeRequestParams = convert_to_rmcp(params.clone()).unwrap();
391        // Verify the conversion succeeded by checking the name
392        assert_eq!(converted.client_info.name, "vtcode");
393        assert_eq!(converted.client_info.version, "1.0");
394    }
395
396    #[test]
397    fn validate_elicitation_payload_rejects_invalid_content() {
398        let schema = json!({
399            "type": "object",
400            "properties": {
401                "name": { "type": "string" }
402            },
403            "required": ["name"]
404        });
405        let validator =
406            build_elicitation_validator("test", &schema).expect("schema should compile");
407
408        let result = validate_elicitation_payload(
409            "test",
410            Some(&validator),
411            &ElicitationAction::Accept,
412            Some(&json!({ "name": 42 })),
413        );
414
415        assert!(result.is_err());
416    }
417
418    #[test]
419    fn validate_elicitation_payload_accepts_valid_content() {
420        let schema = json!({
421            "type": "object",
422            "properties": {
423                "email": { "type": "string", "format": "email" }
424            },
425            "required": ["email"]
426        });
427        let validator =
428            build_elicitation_validator("test", &schema).expect("schema should compile");
429
430        let result = validate_elicitation_payload(
431            "test",
432            Some(&validator),
433            &ElicitationAction::Accept,
434            Some(&json!({ "email": "user@example.com" })),
435        );
436
437        result.unwrap();
438    }
439
440    #[tokio::test]
441    async fn provider_max_concurrency_defaults_to_one() {
442        let config = McpProviderConfig {
443            name: "test".into(),
444            transport: McpTransportConfig::Stdio(McpStdioServerConfig {
445                command: "cat".into(),
446                args: vec![],
447                working_directory: None,
448            }),
449            env: HashMap::new(),
450            enabled: true,
451            max_concurrent_requests: 0,
452            startup_timeout_ms: None,
453        };
454
455        let provider = McpProvider::connect(config, None).await.unwrap();
456        assert_eq!(provider.semaphore.available_permits(), 1);
457    }
458
459    #[test]
460    fn directory_to_file_uri_generates_file_scheme() {
461        let temp_dir = std::env::temp_dir();
462        let uri = directory_to_file_uri(temp_dir.as_path())
463            .expect("should create file uri for temp directory");
464        assert!(uri.starts_with("file://"));
465    }
466}