Skip to main content

vtcode_core/mcp/
client.rs

1use crate::config::mcp::{
2    McpAllowListConfig, McpClientConfig, McpProviderConfig, McpTransportConfig,
3};
4use crate::utils::file_utils::{ensure_dir_exists, write_file_with_context};
5use anyhow::{Context, Result, anyhow, bail};
6use async_trait::async_trait;
7use chrono::Utc;
8use parking_lot::RwLock;
9use rmcp::model::{CallToolResult, ClientCapabilities, InitializeRequestParams, RootsCapabilities};
10use rustc_hash::FxHashMap;
11use serde_json::{Map, Value, json};
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use std::time::Duration;
15use tracing::{debug, error, info, warn};
16
17use super::{
18    McpClientStatus, McpElicitationHandler, McpPromptDetail, McpPromptInfo, McpProvider,
19    McpResourceData, McpResourceInfo, McpToolExecutor, McpToolInfo, format_tool_markdown,
20    sanitize_filename,
21};
22
23struct McpClientState {
24    providers: FxHashMap<String, Arc<McpProvider>>,
25    allowlist: McpAllowListConfig,
26    tool_provider_index: FxHashMap<String, String>,
27    resource_provider_index: FxHashMap<String, String>,
28    prompt_provider_index: FxHashMap<String, String>,
29}
30
31pub struct McpClient {
32    config: McpClientConfig,
33    state: RwLock<McpClientState>,
34    elicitation_handler: Option<Arc<dyn McpElicitationHandler>>,
35}
36
37impl McpClient {
38    /// Create a new MCP client from the configuration.
39    pub fn new(config: McpClientConfig) -> Self {
40        let allowlist = config.allowlist.clone();
41
42        Self {
43            config,
44            state: RwLock::new(McpClientState {
45                providers: FxHashMap::default(),
46                allowlist,
47                tool_provider_index: FxHashMap::default(),
48                resource_provider_index: FxHashMap::default(),
49                prompt_provider_index: FxHashMap::default(),
50            }),
51            elicitation_handler: None,
52        }
53    }
54
55    /// Register a handler used to satisfy elicitation requests from providers.
56    pub fn set_elicitation_handler(&mut self, handler: Arc<dyn McpElicitationHandler>) {
57        self.elicitation_handler = Some(handler);
58    }
59
60    /// Establish connections to all configured providers and complete the
61    /// MCP handshake.
62    pub async fn initialize(&mut self) -> Result<()> {
63        if !self.config.enabled {
64            info!("MCP client is disabled in configuration");
65            return Ok(());
66        }
67
68        info!(
69            "Initializing MCP client with {} configured providers",
70            self.config.providers.len()
71        );
72
73        // Sequential initialization
74        Box::pin(self.initialize_sequential()).await
75    }
76
77    /// Initialize providers sequentially (fallback method)
78    async fn initialize_sequential(&mut self) -> Result<()> {
79        let tool_timeout = self.tool_timeout();
80        let allowlist_snapshot = self.state.read().allowlist.clone();
81
82        let mut initialized = FxHashMap::default();
83
84        for provider_config in &self.config.providers {
85            if !provider_config.enabled {
86                debug!(
87                    "MCP provider '{}' is disabled; skipping",
88                    provider_config.name
89                );
90                continue;
91            }
92
93            if let Some(reason) = self.requirement_mismatch_reason(provider_config) {
94                warn!(
95                    "Skipping MCP provider '{}' due to requirements policy: {}",
96                    provider_config.name, reason
97                );
98                continue;
99            }
100
101            if matches!(provider_config.transport, McpTransportConfig::Http(_))
102                && !self.config.experimental_use_rmcp_client
103            {
104                warn!(
105                    "Skipping MCP HTTP provider '{}' because experimental_use_rmcp_client is disabled",
106                    provider_config.name
107                );
108                continue;
109            }
110
111            match self
112                .connect_and_initialize_provider(provider_config, &allowlist_snapshot, tool_timeout)
113                .await
114            {
115                Ok(provider) => {
116                    if let Err(err) = provider
117                        .cached_tools_or_refresh(&allowlist_snapshot, tool_timeout)
118                        .await
119                    {
120                        error!(
121                            "Failed to fetch tools for provider '{}': {err}",
122                            provider_config.name
123                        );
124                    } else if let Some(cache) = provider.cached_tools().await {
125                        self.record_tool_provider(&provider.name, &cache);
126                    }
127
128                    initialized.insert(provider.name.clone(), Arc::new(provider));
129                    info!(
130                        "Successfully initialized MCP provider '{}'",
131                        provider_config.name
132                    );
133                }
134                Err(err) => {
135                    error!(
136                        "Failed to connect to MCP provider '{}': {err}",
137                        provider_config.name
138                    );
139                }
140            }
141        }
142
143        self.state.write().providers = initialized;
144        info!(
145            "MCP client initialization complete. Active providers: {}",
146            self.state.read().providers.len()
147        );
148
149        Ok(())
150    }
151
152    /// Validate tool arguments based on security configuration
153    fn validate_tool_arguments(&self, _tool_name: &str, args: &Value) -> Result<()> {
154        // Check argument size
155        if self.config.security.validation.max_argument_size > 0 {
156            let args_size = serde_json::to_string(args).map_or(0, |s| s.len()) as u32;
157
158            if args_size > self.config.security.validation.max_argument_size {
159                return Err(anyhow::anyhow!(
160                    "Tool arguments exceed maximum size of {} bytes",
161                    self.config.security.validation.max_argument_size
162                ));
163            }
164        }
165
166        // Check for path traversal in file-related arguments
167        if self.config.security.validation.path_traversal_protection
168            && let Some(path) = args.get("path").and_then(|v| v.as_str())
169            && (path.contains("../")
170                || path.starts_with("../")
171                || path.contains("..\\")
172                || path.starts_with("..\\"))
173        {
174            return Err(anyhow::anyhow!("Path traversal detected in arguments"));
175        }
176
177        Ok(())
178    }
179
180    /// Execute a tool call after validating arguments.
181    ///
182    /// Public-facing version that takes ownership of `args` for compatibility
183    /// with existing callers. Delegates to the reference-taking implementation
184    /// to avoid unnecessary cloning when the caller already has a reference.
185    pub async fn execute_tool_with_validation(
186        &self,
187        tool_name: &str,
188        args: Value,
189    ) -> Result<Value> {
190        self.execute_tool_with_validation_ref(tool_name, &args)
191            .await
192    }
193
194    // Internal reference-taking implementation to avoid cloning when not necessary.
195    async fn execute_tool_with_validation_ref(
196        &self,
197        tool_name: &str,
198        args: &Value,
199    ) -> Result<Value> {
200        if !self.config.enabled {
201            return Err(anyhow!(
202                "MCP support is disabled in the current configuration"
203            ));
204        }
205
206        self.validate_tool_arguments(tool_name, args)?;
207
208        let provider = self.resolve_provider_for_tool(tool_name).await?;
209        let allowlist_snapshot = self.state.read().allowlist.clone();
210        let result = provider
211            .call_tool(tool_name, args, self.tool_timeout(), &allowlist_snapshot)
212            .await?;
213
214        Self::format_tool_result(&provider.name, tool_name, result)
215    }
216
217    /// Refresh the internal allow list at runtime.
218    pub fn update_allowlist(&self, allowlist: McpAllowListConfig) {
219        let providers: Vec<Arc<McpProvider>> = {
220            let mut state = self.state.write();
221            state.allowlist = allowlist;
222            state.tool_provider_index.clear();
223            state.resource_provider_index.clear();
224            state.prompt_provider_index.clear();
225            state.providers.values().cloned().collect()
226        };
227
228        for provider in providers {
229            provider.invalidate_caches();
230        }
231    }
232
233    /// Current allow list snapshot.
234    pub fn current_allowlist(&self) -> McpAllowListConfig {
235        self.state.read().allowlist.clone()
236    }
237
238    /// Return the provider name serving the given tool if previously cached.
239    pub fn provider_for_tool(&self, tool_name: &str) -> Option<String> {
240        self.state
241            .read()
242            .tool_provider_index
243            .get(tool_name)
244            .cloned()
245    }
246
247    /// Return the provider responsible for the given resource URI if known.
248    pub fn provider_for_resource(&self, uri: &str) -> Option<String> {
249        self.state.read().resource_provider_index.get(uri).cloned()
250    }
251
252    /// Return the provider that exposes the given prompt if known.
253    pub fn provider_for_prompt(&self, prompt_name: &str) -> Option<String> {
254        self.state
255            .read()
256            .prompt_provider_index
257            .get(prompt_name)
258            .cloned()
259    }
260
261    /// Execute a tool call on the appropriate provider.
262    pub async fn execute_tool(&self, tool_name: &str, args: Value) -> Result<Value> {
263        self.execute_tool_with_validation(tool_name, args).await
264    }
265
266    /// List all tools from all active providers.
267    pub async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
268        self.collect_tools(false).await
269    }
270
271    /// List all resources exposed by connected MCP providers.
272    pub async fn list_resources(&self) -> Result<Vec<McpResourceInfo>> {
273        self.collect_resources(false).await
274    }
275
276    /// Force refresh and list resources from providers.
277    pub async fn refresh_resources(&self) -> Result<Vec<McpResourceInfo>> {
278        self.collect_resources(true).await
279    }
280
281    /// List all prompts advertised by connected MCP providers.
282    pub async fn list_prompts(&self) -> Result<Vec<McpPromptInfo>> {
283        self.collect_prompts(false).await
284    }
285
286    /// Force refresh and list prompts from providers.
287    pub async fn refresh_prompts(&self) -> Result<Vec<McpPromptInfo>> {
288        self.collect_prompts(true).await
289    }
290
291    /// Read a single resource from its originating provider.
292    pub async fn read_resource(&self, uri: &str) -> Result<McpResourceData> {
293        let provider = self.resolve_provider_for_resource(uri).await?;
294        let provider_name = provider.name.clone();
295        let allowlist_snapshot = self.state.read().allowlist.clone();
296        let data = provider
297            .read_resource(uri, self.request_timeout(), &allowlist_snapshot)
298            .await?;
299        self.state
300            .write()
301            .resource_provider_index
302            .insert(uri.into(), provider_name);
303        Ok(data)
304    }
305
306    /// Retrieve a rendered prompt from its originating provider.
307    pub async fn get_prompt(
308        &self,
309        prompt_name: &str,
310        arguments: Option<hashbrown::HashMap<String, String>>,
311    ) -> Result<McpPromptDetail> {
312        let provider = self.resolve_provider_for_prompt(prompt_name).await?;
313        let provider_name = provider.name.clone();
314        let allowlist_snapshot = self.state.read().allowlist.clone();
315        let prompt = provider
316            .get_prompt(
317                prompt_name,
318                arguments.unwrap_or_default(),
319                self.request_timeout(),
320                &allowlist_snapshot,
321            )
322            .await?;
323        self.state
324            .write()
325            .prompt_provider_index
326            .insert(prompt_name.into(), provider_name);
327        Ok(prompt)
328    }
329
330    /// Shutdown all active provider connections.
331    pub async fn shutdown(&self) -> Result<()> {
332        let providers: Vec<Arc<McpProvider>> = {
333            let mut state = self.state.write();
334            let values: Vec<_> = state.providers.values().cloned().collect();
335            state.providers.clear();
336            state.tool_provider_index.clear();
337            state.resource_provider_index.clear();
338            state.prompt_provider_index.clear();
339            values
340        };
341
342        if providers.is_empty() {
343            info!("No active MCP connections to shutdown");
344            return Ok(());
345        }
346
347        info!("Shutting down {} MCP providers", providers.len());
348        for provider in providers {
349            if let Err(err) = provider.shutdown().await {
350                warn!(
351                    "Provider '{}' shutdown returned error: {err}",
352                    provider.name
353                );
354            }
355        }
356        Ok(())
357    }
358
359    /// Current status snapshot for UI/debugging purposes.
360    pub fn get_status(&self) -> McpClientStatus {
361        let state = self.state.read();
362        let providers = &state.providers;
363        // Use iterator to collect keys directly without intermediate push
364        let configured_providers: Vec<String> = providers.keys().cloned().collect();
365        McpClientStatus {
366            enabled: self.config.enabled,
367            provider_count: providers.len(),
368            active_connections: providers.len(),
369            configured_providers,
370        }
371    }
372
373    /// Return configured MCP servers and their current connection state.
374    pub fn list_servers(&self) -> Vec<Value> {
375        let state = self.state.read();
376        self.config
377            .providers
378            .iter()
379            .map(|provider_config| {
380                let connected = state.providers.contains_key(&provider_config.name);
381                let (transport, target) = match &provider_config.transport {
382                    McpTransportConfig::Stdio(stdio) => {
383                        ("stdio", Value::String(stdio.command.clone()))
384                    }
385                    McpTransportConfig::Http(http) => {
386                        ("http", Value::String(http.endpoint.clone()))
387                    }
388                };
389
390                json!({
391                    "name": provider_config.name,
392                    "enabled": provider_config.enabled,
393                    "connected": connected,
394                    "connection_state": if connected { "connected" } else { "disconnected" },
395                    "transport": transport,
396                    "target": target,
397                })
398            })
399            .collect()
400    }
401
402    /// Return whether model-callable lifecycle tools are enabled by config.
403    pub fn allow_model_lifecycle_control(&self) -> bool {
404        self.config.lifecycle.allow_model_control
405    }
406
407    /// Connect one configured MCP server by name.
408    pub async fn connect_server(&self, server_name: &str) -> Result<()> {
409        if !self.config.enabled {
410            bail!("MCP support is disabled in the current configuration");
411        }
412
413        if self.state.read().providers.contains_key(server_name) {
414            return Ok(());
415        }
416
417        let provider_config = self
418            .config
419            .providers
420            .iter()
421            .find(|provider| provider.name == server_name)
422            .cloned()
423            .ok_or_else(|| anyhow!("MCP server '{}' is not configured", server_name))?;
424
425        if !provider_config.enabled {
426            bail!("MCP server '{}' is configured but disabled", server_name);
427        }
428
429        if let Some(reason) = self.requirement_mismatch_reason(&provider_config) {
430            bail!(
431                "Cannot connect MCP server '{}': {}",
432                provider_config.name,
433                reason
434            );
435        }
436
437        if matches!(provider_config.transport, McpTransportConfig::Http(_))
438            && !self.config.experimental_use_rmcp_client
439        {
440            bail!(
441                "Cannot connect MCP HTTP server '{}' while experimental_use_rmcp_client is disabled",
442                provider_config.name
443            );
444        }
445
446        let allowlist_snapshot = self.state.read().allowlist.clone();
447        let tool_timeout = self.tool_timeout();
448        let provider = self
449            .connect_and_initialize_provider(&provider_config, &allowlist_snapshot, tool_timeout)
450            .await?;
451
452        if let Err(err) = provider
453            .cached_tools_or_refresh(&allowlist_snapshot, tool_timeout)
454            .await
455        {
456            warn!(
457                "Connected MCP server '{}' but failed to refresh tools: {err}",
458                server_name
459            );
460        } else if let Some(cache) = provider.cached_tools().await {
461            self.record_tool_provider(&provider.name, &cache);
462        }
463
464        self.state
465            .write()
466            .providers
467            .insert(provider.name.clone(), Arc::new(provider));
468        Ok(())
469    }
470
471    /// Disconnect one active MCP server by name.
472    pub async fn disconnect_server(&self, server_name: &str) -> Result<()> {
473        let provider = {
474            let mut state = self.state.write();
475            let provider = state
476                .providers
477                .remove(server_name)
478                .ok_or_else(|| anyhow!("MCP server '{}' is not connected", server_name))?;
479            state
480                .tool_provider_index
481                .retain(|_, provider_name| provider_name != server_name);
482            state
483                .resource_provider_index
484                .retain(|_, provider_name| provider_name != server_name);
485            state
486                .prompt_provider_index
487                .retain(|_, provider_name| provider_name != server_name);
488            provider
489        };
490
491        provider.shutdown().await?;
492        Ok(())
493    }
494
495    /// Sync MCP tool descriptions to files for dynamic context discovery
496    ///
497    /// This implements Cursor-style dynamic context discovery:
498    /// - Tool descriptions are written to `.vtcode/mcp/tools/{provider}/{tool}.md`
499    /// - Status is written to `.vtcode/mcp/status.json`
500    /// - Agents can discover tools via grep/read_file without loading all schemas
501    ///
502    /// Returns the paths to written files (index path, tool count)
503    pub async fn sync_tools_to_files(&self, workspace_root: &Path) -> Result<(PathBuf, usize)> {
504        let tools = self.list_tools().await?;
505        let mcp_dir = workspace_root.join(".vtcode").join("mcp");
506        let tools_dir = mcp_dir.join("tools");
507
508        // Create directories
509        ensure_dir_exists(&tools_dir).await.with_context(|| {
510            format!(
511                "Failed to create MCP tools directory: {}",
512                tools_dir.display()
513            )
514        })?;
515
516        // Group tools by provider
517        let mut by_provider: FxHashMap<String, Vec<&McpToolInfo>> = FxHashMap::default();
518        for tool in &tools {
519            by_provider
520                .entry(tool.provider.clone())
521                .or_default()
522                .push(tool);
523        }
524
525        // Write tool files per provider
526        for (provider, provider_tools) in &by_provider {
527            let provider_dir = tools_dir.join(sanitize_filename(provider));
528            ensure_dir_exists(&provider_dir).await.with_context(|| {
529                format!(
530                    "Failed to create provider directory: {}",
531                    provider_dir.display()
532                )
533            })?;
534
535            for tool in provider_tools {
536                let tool_content = format_tool_markdown(tool);
537                let tool_path = provider_dir.join(format!("{}.md", sanitize_filename(&tool.name)));
538                write_file_with_context(&tool_path, &tool_content, "MCP tool file")
539                    .await
540                    .with_context(|| {
541                        format!("Failed to write tool file: {}", tool_path.display())
542                    })?;
543            }
544        }
545
546        // Write index file
547        let index_content = self.generate_tools_index(&tools, &by_provider);
548        let index_path = tools_dir.join("INDEX.md");
549        write_file_with_context(&index_path, &index_content, "MCP tools index")
550            .await
551            .with_context(|| {
552                format!("Failed to write MCP tools index: {}", index_path.display())
553            })?;
554
555        // Write status file
556        let status = self.generate_status_json();
557        let status_path = mcp_dir.join("status.json");
558        let status_json = serde_json::to_string_pretty(&status)?;
559        write_file_with_context(&status_path, &status_json, "MCP status")
560            .await
561            .with_context(|| format!("Failed to write MCP status: {}", status_path.display()))?;
562
563        info!(
564            tools = tools.len(),
565            providers = by_provider.len(),
566            index = %index_path.display(),
567            "Synced MCP tool descriptions to files"
568        );
569
570        Ok((index_path, tools.len()))
571    }
572
573    /// Generate INDEX.md content for MCP tools
574    fn generate_tools_index(
575        &self,
576        tools: &[McpToolInfo],
577        by_provider: &FxHashMap<String, Vec<&McpToolInfo>>,
578    ) -> String {
579        let mut content = String::new();
580        content.push_str("# MCP Tools Index\n\n");
581        content.push_str("This file lists all available MCP tools for dynamic discovery.\n");
582        content.push_str("Use `read_file` on individual tool files for full schema details.\n\n");
583
584        if tools.is_empty() {
585            content.push_str("*No MCP tools available.*\n\n");
586            content.push_str("Configure MCP servers in `vtcode.toml` or `.mcp.json`.\n");
587        } else {
588            content.push_str(&format!("**Total Tools**: {}\n\n", tools.len()));
589
590            // Summary table
591            content.push_str("## Quick Reference\n\n");
592            content.push_str("| Provider | Tool | Description |\n");
593            content.push_str("|----------|------|-------------|\n");
594
595            for tool in tools {
596                let desc = tool.description.lines().next().unwrap_or(&tool.description);
597                let desc_truncated =
598                    vtcode_commons::formatting::truncate_byte_budget(desc, 57, "...");
599                content.push_str(&format!(
600                    "| {} | `{}` | {} |\n",
601                    tool.provider,
602                    tool.name,
603                    desc_truncated.replace('|', "\\|")
604                ));
605            }
606
607            // Per-provider sections
608            content.push_str("\n## Tools by Provider\n\n");
609            for (provider, provider_tools) in by_provider {
610                content.push_str(&format!("### {}\n\n", provider));
611                for tool in provider_tools {
612                    content.push_str(&format!(
613                        "- **{}**: {}\n  - Path: `.vtcode/mcp/tools/{}/{}.md`\n",
614                        tool.name,
615                        tool.description.lines().next().unwrap_or(&tool.description),
616                        sanitize_filename(provider),
617                        sanitize_filename(&tool.name)
618                    ));
619                }
620                content.push('\n');
621            }
622        }
623
624        content.push_str("\n---\n");
625        content.push_str("*Generated automatically. Do not edit manually.*\n");
626
627        content
628    }
629
630    /// Generate status.json content
631    fn generate_status_json(&self) -> Value {
632        let status = self.get_status();
633        json!({
634            "enabled": status.enabled,
635            "provider_count": status.provider_count,
636            "active_connections": status.active_connections,
637            "configured_providers": status.configured_providers,
638            "last_updated": Utc::now().to_rfc3339(),
639        })
640    }
641
642    async fn collect_tools(&self, force_refresh: bool) -> Result<Vec<McpToolInfo>> {
643        // Collect provider references in one pass
644        let (providers, allowlist) = {
645            let state = self.state.read();
646            (
647                state.providers.values().cloned().collect::<Vec<_>>(),
648                state.allowlist.clone(),
649            )
650        };
651
652        if providers.is_empty() {
653            return Ok(Vec::new());
654        }
655
656        let timeout = self.tool_timeout();
657        let mut all_tools = Vec::with_capacity(128);
658        let mut index_updates: FxHashMap<String, String> =
659            FxHashMap::with_capacity_and_hasher(128, Default::default());
660
661        for provider in providers {
662            let provider_name = provider.name.clone();
663            let tools = if force_refresh {
664                provider.refresh_tools(&allowlist, timeout).await
665            } else {
666                provider.list_tools(&allowlist, timeout).await
667            };
668
669            match tools {
670                Ok(tools) => {
671                    for tool in &tools {
672                        index_updates.insert(tool.name.clone(), provider_name.clone());
673                    }
674                    all_tools.extend(tools);
675                }
676                Err(err) => {
677                    warn!(
678                        "Failed to list tools for provider '{}': {err}",
679                        provider_name
680                    );
681                }
682            }
683        }
684
685        if !index_updates.is_empty() || force_refresh {
686            let mut state = self.state.write();
687            if index_updates.is_empty() {
688                state.tool_provider_index.clear();
689            } else {
690                state.tool_provider_index = index_updates;
691            }
692        }
693
694        Ok(all_tools)
695    }
696
697    async fn collect_resources(&self, force_refresh: bool) -> Result<Vec<McpResourceInfo>> {
698        // Collect provider references in one pass
699        let (providers, allowlist) = {
700            let state = self.state.read();
701            (
702                state.providers.values().cloned().collect::<Vec<_>>(),
703                state.allowlist.clone(),
704            )
705        };
706
707        if providers.is_empty() {
708            self.state.write().resource_provider_index.clear();
709            return Ok(Vec::new());
710        }
711
712        let timeout = self.request_timeout();
713        let mut all_resources = Vec::with_capacity(64);
714
715        for provider in providers {
716            let resources = if force_refresh {
717                provider.refresh_resources(&allowlist, timeout).await
718            } else {
719                provider.list_resources(&allowlist, timeout).await
720            };
721
722            match resources {
723                Ok(resources) => {
724                    all_resources.extend(resources);
725                }
726                Err(err) => {
727                    warn!(
728                        "Failed to list resources for provider '{}': {err}",
729                        provider.name
730                    );
731                }
732            }
733        }
734
735        let mut state = self.state.write();
736        let index = &mut state.resource_provider_index;
737        index.clear();
738        for resource in &all_resources {
739            index.insert(resource.uri.clone(), resource.provider.clone());
740        }
741
742        Ok(all_resources)
743    }
744
745    async fn collect_prompts(&self, force_refresh: bool) -> Result<Vec<McpPromptInfo>> {
746        // Collect provider references in one pass
747        let (providers, allowlist) = {
748            let state = self.state.read();
749            (
750                state.providers.values().cloned().collect::<Vec<_>>(),
751                state.allowlist.clone(),
752            )
753        };
754
755        if providers.is_empty() {
756            self.state.write().prompt_provider_index.clear();
757            return Ok(Vec::new());
758        }
759
760        let timeout = self.request_timeout();
761        let mut all_prompts = Vec::with_capacity(32);
762
763        for provider in providers {
764            let prompts = if force_refresh {
765                provider.refresh_prompts(&allowlist, timeout).await
766            } else {
767                provider.list_prompts(&allowlist, timeout).await
768            };
769
770            match prompts {
771                Ok(prompts) => {
772                    all_prompts.extend(prompts);
773                }
774                Err(err) => {
775                    warn!(
776                        "Failed to list prompts for provider '{}': {err}",
777                        provider.name
778                    );
779                }
780            }
781        }
782
783        let mut state = self.state.write();
784        let index = &mut state.prompt_provider_index;
785        index.clear();
786        for prompt in &all_prompts {
787            index.insert(prompt.name.clone(), prompt.provider.clone());
788        }
789
790        Ok(all_prompts)
791    }
792
793    async fn resolve_provider_for_tool(&self, tool_name: &str) -> Result<Arc<McpProvider>> {
794        if !self.config.enabled {
795            return Err(anyhow!(
796                "MCP support is disabled in the current configuration"
797            ));
798        }
799
800        if let Some(provider) = self.provider_for_tool(tool_name)
801            && let Some(found) = self.state.read().providers.get(&provider)
802        {
803            return Ok(found.clone());
804        }
805
806        let (allowlist, providers) = {
807            let state = self.state.read();
808            (
809                state.allowlist.clone(),
810                state.providers.values().cloned().collect::<Vec<_>>(),
811            )
812        };
813        let timeout = self.tool_timeout();
814
815        if providers.is_empty() {
816            if self.config.providers.is_empty() {
817                return Err(anyhow!(
818                    "No MCP providers are configured. Use `vtcode mcp add` or update vtcode.toml to register one."
819                ));
820            }
821
822            return Err(anyhow!(
823                "No MCP providers are currently connected. Ensure MCP initialization completed successfully."
824            ));
825        }
826
827        for provider in providers {
828            match provider.has_tool(tool_name, &allowlist, timeout).await {
829                Ok(true) => {
830                    self.state
831                        .write()
832                        .tool_provider_index
833                        .insert(tool_name.into(), provider.name.clone());
834                    return Ok(provider);
835                }
836                Ok(false) => continue,
837                Err(err) => {
838                    warn!(
839                        "Error checking tool '{}' on provider '{}': {err}",
840                        tool_name, provider.name
841                    );
842                }
843            }
844        }
845
846        match self.collect_tools(true).await {
847            Ok(_) => {
848                if let Some(provider) = self.provider_for_tool(tool_name)
849                    && let Some(found) = self.state.read().providers.get(&provider)
850                {
851                    return Ok(found.clone());
852                }
853            }
854            Err(err) => {
855                warn!(
856                    "Failed to refresh MCP tool caches while resolving '{}': {err}",
857                    tool_name
858                );
859            }
860        }
861
862        Err(anyhow!(
863            "Tool '{}' not found on any MCP provider.\n\n\
864            To use this tool:\n\
865            1. Install the MCP server: `uv tool install mcp-server-{}`\n\
866            2. Add to vtcode.toml:\n   \
867               [[mcp.providers]]\n   \
868               name = \"{}\"\n   \
869               command = \"uvx\"\n   \
870               args = [\"mcp-server-{}\"]\n\
871            3. Restart VT Code\n\n\
872            Or use the built-in alternative if available (e.g., web_fetch instead of mcp_fetch)",
873            tool_name,
874            tool_name,
875            tool_name,
876            tool_name
877        ))
878    }
879
880    async fn resolve_provider_for_resource(&self, uri: &str) -> Result<Arc<McpProvider>> {
881        if let Some(provider) = self.provider_for_resource(uri)
882            && let Some(found) = self.state.read().providers.get(&provider)
883        {
884            return Ok(found.clone());
885        }
886
887        let (allowlist, providers) = {
888            let state = self.state.read();
889            (
890                state.allowlist.clone(),
891                state.providers.values().cloned().collect::<Vec<_>>(),
892            )
893        };
894        let timeout = self.request_timeout();
895
896        for provider in providers {
897            match provider.has_resource(uri, &allowlist, timeout).await {
898                Ok(true) => {
899                    self.state
900                        .write()
901                        .resource_provider_index
902                        .insert(uri.into(), provider.name.clone());
903                    return Ok(provider);
904                }
905                Ok(false) => continue,
906                Err(err) => {
907                    warn!(
908                        "Error checking resource '{}' on provider '{}': {err}",
909                        uri, provider.name
910                    );
911                }
912            }
913        }
914
915        Err(anyhow!("Resource '{}' not found on any MCP provider", uri))
916    }
917
918    async fn resolve_provider_for_prompt(&self, prompt_name: &str) -> Result<Arc<McpProvider>> {
919        if let Some(provider) = self.provider_for_prompt(prompt_name)
920            && let Some(found) = self.state.read().providers.get(&provider)
921        {
922            return Ok(found.clone());
923        }
924
925        let (allowlist, providers) = {
926            let state = self.state.read();
927            (
928                state.allowlist.clone(),
929                state.providers.values().cloned().collect::<Vec<_>>(),
930            )
931        };
932        let timeout = self.request_timeout();
933
934        for provider in providers {
935            match provider.has_prompt(prompt_name, &allowlist, timeout).await {
936                Ok(true) => {
937                    self.state
938                        .write()
939                        .prompt_provider_index
940                        .insert(prompt_name.into(), provider.name.clone());
941                    return Ok(provider);
942                }
943                Ok(false) => continue,
944                Err(err) => {
945                    warn!(
946                        "Error checking prompt '{}' on provider '{}': {err}",
947                        prompt_name, provider.name
948                    );
949                }
950            }
951        }
952
953        Err(anyhow!(
954            "Prompt '{}' not found on any MCP provider",
955            prompt_name
956        ))
957    }
958
959    fn record_tool_provider(&self, provider: &str, tools: &[McpToolInfo]) {
960        let mut state = self.state.write();
961        let index = &mut state.tool_provider_index;
962        for tool in tools {
963            index.insert(tool.name.clone(), provider.to_string());
964        }
965    }
966
967    async fn connect_and_initialize_provider(
968        &self,
969        provider_config: &McpProviderConfig,
970        allowlist_snapshot: &McpAllowListConfig,
971        tool_timeout: Option<Duration>,
972    ) -> Result<McpProvider> {
973        let total_attempts = self.provider_retry_attempts();
974        let mut last_error: Option<anyhow::Error> = None;
975
976        for attempt_idx in 0..total_attempts {
977            let attempt_number = attempt_idx + 1;
978            match self
979                .connect_and_initialize_provider_once(
980                    provider_config,
981                    allowlist_snapshot,
982                    tool_timeout,
983                )
984                .await
985            {
986                Ok(provider) => return Ok(provider),
987                Err(err) => {
988                    if attempt_number == total_attempts {
989                        return Err(err);
990                    }
991
992                    let retries_remaining = total_attempts - attempt_number;
993                    warn!(
994                        provider = provider_config.name.as_str(),
995                        attempt = attempt_number,
996                        retries_remaining,
997                        error = %err,
998                        "MCP provider initialization failed; retrying"
999                    );
1000                    last_error = Some(err);
1001                    tokio::time::sleep(Self::provider_retry_delay(attempt_idx)).await;
1002                }
1003            }
1004        }
1005
1006        Err(last_error.unwrap_or_else(|| {
1007            anyhow!(
1008                "Failed to initialize MCP provider '{}'",
1009                provider_config.name
1010            )
1011        }))
1012    }
1013
1014    async fn connect_and_initialize_provider_once(
1015        &self,
1016        provider_config: &McpProviderConfig,
1017        allowlist_snapshot: &McpAllowListConfig,
1018        tool_timeout: Option<Duration>,
1019    ) -> Result<McpProvider> {
1020        let provider =
1021            McpProvider::connect(provider_config.clone(), self.elicitation_handler.clone())
1022                .await
1023                .with_context(|| {
1024                    format!(
1025                        "Failed to connect to MCP provider '{}'",
1026                        provider_config.name
1027                    )
1028                })?;
1029        let provider_startup_timeout = self.resolve_startup_timeout(provider_config);
1030        provider
1031            .initialize(
1032                self.build_initialize_params(&provider),
1033                provider_startup_timeout,
1034                tool_timeout,
1035                allowlist_snapshot,
1036            )
1037            .await
1038            .with_context(|| {
1039                format!(
1040                    "Failed to initialize MCP provider '{}'",
1041                    provider_config.name
1042                )
1043            })?;
1044        Ok(provider)
1045    }
1046
1047    fn startup_timeout(&self) -> Option<Duration> {
1048        match self.config.startup_timeout_seconds {
1049            Some(0) => None,
1050            Some(value) => Some(Duration::from_secs(value)),
1051            None => self.request_timeout(),
1052        }
1053    }
1054
1055    fn requirement_mismatch_reason(&self, provider_config: &McpProviderConfig) -> Option<String> {
1056        let requirements = &self.config.requirements;
1057        if !requirements.enforce {
1058            return None;
1059        }
1060
1061        match &provider_config.transport {
1062            McpTransportConfig::Stdio(stdio) => {
1063                if requirements
1064                    .allowed_stdio_commands
1065                    .iter()
1066                    .any(|allowed| allowed == &stdio.command)
1067                {
1068                    None
1069                } else {
1070                    Some(format!(
1071                        "stdio command '{}' is not allowlisted",
1072                        stdio.command
1073                    ))
1074                }
1075            }
1076            McpTransportConfig::Http(http) => {
1077                if requirements
1078                    .allowed_http_endpoints
1079                    .iter()
1080                    .any(|allowed| allowed == &http.endpoint)
1081                {
1082                    None
1083                } else {
1084                    Some(format!(
1085                        "HTTP endpoint '{}' is not allowlisted",
1086                        http.endpoint
1087                    ))
1088                }
1089            }
1090        }
1091    }
1092
1093    fn resolve_startup_timeout(&self, provider_config: &McpProviderConfig) -> Option<Duration> {
1094        if let Some(timeout_ms) = provider_config.startup_timeout_ms {
1095            if timeout_ms == 0 {
1096                None
1097            } else {
1098                Some(Duration::from_millis(timeout_ms))
1099            }
1100        } else {
1101            self.startup_timeout()
1102        }
1103    }
1104
1105    fn provider_retry_attempts(&self) -> usize {
1106        self.config
1107            .retry_attempts
1108            .try_into()
1109            .unwrap_or(usize::MAX)
1110            .saturating_add(1)
1111    }
1112
1113    fn provider_retry_delay(attempt_idx: usize) -> Duration {
1114        let step = u64::try_from(attempt_idx)
1115            .unwrap_or(u64::MAX)
1116            .saturating_add(1);
1117        Duration::from_millis((step * 250).min(1_000))
1118    }
1119
1120    fn tool_timeout(&self) -> Option<Duration> {
1121        match self.config.tool_timeout_seconds {
1122            Some(0) => None,
1123            Some(value) => Some(Duration::from_secs(value)),
1124            None => self.request_timeout(),
1125        }
1126    }
1127
1128    fn request_timeout(&self) -> Option<Duration> {
1129        if self.config.request_timeout_seconds == 0 {
1130            None
1131        } else {
1132            Some(Duration::from_secs(self.config.request_timeout_seconds))
1133        }
1134    }
1135
1136    fn build_initialize_params(&self, _provider: &McpProvider) -> InitializeRequestParams {
1137        let mut capabilities = ClientCapabilities::default();
1138        capabilities.roots = Some(RootsCapabilities {
1139            list_changed: Some(true),
1140        });
1141
1142        if self.elicitation_handler.is_some() {
1143            // Elicitation is now a first-class capability in rmcp
1144            capabilities.elicitation = Some(rmcp::model::ElicitationCapability {
1145                form: Some(rmcp::model::FormElicitationCapability {
1146                    schema_validation: Some(true),
1147                }),
1148                ..Default::default()
1149            });
1150        }
1151
1152        InitializeRequestParams::new(capabilities, super::utils::build_client_implementation())
1153            .with_protocol_version(rmcp::model::ProtocolVersion::V_2024_11_05)
1154    }
1155
1156    pub(super) fn normalize_arguments(args: &Value) -> Map<String, Value> {
1157        match args {
1158            Value::Null => Map::new(),
1159            Value::Object(map) => map.clone(),
1160            other => {
1161                let mut map = Map::new();
1162                map.insert("value".to_owned(), other.clone());
1163                map
1164            }
1165        }
1166    }
1167
1168    fn format_tool_result(
1169        provider_name: &str,
1170        tool_name: &str,
1171        result: CallToolResult,
1172    ) -> Result<Value> {
1173        // Convert result to JSON to access fields flexibly
1174        let result_json = serde_json::to_value(&result)?;
1175        let result_obj = result_json.as_object();
1176
1177        // Check for error - handle both rmcp's is_error field and meta message
1178        let is_error = result_obj
1179            .and_then(|o| o.get("isError"))
1180            .or_else(|| result_obj.and_then(|o| o.get("is_error")))
1181            .and_then(Value::as_bool)
1182            .unwrap_or(false);
1183
1184        if is_error {
1185            let mut message = result_obj
1186                .and_then(|o| o.get("_meta"))
1187                .or_else(|| result_obj.and_then(|o| o.get("meta")))
1188                .and_then(|m| m.get("message"))
1189                .and_then(Value::as_str)
1190                .map(str::to_owned);
1191
1192            // Try to find text content in the content array
1193            if message.is_none()
1194                && let Some(content) = result_obj
1195                    .and_then(|o| o.get("content"))
1196                    .and_then(Value::as_array)
1197            {
1198                message = content
1199                    .iter()
1200                    .find_map(|block| block.get("text").and_then(Value::as_str).map(str::to_owned));
1201            }
1202
1203            let message = message.unwrap_or_else(|| "Unknown MCP tool error".to_owned());
1204            return Err(anyhow!(
1205                "MCP tool '{}' on provider '{}' reported an error: {}",
1206                tool_name,
1207                provider_name,
1208                message
1209            ));
1210        }
1211
1212        let mut payload = Map::new();
1213        payload.insert("provider".into(), Value::String(provider_name.to_string()));
1214        payload.insert("tool".into(), Value::String(tool_name.to_string()));
1215
1216        // Add meta if present
1217        if let Some(meta) = result_obj
1218            .and_then(|o| o.get("_meta"))
1219            .or_else(|| result_obj.and_then(|o| o.get("meta")))
1220            .and_then(Value::as_object)
1221            && !meta.is_empty()
1222        {
1223            payload.insert("meta".into(), Value::Object(meta.clone()));
1224        }
1225
1226        // Add content if present
1227        if let Some(content) = result_obj.and_then(|o| o.get("content"))
1228            && !content.is_null()
1229            && !content.as_array().map(|a| a.is_empty()).unwrap_or(true)
1230        {
1231            payload.insert("content".into(), content.clone());
1232        }
1233
1234        Ok(Value::Object(payload))
1235    }
1236}
1237
1238#[async_trait]
1239impl McpToolExecutor for McpClient {
1240    async fn execute_mcp_tool(&self, tool_name: &str, args: &Value) -> Result<Value> {
1241        self.execute_tool_with_validation_ref(tool_name, args).await
1242    }
1243
1244    async fn list_mcp_tools(&self) -> Result<Vec<McpToolInfo>> {
1245        self.collect_tools(false).await
1246    }
1247
1248    async fn has_mcp_tool(&self, tool_name: &str) -> Result<bool> {
1249        if !self.config.enabled {
1250            return Ok(false);
1251        }
1252
1253        if self.provider_for_tool(tool_name).is_some() {
1254            return Ok(true);
1255        }
1256
1257        if self.state.read().providers.is_empty() {
1258            if self.config.providers.is_empty() {
1259                return Ok(false);
1260            }
1261
1262            bail!(
1263                "No MCP providers are currently connected. Ensure MCP initialization completed successfully."
1264            );
1265        }
1266
1267        let tools = self.collect_tools(false).await?;
1268        Ok(tools.iter().any(|tool| tool.name == tool_name))
1269    }
1270
1271    fn get_status(&self) -> McpClientStatus {
1272        self.get_status()
1273    }
1274}
1275
1276#[cfg(test)]
1277mod tests {
1278    use super::McpClient;
1279    use crate::config::mcp::{
1280        McpClientConfig, McpHttpServerConfig, McpProviderConfig, McpRequirementsConfig,
1281        McpStdioServerConfig, McpTransportConfig,
1282    };
1283
1284    fn base_config() -> McpClientConfig {
1285        McpClientConfig {
1286            enabled: true,
1287            requirements: McpRequirementsConfig {
1288                enforce: true,
1289                allowed_stdio_commands: vec!["uvx".to_string()],
1290                allowed_http_endpoints: vec!["https://allowed.example/mcp".to_string()],
1291            },
1292            ..McpClientConfig::default()
1293        }
1294    }
1295
1296    #[test]
1297    fn requirements_allow_matching_stdio_command() {
1298        let client = McpClient::new(base_config());
1299        let provider = McpProviderConfig {
1300            name: "time".to_string(),
1301            transport: McpTransportConfig::Stdio(McpStdioServerConfig {
1302                command: "uvx".to_string(),
1303                args: vec![],
1304                working_directory: None,
1305            }),
1306            ..McpProviderConfig::default()
1307        };
1308
1309        assert!(client.requirement_mismatch_reason(&provider).is_none());
1310    }
1311
1312    #[test]
1313    fn requirements_block_unmatched_stdio_command() {
1314        let client = McpClient::new(base_config());
1315        let provider = McpProviderConfig {
1316            name: "time".to_string(),
1317            transport: McpTransportConfig::Stdio(McpStdioServerConfig {
1318                command: "npx".to_string(),
1319                args: vec![],
1320                working_directory: None,
1321            }),
1322            ..McpProviderConfig::default()
1323        };
1324
1325        assert!(
1326            client
1327                .requirement_mismatch_reason(&provider)
1328                .is_some_and(|reason| reason.contains("not allowlisted"))
1329        );
1330    }
1331
1332    #[test]
1333    fn requirements_block_unmatched_http_endpoint() {
1334        let client = McpClient::new(base_config());
1335        let provider = McpProviderConfig {
1336            name: "remote".to_string(),
1337            transport: McpTransportConfig::Http(McpHttpServerConfig {
1338                endpoint: "https://blocked.example/mcp".to_string(),
1339                ..McpHttpServerConfig::default()
1340            }),
1341            ..McpProviderConfig::default()
1342        };
1343
1344        assert!(
1345            client
1346                .requirement_mismatch_reason(&provider)
1347                .is_some_and(|reason| reason.contains("not allowlisted"))
1348        );
1349    }
1350
1351    #[test]
1352    fn list_servers_includes_configured_provider_metadata() {
1353        let mut config = base_config();
1354        config.providers = vec![McpProviderConfig {
1355            name: "calendar".to_string(),
1356            transport: McpTransportConfig::Http(McpHttpServerConfig {
1357                endpoint: "https://calendar.example/mcp".to_string(),
1358                ..McpHttpServerConfig::default()
1359            }),
1360            ..McpProviderConfig::default()
1361        }];
1362
1363        let client = McpClient::new(config);
1364        let servers = client.list_servers();
1365        assert_eq!(servers.len(), 1);
1366        assert_eq!(servers[0]["name"], "calendar");
1367        assert_eq!(servers[0]["connected"], false);
1368        assert_eq!(servers[0]["connection_state"], "disconnected");
1369        assert_eq!(servers[0]["transport"], "http");
1370        assert_eq!(servers[0]["target"], "https://calendar.example/mcp");
1371    }
1372
1373    #[tokio::test]
1374    async fn connect_server_rejects_unknown_server_name() {
1375        let client = McpClient::new(base_config());
1376        let err = client
1377            .connect_server("missing")
1378            .await
1379            .expect_err("missing server should error");
1380        assert!(err.to_string().contains("not configured"));
1381    }
1382
1383    #[tokio::test]
1384    async fn disconnect_server_rejects_unknown_server_name() {
1385        let client = McpClient::new(base_config());
1386        let err = client
1387            .disconnect_server("missing")
1388            .await
1389            .expect_err("missing server should error");
1390        assert!(err.to_string().contains("not connected"));
1391    }
1392}