1use anyhow::{Context, Result, anyhow};
2use arc_swap::ArcSwap;
3use hashbrown::HashMap;
4use rmcp::model::{
5 CallToolRequestParams, CallToolResult, GetPromptRequestParams, InitializeRequestParams,
6 InitializeResult, Prompt, ReadResourceRequestParams, Resource, Tool,
7};
8use serde_json::{Map, Value};
9use std::ffi::OsString;
10use std::path::PathBuf;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::{Mutex, Semaphore};
14use tracing::{Instrument, Span, warn};
15use url::Url;
16
17use super::{LATEST_PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS};
18
19use crate::config::mcp::{McpAllowListConfig, McpProviderConfig, McpTransportConfig};
20use vtcode_config::auth::McpOAuthService;
21use vtcode_utility_tool_specs::parse_mcp_tool;
22
23use super::{McpClient, RmcpClient};
24use super::{
25 McpElicitationHandler, McpPromptDetail, McpPromptInfo, McpResourceData, McpResourceInfo,
26 McpToolInfo, TIMEZONE_ARGUMENT, build_headers, ensure_timezone_argument, schema_requires_field,
27};
28
29pub struct McpProvider {
30 pub(super) name: String,
31 #[expect(dead_code)]
32 pub(super) protocol_version: String,
33 client: ArcSwap<RmcpClient>,
34 config: McpProviderConfig,
36 elicitation_handler: Option<Arc<dyn McpElicitationHandler>>,
38 pub(crate) semaphore: Arc<Semaphore>,
39 caches: Mutex<ProviderCaches>,
40 initialize_result: Mutex<Option<InitializeResult>>,
41}
42
43#[derive(Default)]
44struct ProviderCaches {
45 tools: Option<Arc<Vec<McpToolInfo>>>,
46 resources: Option<Arc<Vec<McpResourceInfo>>>,
47 prompts: Option<Arc<Vec<McpPromptInfo>>>,
48}
49
50impl McpProvider {
51 pub(super) async fn connect(
52 config: McpProviderConfig,
53 elicitation_handler: Option<Arc<dyn McpElicitationHandler>>,
54 ) -> Result<Self> {
55 if config.name.trim().is_empty() {
56 return Err(anyhow!("MCP provider name cannot be empty"));
57 }
58
59 let max_requests = std::cmp::max(1, config.max_concurrent_requests);
60
61 let (client, protocol_version) = match &config.transport {
62 McpTransportConfig::Stdio(stdio) => {
63 let program = OsString::from(&stdio.command);
64 let args: Vec<OsString> = stdio.args.iter().map(OsString::from).collect();
65 let working_dir = stdio.working_directory.as_ref().map(PathBuf::from);
66 let env: HashMap<OsString, OsString> = config
67 .env
68 .iter()
69 .map(|(key, value)| (OsString::from(key), OsString::from(value)))
70 .collect();
71 let client = RmcpClient::new_stdio_client(
72 config.name.clone(),
73 program,
74 args,
75 working_dir,
76 Some(env),
77 elicitation_handler.clone(),
78 )
79 .await?;
80 (client, LATEST_PROTOCOL_VERSION.to_string())
81 }
82 McpTransportConfig::Http(http) => {
83 if !SUPPORTED_PROTOCOL_VERSIONS
84 .iter()
85 .any(|supported| supported == &http.protocol_version)
86 {
87 return Err(anyhow!(
88 "MCP HTTP provider '{}' requested unsupported protocol version '{}'",
89 config.name,
90 http.protocol_version
91 ));
92 }
93
94 let bearer_token = if let Some(oauth) = http.oauth.as_ref() {
95 McpOAuthService::new()
96 .resolve_access_token(&config.name, oauth)
97 .await?
98 .ok_or_else(|| {
99 anyhow!(
100 "MCP HTTP provider '{}' requires OAuth login. Run `vtcode mcp login {}`.",
101 config.name,
102 config.name
103 )
104 })
105 .map(Some)?
106 } else {
107 match http.api_key_env.as_ref() {
108 Some(var) => Some(std::env::var(var).with_context(|| {
109 format!("Missing MCP API key environment variable: {var}")
110 })?),
111 None => None,
112 }
113 };
114
115 let headers = build_headers(&http.http_headers, &http.env_http_headers);
116 let client = RmcpClient::new_streamable_http_client(
117 config.name.clone(),
118 &http.endpoint,
119 bearer_token,
120 headers,
121 elicitation_handler.clone(),
122 )
123 .await?;
124 (client, http.protocol_version.clone())
125 }
126 };
127
128 Ok(Self {
129 name: config.name.clone(),
130 protocol_version,
131 client: ArcSwap::from_pointee(client),
132 config,
133 elicitation_handler,
134 semaphore: Arc::new(Semaphore::new(max_requests)),
135 caches: Mutex::new(ProviderCaches::default()),
136 initialize_result: Mutex::new(None),
137 })
138 }
139
140 pub(super) fn invalidate_caches(&self) {
141 if let Ok(mut caches) = self.caches.try_lock() {
142 caches.tools = None;
143 caches.resources = None;
144 caches.prompts = None;
145 }
146 }
147
148 pub(super) async fn initialize(
149 &self,
150 params: InitializeRequestParams,
151 startup_timeout: Option<Duration>,
152 tool_timeout: Option<Duration>,
153 allowlist: &McpAllowListConfig,
154 ) -> Result<()> {
155 let client = self.client.load_full();
156 let result = client.initialize(params, startup_timeout).await?;
157
158 let protocol_version_str = result.protocol_version.to_string();
159 if !SUPPORTED_PROTOCOL_VERSIONS
160 .iter()
161 .any(|supported| *supported == protocol_version_str)
162 {
163 return Err(anyhow!(
164 "MCP server for '{}' negotiated unsupported protocol version '{}'",
165 self.name,
166 protocol_version_str
167 ));
168 }
169
170 *self.initialize_result.lock().await = Some(result);
171 let _ = self.refresh_tools(allowlist, tool_timeout).await;
172 Ok(())
173 }
174
175 pub(super) async fn list_tools(
176 &self,
177 allowlist: &McpAllowListConfig,
178 timeout: Option<Duration>,
179 ) -> Result<Vec<McpToolInfo>> {
180 Ok(self
181 .list_tools_shared(allowlist, timeout)
182 .await?
183 .as_ref()
184 .clone())
185 }
186
187 async fn list_tools_shared(
188 &self,
189 allowlist: &McpAllowListConfig,
190 timeout: Option<Duration>,
191 ) -> Result<Arc<Vec<McpToolInfo>>> {
192 let mut caches = self.caches.lock().await;
193 if self.client.load_full().take_tool_list_changed() {
194 caches.tools = None;
195 }
196
197 if let Some(cache) = &caches.tools {
198 return Ok(Arc::clone(cache));
199 }
200 drop(caches);
201
202 self.refresh_tools_shared(allowlist, timeout).await
203 }
204
205 pub(super) async fn refresh_tools(
206 &self,
207 allowlist: &McpAllowListConfig,
208 timeout: Option<Duration>,
209 ) -> Result<Vec<McpToolInfo>> {
210 Ok(self
211 .refresh_tools_shared(allowlist, timeout)
212 .await?
213 .as_ref()
214 .clone())
215 }
216
217 async fn refresh_tools_shared(
218 &self,
219 allowlist: &McpAllowListConfig,
220 timeout: Option<Duration>,
221 ) -> Result<Arc<Vec<McpToolInfo>>> {
222 let client = self.client.load_full();
223 let tools = client.list_all_tools(timeout).await?;
224 let filtered = Arc::new(self.filter_tools(tools, allowlist));
225 self.caches.lock().await.tools = Some(Arc::clone(&filtered));
226 Ok(filtered)
227 }
228
229 pub(super) async fn has_tool(
230 &self,
231 tool_name: &str,
232 allowlist: &McpAllowListConfig,
233 timeout: Option<Duration>,
234 ) -> Result<bool> {
235 let tools = self.list_tools_shared(allowlist, timeout).await?;
236 Ok(tools.iter().any(|tool| tool.name == tool_name))
237 }
238
239 pub(super) async fn call_tool(
240 &self,
241 tool_name: &str,
242 args: &Value,
243 timeout: Option<Duration>,
244 allowlist: &McpAllowListConfig,
245 ) -> Result<CallToolResult> {
246 if !allowlist.is_tool_allowed(&self.name, tool_name) {
247 return Err(anyhow!(
248 "Tool '{}' is blocked by the MCP allow list for provider '{}'",
249 tool_name,
250 self.name
251 ));
252 }
253
254 let _permit = self
255 .semaphore
256 .clone()
257 .acquire_owned()
258 .await
259 .context("Failed to acquire MCP request slot")?;
260 let mut arguments = McpClient::normalize_arguments(args);
261 self.add_argument_defaults(tool_name, &mut arguments, allowlist, timeout)
262 .await
263 .with_context(|| {
264 format!(
265 "failed to prepare arguments for MCP tool '{}' on provider '{}'",
266 tool_name, self.name
267 )
268 })?;
269 let params = CallToolRequestParams::new(tool_name.to_string()).with_arguments(arguments);
270 let client = self.client.load_full();
271 async move { client.call_tool(params, timeout).await }
272 .instrument(mcp_tool_call_span(
273 &self.name,
274 tool_name,
275 &self.config.transport,
276 ))
277 .await
278 }
279
280 async fn add_argument_defaults(
281 &self,
282 tool_name: &str,
283 arguments: &mut Map<String, Value>,
284 allowlist: &McpAllowListConfig,
285 timeout: Option<Duration>,
286 ) -> Result<()> {
287 let requires_timezone = self
288 .tool_requires_field(tool_name, TIMEZONE_ARGUMENT, allowlist, timeout)
289 .await?;
290 ensure_timezone_argument(arguments, requires_timezone)?;
291 Ok(())
292 }
293
294 async fn tool_requires_field(
295 &self,
296 tool_name: &str,
297 field: &str,
298 allowlist: &McpAllowListConfig,
299 timeout: Option<Duration>,
300 ) -> Result<bool> {
301 if let Some(tools) = &self.caches.lock().await.tools
302 && let Some(tool) = tools.iter().find(|tool| tool.name == tool_name)
303 {
304 return Ok(schema_requires_field(&tool.input_schema, field));
305 }
306
307 match self.refresh_tools_shared(allowlist, timeout).await {
308 Ok(tools) => Ok(tools
309 .iter()
310 .find(|tool| tool.name == tool_name)
311 .map(|tool| schema_requires_field(&tool.input_schema, field))
312 .unwrap_or(false)),
313 Err(err) => {
314 warn!(
315 "Failed to refresh tools while inspecting schema for '{}' on provider '{}': {err}",
316 tool_name, self.name
317 );
318 Ok(false)
319 }
320 }
321 }
322
323 pub(super) async fn list_resources(
324 &self,
325 allowlist: &McpAllowListConfig,
326 timeout: Option<Duration>,
327 ) -> Result<Vec<McpResourceInfo>> {
328 Ok(self
329 .list_resources_shared(allowlist, timeout)
330 .await?
331 .as_ref()
332 .clone())
333 }
334
335 async fn list_resources_shared(
336 &self,
337 allowlist: &McpAllowListConfig,
338 timeout: Option<Duration>,
339 ) -> Result<Arc<Vec<McpResourceInfo>>> {
340 let mut caches = self.caches.lock().await;
341 if self.client.load_full().take_resource_list_changed() {
342 caches.resources = None;
343 }
344
345 if let Some(cache) = &caches.resources {
346 return Ok(Arc::clone(cache));
347 }
348 drop(caches);
349
350 self.refresh_resources_shared(allowlist, timeout).await
351 }
352
353 pub(super) async fn refresh_resources(
354 &self,
355 allowlist: &McpAllowListConfig,
356 timeout: Option<Duration>,
357 ) -> Result<Vec<McpResourceInfo>> {
358 Ok(self
359 .refresh_resources_shared(allowlist, timeout)
360 .await?
361 .as_ref()
362 .clone())
363 }
364
365 async fn refresh_resources_shared(
366 &self,
367 allowlist: &McpAllowListConfig,
368 timeout: Option<Duration>,
369 ) -> Result<Arc<Vec<McpResourceInfo>>> {
370 let client = self.client.load_full();
371 let resources = client.list_all_resources(timeout).await?;
372 let filtered = Arc::new(self.filter_resources(resources, allowlist));
373 self.caches.lock().await.resources = Some(Arc::clone(&filtered));
374 Ok(filtered)
375 }
376
377 pub(super) async fn has_resource(
378 &self,
379 uri: &str,
380 allowlist: &McpAllowListConfig,
381 timeout: Option<Duration>,
382 ) -> Result<bool> {
383 let resources = self.list_resources_shared(allowlist, timeout).await?;
384 Ok(resources.iter().any(|resource| resource.uri == uri))
385 }
386
387 pub(super) async fn read_resource(
388 &self,
389 uri: &str,
390 timeout: Option<Duration>,
391 allowlist: &McpAllowListConfig,
392 ) -> Result<McpResourceData> {
393 if !allowlist.is_resource_allowed(&self.name, uri) {
394 return Err(anyhow!(
395 "Resource '{}' is blocked by the MCP allow list for provider '{}'",
396 uri,
397 self.name
398 ));
399 }
400
401 let _permit = self
402 .semaphore
403 .clone()
404 .acquire_owned()
405 .await
406 .context("Failed to acquire MCP request slot")?;
407 let params = ReadResourceRequestParams::new(uri.to_string());
408 let client = self.client.load_full();
409 let result = client.read_resource(params, timeout).await?;
410 Ok(McpResourceData {
411 provider: self.name.clone(),
412 uri: uri.to_string(),
413 contents: result.contents,
414 meta: Map::new(),
415 })
416 }
417
418 pub(super) async fn list_prompts(
419 &self,
420 allowlist: &McpAllowListConfig,
421 timeout: Option<Duration>,
422 ) -> Result<Vec<McpPromptInfo>> {
423 Ok(self
424 .list_prompts_shared(allowlist, timeout)
425 .await?
426 .as_ref()
427 .clone())
428 }
429
430 async fn list_prompts_shared(
431 &self,
432 allowlist: &McpAllowListConfig,
433 timeout: Option<Duration>,
434 ) -> Result<Arc<Vec<McpPromptInfo>>> {
435 let mut caches = self.caches.lock().await;
436 if self.client.load_full().take_prompt_list_changed() {
437 caches.prompts = None;
438 }
439
440 if let Some(cache) = &caches.prompts {
441 return Ok(Arc::clone(cache));
442 }
443 drop(caches);
444
445 self.refresh_prompts_shared(allowlist, timeout).await
446 }
447
448 pub(super) async fn refresh_prompts(
449 &self,
450 allowlist: &McpAllowListConfig,
451 timeout: Option<Duration>,
452 ) -> Result<Vec<McpPromptInfo>> {
453 Ok(self
454 .refresh_prompts_shared(allowlist, timeout)
455 .await?
456 .as_ref()
457 .clone())
458 }
459
460 async fn refresh_prompts_shared(
461 &self,
462 allowlist: &McpAllowListConfig,
463 timeout: Option<Duration>,
464 ) -> Result<Arc<Vec<McpPromptInfo>>> {
465 let client = self.client.load_full();
466 let prompts = client.list_all_prompts(timeout).await?;
467 let filtered = Arc::new(self.filter_prompts(prompts, allowlist));
468 self.caches.lock().await.prompts = Some(Arc::clone(&filtered));
469 Ok(filtered)
470 }
471
472 pub(super) async fn has_prompt(
473 &self,
474 prompt_name: &str,
475 allowlist: &McpAllowListConfig,
476 timeout: Option<Duration>,
477 ) -> Result<bool> {
478 let prompts = self.list_prompts_shared(allowlist, timeout).await?;
479 Ok(prompts.iter().any(|prompt| prompt.name == prompt_name))
480 }
481
482 pub(super) async fn get_prompt(
483 &self,
484 prompt_name: &str,
485 arguments: HashMap<String, String>,
486 timeout: Option<Duration>,
487 allowlist: &McpAllowListConfig,
488 ) -> Result<McpPromptDetail> {
489 if !allowlist.is_prompt_allowed(&self.name, prompt_name) {
490 return Err(anyhow!(
491 "Prompt '{}' is blocked by the MCP allow list for provider '{}'",
492 prompt_name,
493 self.name
494 ));
495 }
496
497 let _permit = self
498 .semaphore
499 .clone()
500 .acquire_owned()
501 .await
502 .context("Failed to acquire MCP request slot")?;
503 let args_json: Map<String, Value> = arguments
505 .into_iter()
506 .map(|(k, v)| (k, Value::String(v)))
507 .collect();
508
509 let params = GetPromptRequestParams::new(prompt_name.to_string()).with_arguments(args_json);
510 let client = self.client.load_full();
511 let result = client.get_prompt(params, timeout).await?;
512 Ok(McpPromptDetail {
513 provider: self.name.clone(),
514 name: prompt_name.to_string(),
515 description: result.description,
516 messages: result.messages,
517 meta: Map::new(),
518 })
519 }
520
521 pub(super) async fn cached_tools(&self) -> Option<Vec<McpToolInfo>> {
522 self.caches
523 .lock()
524 .await
525 .tools
526 .as_ref()
527 .map(|tools| tools.as_ref().clone())
528 }
529
530 pub(super) async fn cached_tools_or_refresh(
531 &self,
532 allowlist: &McpAllowListConfig,
533 timeout: Option<Duration>,
534 ) -> Result<Vec<McpToolInfo>> {
535 if let Some(tools) = self.cached_tools().await {
536 return Ok(tools);
537 }
538
539 self.refresh_tools(allowlist, timeout).await
540 }
541
542 pub(super) async fn shutdown(&self) -> Result<()> {
543 let client = self.client.load_full();
544 client.shutdown().await
545 }
546
547 pub(super) async fn is_healthy(&self) -> bool {
549 let client = self.client.load_full();
550 client.is_healthy().await
551 }
552
553 pub(super) async fn reconnect(
558 &self,
559 startup_timeout: Option<Duration>,
560 tool_timeout: Option<Duration>,
561 allowlist: &McpAllowListConfig,
562 ) -> Result<()> {
563 tracing::info!(provider = self.name.as_str(), "Attempting MCP reconnection");
564
565 {
567 let old = self.client.load_full();
568 let _ = old.shutdown().await;
569 }
570
571 let new_provider =
573 McpProvider::connect(self.config.clone(), self.elicitation_handler.clone())
574 .await
575 .with_context(|| format!("MCP reconnect failed for provider '{}'", self.name))?;
576
577 {
579 let new_client = new_provider.client.load_full();
580 self.client.store(new_client);
581 }
582
583 self.invalidate_caches();
585
586 let init_params = InitializeRequestParams::new(
588 rmcp::model::ClientCapabilities::default(),
589 super::utils::build_client_implementation(),
590 )
591 .with_protocol_version(rmcp::model::ProtocolVersion::V_2024_11_05);
592 self.initialize(init_params, startup_timeout, tool_timeout, allowlist)
593 .await
594 .with_context(|| {
595 format!("MCP re-initialization failed for provider '{}'", self.name)
596 })?;
597
598 tracing::info!(provider = self.name.as_str(), "MCP reconnection successful");
599 Ok(())
600 }
601
602 fn filter_tools(&self, tools: Vec<Tool>, allowlist: &McpAllowListConfig) -> Vec<McpToolInfo> {
603 tools
604 .into_iter()
605 .filter(|tool| allowlist.is_tool_allowed(&self.name, &tool.name))
606 .map(|tool| {
607 let parsed = parse_mcp_tool(&tool);
608 McpToolInfo {
609 description: parsed.description,
610 input_schema: parsed.input_schema,
611 provider: self.name.clone(),
612 name: parsed.name,
613 }
614 })
615 .collect()
616 }
617
618 fn filter_resources(
619 &self,
620 resources: Vec<Resource>,
621 allowlist: &McpAllowListConfig,
622 ) -> Vec<McpResourceInfo> {
623 resources
624 .into_iter()
625 .filter(|resource| allowlist.is_resource_allowed(&self.name, &resource.uri))
626 .map(|resource| McpResourceInfo {
627 provider: self.name.clone(),
628 uri: resource.uri.clone(),
629 name: resource.name.clone(),
630 description: resource.description.clone(),
631 mime_type: resource.mime_type.clone(),
632 size: resource.size.map(|s| s as i64),
633 })
634 .collect()
635 }
636
637 fn filter_prompts(
638 &self,
639 prompts: Vec<Prompt>,
640 allowlist: &McpAllowListConfig,
641 ) -> Vec<McpPromptInfo> {
642 prompts
643 .into_iter()
644 .filter(|prompt| allowlist.is_prompt_allowed(&self.name, &prompt.name))
645 .map(|prompt| McpPromptInfo {
646 provider: self.name.clone(),
647 name: prompt.name.clone(),
648 description: prompt.description.clone(),
649 arguments: prompt.arguments.clone().unwrap_or_default(),
650 })
651 .collect()
652 }
653}
654
655fn mcp_tool_call_span(
656 provider_name: &str,
657 tool_name: &str,
658 transport: &McpTransportConfig,
659) -> Span {
660 let (transport_label, server_address, server_port) = match transport {
661 McpTransportConfig::Stdio(_) => ("stdio", String::new(), 0_u16),
662 McpTransportConfig::Http(http) => {
663 let (server_address, server_port) = Url::parse(&http.endpoint)
664 .ok()
665 .and_then(|url| {
666 url.host_str().map(|host| {
667 (
668 host.to_string(),
669 url.port_or_known_default().unwrap_or_default(),
670 )
671 })
672 })
673 .unwrap_or_default();
674 ("streamable_http", server_address, server_port)
675 }
676 };
677
678 tracing::info_span!(
679 "mcp.tools.call",
680 provider = provider_name,
681 tool = tool_name,
682 rpc_system = "jsonrpc",
683 rpc_method = "tools/call",
684 transport = transport_label,
685 server_address = server_address.as_str(),
686 server_port = server_port,
687 )
688}
689
690#[cfg(test)]
691mod tests {
692 use super::mcp_tool_call_span;
693 use crate::config::mcp::{McpHttpServerConfig, McpStdioServerConfig, McpTransportConfig};
694 use crate::utils::trace_writer::FlushableWriter;
695 use std::fs;
696 use tempfile::tempdir;
697 use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};
698
699 #[test]
700 fn mcp_tool_call_span_records_http_metadata() {
701 let tempdir = tempdir().expect("tempdir");
702 let log_file = tempdir.path().join("trace.log");
703 let writer = FlushableWriter::open(&log_file).expect("trace writer");
704 let writer_for_layer = writer.clone();
705 let subscriber = tracing_subscriber::registry().with(
706 tracing_subscriber::fmt::layer()
707 .with_writer(move || writer_for_layer.clone())
708 .with_span_events(FmtSpan::FULL)
709 .with_ansi(false),
710 );
711 let _guard = tracing::subscriber::set_default(subscriber);
712
713 {
714 let span = mcp_tool_call_span(
715 "calendar",
716 "get_events",
717 &McpTransportConfig::Http(McpHttpServerConfig {
718 endpoint: "https://example.com:8443/mcp".to_string(),
719 api_key_env: None,
720 oauth: None,
721 protocol_version: "2024-11-05".to_string(),
722 http_headers: Default::default(),
723 env_http_headers: Default::default(),
724 }),
725 );
726 let _entered = span.enter();
727 }
728
729 writer.flush();
730 let logs = fs::read_to_string(&log_file).expect("trace log");
731 assert!(logs.contains("mcp.tools.call"));
732 assert!(logs.contains("provider=\"calendar\""));
733 assert!(logs.contains("tool=\"get_events\""));
734 assert!(logs.contains("rpc_system=\"jsonrpc\""));
735 assert!(logs.contains("rpc_method=\"tools/call\""));
736 assert!(logs.contains("transport=\"streamable_http\""));
737 assert!(logs.contains("server_address=\"example.com\""));
738 assert!(logs.contains("server_port=8443"));
739 }
740
741 #[test]
742 fn mcp_tool_call_span_defaults_stdio_transport() {
743 let span = mcp_tool_call_span(
744 "filesystem",
745 "read_file",
746 &McpTransportConfig::Stdio(McpStdioServerConfig {
747 command: "rmcp-server".to_string(),
748 args: Vec::new(),
749 working_directory: None,
750 }),
751 );
752
753 assert_eq!(span.metadata().expect("metadata").name(), "mcp.tools.call");
754 }
755}