1use crate::config::mcp::McpClientConfig;
16
17pub mod cli;
18mod client;
19pub mod connection_pool;
20pub mod enhanced_config;
21pub mod errors;
22mod provider;
23mod rmcp_client;
24pub mod rmcp_transport;
25pub mod schema;
26pub mod tool_discovery;
27pub mod tool_discovery_cache;
28pub mod traits;
29pub mod types;
30pub mod utils;
31
32pub use client::McpClient;
33
34pub use connection_pool::{
35 ConnectionPoolStats, McpConnectionPool, McpPoolError, PooledMcpManager, PooledMcpStats,
36};
37pub use errors::{
38 ErrorCode, McpResult, configuration_error, initialization_timeout, provider_not_found,
39 provider_unavailable, schema_invalid, tool_invocation_failed, tool_not_found,
40};
41pub use provider::McpProvider;
42pub(crate) use rmcp_client::RmcpClient;
43pub use rmcp_transport::{
44 HttpTransport, create_http_transport, create_stdio_transport,
45 create_stdio_transport_with_stderr,
46};
47pub use schema::{validate_against_schema, validate_tool_input};
48pub use tool_discovery::{DetailLevel, ToolDiscovery, ToolDiscoveryResult};
49pub use traits::{McpElicitationHandler, McpToolExecutor};
50pub use types::{
51 FileParamSchemaEntry, FileUploadResult, McpClientStatus, McpElicitationRequest,
52 McpElicitationResponse, McpPromptDetail, McpPromptInfo, McpResourceData, McpResourceInfo,
53 McpToolInfo, OPENAI_FILE_PARAMS_META_KEY, OPENAI_FILE_PARAMS_VALUE, ProvidedFilePayload,
54};
55pub use utils::{
56 LOCAL_TIMEZONE_ENV_VAR, TIMEZONE_ARGUMENT, TZ_ENV_VAR, build_headers, detect_local_timezone,
57 ensure_timezone_argument, schema_requires_field,
58};
59
60use anyhow::{Result, anyhow};
61use hashbrown::HashMap;
62pub use rmcp::model::ElicitationAction;
63use std::ffi::OsString;
64use std::fmt::Write;
65
66pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
68pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &[LATEST_PROTOCOL_VERSION];
69
70pub(crate) fn convert_to_rmcp<T, U>(value: T) -> Result<U>
72where
73 T: serde::Serialize,
74 U: serde::de::DeserializeOwned,
75{
76 let json = serde_json::to_value(value)?;
77 serde_json::from_value(json).map_err(|err| anyhow!(err))
78}
79
80fn create_env_for_mcp_server(
81 extra_env: Option<HashMap<OsString, OsString>>,
82) -> HashMap<OsString, OsString> {
83 DEFAULT_ENV_VARS
84 .iter()
85 .filter_map(|var| std::env::var_os(var).map(|value| (OsString::from(*var), value)))
86 .chain(extra_env.unwrap_or_default())
87 .collect()
88}
89
90pub fn validate_mcp_config(config: &McpClientConfig) -> Result<()> {
92 if config.server.enabled {
94 if config.server.port == 0 {
96 return Err(anyhow::anyhow!(
97 "Invalid server port: {}",
98 config.server.port
99 ));
100 }
101
102 if config.server.bind_address.is_empty() {
104 return Err(anyhow::anyhow!("Server bind address cannot be empty"));
105 }
106
107 if config.security.auth_enabled && config.security.api_key_env.is_none() {
109 return Err(anyhow::anyhow!(
110 "API key environment variable must be set when auth is enabled"
111 ));
112 }
113 }
114
115 if let Some(startup_timeout) = config.startup_timeout_seconds
117 && startup_timeout > 300
118 {
119 return Err(anyhow::anyhow!("Startup timeout cannot exceed 300 seconds"));
121 }
122
123 if let Some(tool_timeout) = config.tool_timeout_seconds
124 && tool_timeout > 3600
125 {
126 return Err(anyhow::anyhow!("Tool timeout cannot exceed 3600 seconds"));
128 }
129
130 for provider in &config.providers {
132 if provider.name.is_empty() {
133 return Err(anyhow::anyhow!("MCP provider name cannot be empty"));
134 }
135
136 if provider.max_concurrent_requests == 0 {
138 return Err(anyhow::anyhow!(
139 "Max concurrent requests must be greater than 0 for provider '{}'",
140 provider.name
141 ));
142 }
143 }
144
145 Ok(())
146}
147
148#[cfg(unix)]
149const DEFAULT_ENV_VARS: &[&str] = &[
150 "HOME",
151 "LOGNAME",
152 "PATH",
153 "SHELL",
154 "USER",
155 "__CF_USER_TEXT_ENCODING",
156 "LANG",
157 "LC_ALL",
158 "TERM",
159 "TMPDIR",
160 "TZ",
161];
162
163#[cfg(windows)]
164const DEFAULT_ENV_VARS: &[&str] = &[
165 "PATH",
167 "PATHEXT",
168 "COMSPEC",
170 "SYSTEMROOT",
171 "SYSTEMDRIVE",
172 "USERNAME",
174 "USERDOMAIN",
175 "USERPROFILE",
176 "HOMEDRIVE",
177 "HOMEPATH",
178 "PROGRAMFILES",
180 "PROGRAMFILES(X86)",
181 "PROGRAMW6432",
182 "PROGRAMDATA",
183 "LOCALAPPDATA",
185 "APPDATA",
186 "TEMP",
188 "TMP",
189 "POWERSHELL",
191 "PWSH",
192];
193
194fn sanitize_filename(name: &str) -> String {
198 name.chars()
199 .map(|c| {
200 if c.is_alphanumeric() || c == '_' || c == '-' {
201 c
202 } else {
203 '_'
204 }
205 })
206 .collect()
207}
208
209fn format_tool_markdown(tool: &McpToolInfo) -> String {
211 let mut content = String::new();
212 let _ = write!(content, "# {}\n\n", tool.name);
213 let _ = write!(content, "**Provider**: {}\n\n", tool.provider);
214 content.push_str("## Description\n\n");
215 content.push_str(&tool.description);
216 content.push_str("\n\n");
217
218 content.push_str("## Input Schema\n\n");
219 content.push_str("```json\n");
220 content.push_str(
221 &serde_json::to_string_pretty(&tool.input_schema)
222 .unwrap_or_else(|_| tool.input_schema.to_string()),
223 );
224 content.push_str("\n```\n\n");
225
226 if let Some(obj) = tool.input_schema.as_object() {
228 if let Some(required) = obj.get("required").and_then(|v| v.as_array())
229 && !required.is_empty()
230 {
231 content.push_str("## Required Parameters\n\n");
232 for req in required {
233 if let Some(name) = req.as_str() {
234 let _ = writeln!(content, "- `{}`", name);
235 }
236 }
237 content.push('\n');
238 }
239
240 if let Some(props) = obj.get("properties").and_then(|v| v.as_object())
242 && !props.is_empty()
243 {
244 content.push_str("## Parameters\n\n");
245 for (param_name, param_schema) in props {
246 let param_type = param_schema
247 .get("type")
248 .and_then(|t| t.as_str())
249 .unwrap_or("any");
250 let param_desc = param_schema
251 .get("description")
252 .and_then(|d| d.as_str())
253 .unwrap_or("");
254 let _ = write!(content, "### `{}`\n\n", param_name);
255 let _ = writeln!(content, "- **Type**: {}", param_type);
256 if !param_desc.is_empty() {
257 let _ = writeln!(content, "- **Description**: {}", param_desc);
258 }
259 content.push('\n');
260 }
261 }
262 }
263
264 content.push_str("---\n");
265 content.push_str("*Generated automatically for dynamic context discovery.*\n");
266
267 content
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::config::mcp::{McpProviderConfig, McpStdioServerConfig, McpTransportConfig};
274 use crate::mcp::rmcp_client::{
275 build_elicitation_validator, directory_to_file_uri, validate_elicitation_payload,
276 };
277 use crate::mcp::utils::{clear_test_env_override, set_test_env_override};
278 use hashbrown::HashMap;
279 use serde_json::{Map, Value, json};
280 use rmcp::model::{
282 ClientCapabilities, Implementation, InitializeRequestParams, RootsCapabilities,
283 };
284
285 #[cfg(unix)]
286 use serial_test::serial;
287
288 #[cfg(unix)]
289 use std::os::unix::ffi::OsStringExt;
290
291 struct EnvGuard {
292 key: &'static str,
293 }
294
295 impl EnvGuard {
296 fn set(key: &'static str, value: &str) -> Self {
297 set_test_env_override(key, Some(value));
298 Self { key }
299 }
300 }
301
302 impl Drop for EnvGuard {
303 fn drop(&mut self) {
304 clear_test_env_override(self.key);
305 }
306 }
307
308 #[test]
309 fn schema_detection_handles_required_entries() {
310 let schema = json!({
311 "type": "object",
312 "required": [TIMEZONE_ARGUMENT],
313 "properties": {
314 TIMEZONE_ARGUMENT: { "type": "string" }
315 }
316 });
317
318 assert!(schema_requires_field(&schema, TIMEZONE_ARGUMENT));
319 assert!(!schema_requires_field(&schema, "location"));
320 }
321
322 #[test]
323 fn ensure_timezone_injects_from_override_env() {
324 let _guard = EnvGuard::set(LOCAL_TIMEZONE_ENV_VAR, "Etc/UTC");
325 let mut arguments = Map::new();
326
327 ensure_timezone_argument(&mut arguments, true).unwrap();
328
329 assert_eq!(
330 arguments.get(TIMEZONE_ARGUMENT).and_then(Value::as_str),
331 Some("Etc/UTC")
332 );
333 }
334
335 #[test]
336 fn ensure_timezone_does_not_override_existing_value() {
337 let mut arguments = Map::new();
338 arguments.insert(
339 TIMEZONE_ARGUMENT.to_string(),
340 Value::String("America/New_York".to_owned()),
341 );
342
343 ensure_timezone_argument(&mut arguments, true).unwrap();
344
345 assert_eq!(
346 arguments.get(TIMEZONE_ARGUMENT).and_then(Value::as_str),
347 Some("America/New_York")
348 );
349 }
350
351 #[test]
352 fn create_env_merges_configured_values() {
353 let mut extra_env = HashMap::new();
354 extra_env.insert(OsString::from("A"), OsString::from("1"));
355 extra_env.insert(OsString::from("B"), OsString::from("2"));
356
357 let env = create_env_for_mcp_server(Some(extra_env));
358
359 assert_eq!(env.get(&OsString::from("A")), Some(&OsString::from("1")));
360 assert_eq!(env.get(&OsString::from("B")), Some(&OsString::from("2")));
361 }
362
363 #[test]
364 #[cfg(unix)]
365 #[serial]
366 fn create_env_preserves_non_utf8_path() {
367 let env_guard = vtcode_commons::env_lock::lock();
368 let original_path = std::env::var_os("PATH");
369 let non_utf8_path = OsString::from_vec(b"/tmp/alpha:\xFFbeta".to_vec());
370
371 env_guard.set_var("PATH", &non_utf8_path);
372
373 let env = create_env_for_mcp_server(None);
374
375 env_guard.restore_var("PATH", original_path);
376
377 assert_eq!(env.get(&OsString::from("PATH")), Some(&non_utf8_path));
378 }
379
380 #[tokio::test]
381 async fn convert_to_rmcp_round_trip() {
382 let mut capabilities = ClientCapabilities::default();
383 capabilities.roots = Some(RootsCapabilities {
384 list_changed: Some(true),
385 });
386 let params =
387 InitializeRequestParams::new(capabilities, Implementation::new("vtcode", "1.0"))
388 .with_protocol_version(rmcp::model::ProtocolVersion::V_2024_11_05);
389
390 let converted: InitializeRequestParams = convert_to_rmcp(params.clone()).unwrap();
391 assert_eq!(converted.client_info.name, "vtcode");
393 assert_eq!(converted.client_info.version, "1.0");
394 }
395
396 #[test]
397 fn validate_elicitation_payload_rejects_invalid_content() {
398 let schema = json!({
399 "type": "object",
400 "properties": {
401 "name": { "type": "string" }
402 },
403 "required": ["name"]
404 });
405 let validator =
406 build_elicitation_validator("test", &schema).expect("schema should compile");
407
408 let result = validate_elicitation_payload(
409 "test",
410 Some(&validator),
411 &ElicitationAction::Accept,
412 Some(&json!({ "name": 42 })),
413 );
414
415 assert!(result.is_err());
416 }
417
418 #[test]
419 fn validate_elicitation_payload_accepts_valid_content() {
420 let schema = json!({
421 "type": "object",
422 "properties": {
423 "email": { "type": "string", "format": "email" }
424 },
425 "required": ["email"]
426 });
427 let validator =
428 build_elicitation_validator("test", &schema).expect("schema should compile");
429
430 let result = validate_elicitation_payload(
431 "test",
432 Some(&validator),
433 &ElicitationAction::Accept,
434 Some(&json!({ "email": "user@example.com" })),
435 );
436
437 result.unwrap();
438 }
439
440 #[tokio::test]
441 async fn provider_max_concurrency_defaults_to_one() {
442 let config = McpProviderConfig {
443 name: "test".into(),
444 transport: McpTransportConfig::Stdio(McpStdioServerConfig {
445 command: "cat".into(),
446 args: vec![],
447 working_directory: None,
448 }),
449 env: HashMap::new(),
450 enabled: true,
451 max_concurrent_requests: 0,
452 startup_timeout_ms: None,
453 };
454
455 let provider = McpProvider::connect(config, None).await.unwrap();
456 assert_eq!(provider.semaphore.available_permits(), 1);
457 }
458
459 #[test]
460 fn directory_to_file_uri_generates_file_scheme() {
461 let temp_dir = std::env::temp_dir();
462 let uri = directory_to_file_uri(temp_dir.as_path())
463 .expect("should create file uri for temp directory");
464 assert!(uri.starts_with("file://"));
465 }
466}