Skip to main content

vtcode_core/mcp/
connection_pool.rs

1//! MCP connection pool for efficient provider management
2//!
3//! This module provides connection pooling and parallel initialization
4//! for MCP providers to eliminate sequential connection bottlenecks.
5
6use hashbrown::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{RwLock, Semaphore};
10use tracing::{error, info, warn};
11
12use super::{McpElicitationHandler, McpProvider};
13use crate::config::mcp::{McpAllowListConfig, McpProviderConfig};
14use rmcp::model::{ClientCapabilities, InitializeRequestParams};
15
16/// MCP connection pool for efficient provider management
17pub struct McpConnectionPool {
18    /// Active provider connections
19    providers: Arc<RwLock<HashMap<String, Arc<McpProvider>>>>,
20    /// Connection semaphore to limit concurrent connections
21    connection_semaphore: Arc<Semaphore>,
22    /// Maximum connections allowed concurrently
23    max_concurrent_connections: usize,
24    /// Connection timeout
25    connection_timeout: Duration,
26}
27
28impl McpConnectionPool {
29    pub fn new(max_concurrent_connections: usize, connection_timeout_seconds: u64) -> Self {
30        Self {
31            providers: Arc::new(RwLock::new(HashMap::new())),
32            connection_semaphore: Arc::new(Semaphore::new(max_concurrent_connections)),
33            max_concurrent_connections,
34            connection_timeout: Duration::from_secs(connection_timeout_seconds),
35        }
36    }
37
38    /// Initialize multiple providers in parallel with controlled concurrency
39    pub async fn initialize_providers_parallel(
40        &self,
41        provider_configs: Vec<McpProviderConfig>,
42        elicitation_handler: Option<Arc<dyn McpElicitationHandler>>,
43        tool_timeout: Option<Duration>,
44        allowlist_snapshot: &McpAllowListConfig,
45    ) -> Result<Vec<(String, Arc<McpProvider>)>, McpPoolError> {
46        use futures::future::join_all;
47
48        // Create initialization tasks for each provider
49        let tasks: Vec<_> = provider_configs
50            .into_iter()
51            .map(|config| {
52                let elicitation_handler = elicitation_handler.clone();
53                let allowlist_snapshot = allowlist_snapshot.clone();
54
55                async move {
56                    self.initialize_provider(
57                        config,
58                        elicitation_handler,
59                        tool_timeout.unwrap_or(Duration::from_secs(30)),
60                        allowlist_snapshot,
61                    )
62                    .await
63                }
64            })
65            .collect();
66
67        // Execute all tasks in parallel
68        let results = join_all(tasks).await;
69
70        // Collect successful connections
71        let mut successful_providers = Vec::new();
72        let mut errors = Vec::new();
73
74        for result in results {
75            match result {
76                Ok((name, provider)) => {
77                    successful_providers.push((name, provider));
78                }
79                Err(error) => {
80                    errors.push(error);
81                }
82            }
83        }
84
85        if !errors.is_empty() {
86            warn!("Some MCP provider connections failed: {:?}", errors);
87        }
88
89        Ok(successful_providers)
90    }
91
92    /// Initialize a single provider with connection pooling
93    async fn initialize_provider(
94        &self,
95        config: McpProviderConfig,
96        elicitation_handler: Option<Arc<dyn McpElicitationHandler>>,
97        tool_timeout: Duration,
98        allowlist_snapshot: McpAllowListConfig,
99    ) -> Result<(String, Arc<McpProvider>), McpPoolError> {
100        // Acquire semaphore permit to limit concurrent connections
101        let _permit = self
102            .connection_semaphore
103            .acquire()
104            .await
105            .map_err(|e| McpPoolError::SemaphoreError(e.to_string()))?;
106
107        info!("Initializing MCP provider '{}'", config.name);
108
109        // Connect to provider with timeout
110        let provider = tokio::time::timeout(
111            self.connection_timeout,
112            McpProvider::connect(config.clone(), elicitation_handler),
113        )
114        .await
115        .map_err(|_| McpPoolError::ConnectionTimeout(config.name.clone()))?
116        .map_err(|e| McpPoolError::ConnectionError(config.name.clone(), e.to_string()))?;
117
118        // Initialize the provider with proper parameters
119        let provider_startup_timeout = self.resolve_startup_timeout(&config);
120        let initialize_params = build_pool_initialize_params(&provider);
121        let tool_timeout_opt = Some(tool_timeout);
122
123        if let Err(err) = provider
124            .initialize(
125                initialize_params,
126                provider_startup_timeout,
127                tool_timeout_opt,
128                &allowlist_snapshot,
129            )
130            .await
131        {
132            return Err(McpPoolError::InitializationError(
133                config.name.clone(),
134                err.to_string(),
135            ));
136        }
137
138        // Refresh tools
139        if let Err(err) = provider
140            .cached_tools_or_refresh(&allowlist_snapshot, tool_timeout_opt)
141            .await
142        {
143            warn!(
144                "Failed to fetch tools for provider '{}': {}",
145                config.name, err
146            );
147        }
148
149        info!("Successfully initialized MCP provider '{}'", config.name);
150
151        Ok((config.name.clone(), Arc::new(provider)))
152    }
153
154    /// Get a provider by name
155    pub async fn get_provider(&self, name: &str) -> Option<Arc<McpProvider>> {
156        let providers = self.providers.read().await;
157        providers.get(name).cloned()
158    }
159
160    /// Get all active providers
161    pub async fn get_all_providers(&self) -> Vec<Arc<McpProvider>> {
162        let providers = self.providers.read().await;
163        providers.values().cloned().collect()
164    }
165
166    /// Remove a provider from the pool
167    pub async fn remove_provider(&self, name: &str) -> Option<Arc<McpProvider>> {
168        let mut providers = self.providers.write().await;
169        providers.remove(name)
170    }
171
172    /// Check if a provider exists in the pool
173    pub async fn has_provider(&self, name: &str) -> bool {
174        let providers = self.providers.read().await;
175        providers.contains_key(name)
176    }
177
178    /// Get connection pool statistics
179    pub async fn stats(&self) -> ConnectionPoolStats {
180        let providers = self.providers.read().await;
181        let semaphore = self.connection_semaphore.available_permits();
182
183        ConnectionPoolStats {
184            active_connections: providers.len(),
185            available_permits: semaphore,
186            max_connections: self.max_concurrent_connections,
187        }
188    }
189
190    /// Shutdown all providers gracefully
191    pub async fn shutdown_all(&self) {
192        let providers: Vec<_> = {
193            let mut providers = self.providers.write().await;
194            providers.drain().collect()
195        };
196
197        for (name, provider) in providers {
198            if let Err(err) = provider.shutdown().await {
199                error!("Failed to shutdown MCP provider '{}': {}", name, err);
200            }
201        }
202    }
203
204    /// Check the health of all active providers.
205    ///
206    /// Returns a map of provider name → `true` (healthy) / `false` (unhealthy).
207    pub async fn health_check(&self) -> HashMap<String, bool> {
208        let providers: Vec<_> = {
209            let providers = self.providers.read().await;
210            providers
211                .iter()
212                .map(|(name, provider)| (name.clone(), Arc::clone(provider)))
213                .collect()
214        };
215        let mut results = HashMap::with_capacity(providers.len());
216        for (name, provider) in providers {
217            results.insert(name, provider.is_healthy().await);
218        }
219        results
220    }
221
222    /// Attempt to reconnect any unhealthy providers.
223    ///
224    /// Returns the names of providers that were successfully reconnected.
225    pub async fn reconnect_unhealthy(
226        &self,
227        startup_timeout: Option<Duration>,
228        tool_timeout: Option<Duration>,
229        allowlist: &McpAllowListConfig,
230    ) -> Vec<String> {
231        let providers: Vec<_> = {
232            let providers = self.providers.read().await;
233            providers
234                .iter()
235                .map(|(name, provider)| (name.clone(), Arc::clone(provider)))
236                .collect()
237        };
238        let mut reconnected = Vec::new();
239        for (name, provider) in providers {
240            if !provider.is_healthy().await {
241                info!("Provider '{}' is unhealthy, attempting reconnect", name);
242                match provider
243                    .reconnect(startup_timeout, tool_timeout, allowlist)
244                    .await
245                {
246                    Ok(()) => {
247                        info!("Successfully reconnected MCP provider '{}'", name);
248                        reconnected.push(name);
249                    }
250                    Err(err) => {
251                        error!("Failed to reconnect MCP provider '{}': {}", name, err);
252                    }
253                }
254            }
255        }
256        reconnected
257    }
258
259    /// Resolve startup timeout based on provider configuration
260    fn resolve_startup_timeout(&self, config: &McpProviderConfig) -> Option<Duration> {
261        config.startup_timeout_ms.map(Duration::from_millis)
262    }
263}
264
265/// Connection pool statistics
266#[derive(Debug, Clone)]
267pub struct ConnectionPoolStats {
268    pub active_connections: usize,
269    pub available_permits: usize,
270    pub max_connections: usize,
271}
272
273/// Enhanced MCP manager with connection pooling
274pub struct PooledMcpManager {
275    /// Connection pool for providers
276    pool: Arc<McpConnectionPool>,
277    /// Tool discovery cache
278    tool_cache: Arc<super::tool_discovery_cache::ToolDiscoveryCache>,
279}
280
281impl PooledMcpManager {
282    pub fn new(
283        max_concurrent_connections: usize,
284        connection_timeout_seconds: u64,
285        tool_cache_capacity: usize,
286    ) -> Self {
287        Self {
288            pool: Arc::new(McpConnectionPool::new(
289                max_concurrent_connections,
290                connection_timeout_seconds,
291            )),
292            tool_cache: Arc::new(super::tool_discovery_cache::ToolDiscoveryCache::new(
293                tool_cache_capacity,
294            )),
295        }
296    }
297
298    /// Initialize providers with pooling and caching
299    pub async fn initialize_providers(
300        &self,
301        provider_configs: Vec<McpProviderConfig>,
302        elicitation_handler: Option<Arc<dyn McpElicitationHandler>>,
303        tool_timeout: Option<Duration>,
304        allowlist_snapshot: &McpAllowListConfig,
305    ) -> Result<Vec<(String, Arc<McpProvider>)>, McpPoolError> {
306        // Initialize providers in parallel
307        let providers = self
308            .pool
309            .initialize_providers_parallel(
310                provider_configs,
311                elicitation_handler,
312                tool_timeout,
313                allowlist_snapshot,
314            )
315            .await?;
316
317        // Add providers to the pool
318        let mut pool_providers = self.pool.providers.write().await;
319        for (name, provider) in &providers {
320            pool_providers.insert(name.clone(), provider.clone());
321        }
322
323        Ok(providers)
324    }
325
326    /// Execute a tool on a specific provider
327    pub async fn execute_tool(
328        &self,
329        provider_name: &str,
330        tool_name: &str,
331        arguments: serde_json::Value,
332        allowlist: &McpAllowListConfig,
333        tool_timeout: Option<Duration>,
334    ) -> Result<serde_json::Value, McpPoolError> {
335        let provider = self
336            .pool
337            .get_provider(provider_name)
338            .await
339            .ok_or_else(|| McpPoolError::ProviderNotFound(provider_name.to_string()))?;
340
341        // Convert arguments to proper format
342        let args_ref = &arguments;
343
344        // Execute the tool with correct signature
345        let result = provider
346            .call_tool(tool_name, args_ref, tool_timeout, allowlist)
347            .await
348            .map_err(|e| {
349                McpPoolError::ToolExecutionError(provider_name.to_string(), e.to_string())
350            })?;
351
352        // Convert result to JSON value
353        Ok(serde_json::to_value(&result).unwrap_or(serde_json::Value::Null))
354    }
355
356    /// Check if a tool is read-only (safe to cache)
357    #[expect(dead_code)]
358    fn is_read_only_tool(&self, tool_name: &str) -> bool {
359        // This is a simple heuristic - in practice, you might want to
360        // check tool metadata or maintain a list of read-only tools
361        matches!(
362            tool_name,
363            "read_file"
364                | "list_directory"
365                | "search_files"
366                | "get_file_info"
367                | "read_environment"
368                | "get_system_info"
369                | "search_code"
370                | "analyze_code"
371        )
372    }
373
374    /// Get pool statistics
375    pub async fn stats(&self) -> PooledMcpStats {
376        let pool_stats = self.pool.stats().await;
377        let tool_cache_stats = self.tool_cache.stats();
378
379        PooledMcpStats {
380            connection_pool: pool_stats,
381            tool_cache: tool_cache_stats,
382        }
383    }
384
385    /// Shutdown all providers gracefully
386    pub async fn shutdown(&self) {
387        self.pool.shutdown_all().await;
388    }
389}
390
391/// Pooled MCP manager statistics
392#[derive(Debug, Clone)]
393pub struct PooledMcpStats {
394    pub connection_pool: ConnectionPoolStats,
395    pub tool_cache: super::tool_discovery_cache::ToolCacheStats,
396}
397
398/// Build initialize params for an MCP provider
399fn build_pool_initialize_params(_provider: &McpProvider) -> InitializeRequestParams {
400    InitializeRequestParams::new(
401        ClientCapabilities::default(),
402        super::utils::build_client_implementation(),
403    )
404    .with_protocol_version(rmcp::model::ProtocolVersion::V_2024_11_05)
405}
406
407/// MCP connection pool errors
408#[derive(Debug, thiserror::Error)]
409pub enum McpPoolError {
410    #[error("Connection timeout for provider '{0}'")]
411    ConnectionTimeout(String),
412
413    #[error("Connection error for provider '{0}': {1}")]
414    ConnectionError(String, String),
415
416    #[error("Initialization timeout for provider '{0}'")]
417    InitializationTimeout(String),
418
419    #[error("Initialization error for provider '{0}': {1}")]
420    InitializationError(String, String),
421
422    #[error("Provider not found: {0}")]
423    ProviderNotFound(String),
424
425    #[error("Tool execution error for provider '{0}': {1}")]
426    ToolExecutionError(String, String),
427
428    #[error("Semaphore error: {0}")]
429    SemaphoreError(String),
430}
431
432#[cfg(test)]
433pub mod tests {
434    use super::*;
435
436    #[tokio::test]
437    async fn test_connection_pool_creation() {
438        let pool = McpConnectionPool::new(5, 30);
439        let stats = pool.stats().await;
440
441        assert_eq!(stats.active_connections, 0);
442        assert_eq!(stats.max_connections, 5);
443        assert_eq!(stats.available_permits, 5);
444    }
445
446    #[tokio::test]
447    async fn test_connection_pool_semaphore_limits() {
448        let pool = McpConnectionPool::new(3, 30);
449
450        // Acquire 3 permits
451        let permit1 = pool.connection_semaphore.acquire().await.unwrap();
452        let _permit2 = pool.connection_semaphore.acquire().await.unwrap();
453        let _permit3 = pool.connection_semaphore.acquire().await.unwrap();
454
455        let stats = pool.stats().await;
456        assert_eq!(stats.available_permits, 0);
457
458        // Try to acquire another (would block if not in test)
459        drop(permit1);
460        let _permit4 = pool.connection_semaphore.acquire().await.unwrap();
461
462        let stats = pool.stats().await;
463        assert_eq!(stats.available_permits, 0);
464    }
465
466    #[tokio::test]
467    async fn test_pooled_manager_creation() {
468        let manager = PooledMcpManager::new(10, 30, 100);
469        let stats = manager.stats().await;
470
471        assert_eq!(stats.connection_pool.max_connections, 10);
472        assert_eq!(stats.connection_pool.active_connections, 0);
473    }
474
475    #[tokio::test]
476    async fn test_read_only_tool_detection() {
477        let manager = PooledMcpManager::new(5, 30, 50);
478
479        assert!(manager.is_read_only_tool("read_file"));
480        assert!(manager.is_read_only_tool("search_files"));
481        assert!(manager.is_read_only_tool("get_system_info"));
482        assert!(manager.is_read_only_tool("get_file_info"));
483
484        assert!(!manager.is_read_only_tool("write_file"));
485        assert!(!manager.is_read_only_tool("edit_file"));
486        assert!(!manager.is_read_only_tool("execute_command"));
487        assert!(!manager.is_read_only_tool("delete_file"));
488    }
489
490    #[test]
491    fn test_connection_pool_error_display() {
492        let error = McpPoolError::ConnectionTimeout("test_provider".to_string());
493        assert!(error.to_string().contains("test_provider"));
494
495        let error = McpPoolError::InitializationError(
496            "auth".to_string(),
497            "invalid credentials".to_string(),
498        );
499        assert!(error.to_string().contains("auth"));
500        assert!(error.to_string().contains("invalid credentials"));
501    }
502
503    #[tokio::test]
504    async fn test_pool_provider_not_found() {
505        let pool = McpConnectionPool::new(5, 30);
506        let provider = pool.get_provider("nonexistent").await;
507        assert!(provider.is_none());
508    }
509
510    #[tokio::test]
511    async fn test_pool_has_provider() {
512        let pool = McpConnectionPool::new(5, 30);
513        assert!(!pool.has_provider("test").await);
514    }
515
516    #[tokio::test]
517    async fn test_pool_get_all_providers_empty() {
518        let pool = McpConnectionPool::new(5, 30);
519        let providers = pool.get_all_providers().await;
520        assert_eq!(providers.len(), 0);
521    }
522
523    #[tokio::test]
524    async fn test_pool_stats() {
525        let pool = McpConnectionPool::new(7, 60);
526        let stats = pool.stats().await;
527
528        assert_eq!(stats.max_connections, 7);
529        assert_eq!(stats.available_permits, 7);
530        assert_eq!(stats.active_connections, 0);
531    }
532}