1use hashbrown::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{RwLock, Semaphore};
10use tracing::{error, info, warn};
11
12use super::{McpElicitationHandler, McpProvider};
13use crate::config::mcp::{McpAllowListConfig, McpProviderConfig};
14use rmcp::model::{ClientCapabilities, InitializeRequestParams};
15
16pub struct McpConnectionPool {
18 providers: Arc<RwLock<HashMap<String, Arc<McpProvider>>>>,
20 connection_semaphore: Arc<Semaphore>,
22 max_concurrent_connections: usize,
24 connection_timeout: Duration,
26}
27
28impl McpConnectionPool {
29 pub fn new(max_concurrent_connections: usize, connection_timeout_seconds: u64) -> Self {
30 Self {
31 providers: Arc::new(RwLock::new(HashMap::new())),
32 connection_semaphore: Arc::new(Semaphore::new(max_concurrent_connections)),
33 max_concurrent_connections,
34 connection_timeout: Duration::from_secs(connection_timeout_seconds),
35 }
36 }
37
38 pub async fn initialize_providers_parallel(
40 &self,
41 provider_configs: Vec<McpProviderConfig>,
42 elicitation_handler: Option<Arc<dyn McpElicitationHandler>>,
43 tool_timeout: Option<Duration>,
44 allowlist_snapshot: &McpAllowListConfig,
45 ) -> Result<Vec<(String, Arc<McpProvider>)>, McpPoolError> {
46 use futures::future::join_all;
47
48 let tasks: Vec<_> = provider_configs
50 .into_iter()
51 .map(|config| {
52 let elicitation_handler = elicitation_handler.clone();
53 let allowlist_snapshot = allowlist_snapshot.clone();
54
55 async move {
56 self.initialize_provider(
57 config,
58 elicitation_handler,
59 tool_timeout.unwrap_or(Duration::from_secs(30)),
60 allowlist_snapshot,
61 )
62 .await
63 }
64 })
65 .collect();
66
67 let results = join_all(tasks).await;
69
70 let mut successful_providers = Vec::new();
72 let mut errors = Vec::new();
73
74 for result in results {
75 match result {
76 Ok((name, provider)) => {
77 successful_providers.push((name, provider));
78 }
79 Err(error) => {
80 errors.push(error);
81 }
82 }
83 }
84
85 if !errors.is_empty() {
86 warn!("Some MCP provider connections failed: {:?}", errors);
87 }
88
89 Ok(successful_providers)
90 }
91
92 async fn initialize_provider(
94 &self,
95 config: McpProviderConfig,
96 elicitation_handler: Option<Arc<dyn McpElicitationHandler>>,
97 tool_timeout: Duration,
98 allowlist_snapshot: McpAllowListConfig,
99 ) -> Result<(String, Arc<McpProvider>), McpPoolError> {
100 let _permit = self
102 .connection_semaphore
103 .acquire()
104 .await
105 .map_err(|e| McpPoolError::SemaphoreError(e.to_string()))?;
106
107 info!("Initializing MCP provider '{}'", config.name);
108
109 let provider = tokio::time::timeout(
111 self.connection_timeout,
112 McpProvider::connect(config.clone(), elicitation_handler),
113 )
114 .await
115 .map_err(|_| McpPoolError::ConnectionTimeout(config.name.clone()))?
116 .map_err(|e| McpPoolError::ConnectionError(config.name.clone(), e.to_string()))?;
117
118 let provider_startup_timeout = self.resolve_startup_timeout(&config);
120 let initialize_params = build_pool_initialize_params(&provider);
121 let tool_timeout_opt = Some(tool_timeout);
122
123 if let Err(err) = provider
124 .initialize(
125 initialize_params,
126 provider_startup_timeout,
127 tool_timeout_opt,
128 &allowlist_snapshot,
129 )
130 .await
131 {
132 return Err(McpPoolError::InitializationError(
133 config.name.clone(),
134 err.to_string(),
135 ));
136 }
137
138 if let Err(err) = provider
140 .cached_tools_or_refresh(&allowlist_snapshot, tool_timeout_opt)
141 .await
142 {
143 warn!(
144 "Failed to fetch tools for provider '{}': {}",
145 config.name, err
146 );
147 }
148
149 info!("Successfully initialized MCP provider '{}'", config.name);
150
151 Ok((config.name.clone(), Arc::new(provider)))
152 }
153
154 pub async fn get_provider(&self, name: &str) -> Option<Arc<McpProvider>> {
156 let providers = self.providers.read().await;
157 providers.get(name).cloned()
158 }
159
160 pub async fn get_all_providers(&self) -> Vec<Arc<McpProvider>> {
162 let providers = self.providers.read().await;
163 providers.values().cloned().collect()
164 }
165
166 pub async fn remove_provider(&self, name: &str) -> Option<Arc<McpProvider>> {
168 let mut providers = self.providers.write().await;
169 providers.remove(name)
170 }
171
172 pub async fn has_provider(&self, name: &str) -> bool {
174 let providers = self.providers.read().await;
175 providers.contains_key(name)
176 }
177
178 pub async fn stats(&self) -> ConnectionPoolStats {
180 let providers = self.providers.read().await;
181 let semaphore = self.connection_semaphore.available_permits();
182
183 ConnectionPoolStats {
184 active_connections: providers.len(),
185 available_permits: semaphore,
186 max_connections: self.max_concurrent_connections,
187 }
188 }
189
190 pub async fn shutdown_all(&self) {
192 let providers: Vec<_> = {
193 let mut providers = self.providers.write().await;
194 providers.drain().collect()
195 };
196
197 for (name, provider) in providers {
198 if let Err(err) = provider.shutdown().await {
199 error!("Failed to shutdown MCP provider '{}': {}", name, err);
200 }
201 }
202 }
203
204 pub async fn health_check(&self) -> HashMap<String, bool> {
208 let providers: Vec<_> = {
209 let providers = self.providers.read().await;
210 providers
211 .iter()
212 .map(|(name, provider)| (name.clone(), Arc::clone(provider)))
213 .collect()
214 };
215 let mut results = HashMap::with_capacity(providers.len());
216 for (name, provider) in providers {
217 results.insert(name, provider.is_healthy().await);
218 }
219 results
220 }
221
222 pub async fn reconnect_unhealthy(
226 &self,
227 startup_timeout: Option<Duration>,
228 tool_timeout: Option<Duration>,
229 allowlist: &McpAllowListConfig,
230 ) -> Vec<String> {
231 let providers: Vec<_> = {
232 let providers = self.providers.read().await;
233 providers
234 .iter()
235 .map(|(name, provider)| (name.clone(), Arc::clone(provider)))
236 .collect()
237 };
238 let mut reconnected = Vec::new();
239 for (name, provider) in providers {
240 if !provider.is_healthy().await {
241 info!("Provider '{}' is unhealthy, attempting reconnect", name);
242 match provider
243 .reconnect(startup_timeout, tool_timeout, allowlist)
244 .await
245 {
246 Ok(()) => {
247 info!("Successfully reconnected MCP provider '{}'", name);
248 reconnected.push(name);
249 }
250 Err(err) => {
251 error!("Failed to reconnect MCP provider '{}': {}", name, err);
252 }
253 }
254 }
255 }
256 reconnected
257 }
258
259 fn resolve_startup_timeout(&self, config: &McpProviderConfig) -> Option<Duration> {
261 config.startup_timeout_ms.map(Duration::from_millis)
262 }
263}
264
265#[derive(Debug, Clone)]
267pub struct ConnectionPoolStats {
268 pub active_connections: usize,
269 pub available_permits: usize,
270 pub max_connections: usize,
271}
272
273pub struct PooledMcpManager {
275 pool: Arc<McpConnectionPool>,
277 tool_cache: Arc<super::tool_discovery_cache::ToolDiscoveryCache>,
279}
280
281impl PooledMcpManager {
282 pub fn new(
283 max_concurrent_connections: usize,
284 connection_timeout_seconds: u64,
285 tool_cache_capacity: usize,
286 ) -> Self {
287 Self {
288 pool: Arc::new(McpConnectionPool::new(
289 max_concurrent_connections,
290 connection_timeout_seconds,
291 )),
292 tool_cache: Arc::new(super::tool_discovery_cache::ToolDiscoveryCache::new(
293 tool_cache_capacity,
294 )),
295 }
296 }
297
298 pub async fn initialize_providers(
300 &self,
301 provider_configs: Vec<McpProviderConfig>,
302 elicitation_handler: Option<Arc<dyn McpElicitationHandler>>,
303 tool_timeout: Option<Duration>,
304 allowlist_snapshot: &McpAllowListConfig,
305 ) -> Result<Vec<(String, Arc<McpProvider>)>, McpPoolError> {
306 let providers = self
308 .pool
309 .initialize_providers_parallel(
310 provider_configs,
311 elicitation_handler,
312 tool_timeout,
313 allowlist_snapshot,
314 )
315 .await?;
316
317 let mut pool_providers = self.pool.providers.write().await;
319 for (name, provider) in &providers {
320 pool_providers.insert(name.clone(), provider.clone());
321 }
322
323 Ok(providers)
324 }
325
326 pub async fn execute_tool(
328 &self,
329 provider_name: &str,
330 tool_name: &str,
331 arguments: serde_json::Value,
332 allowlist: &McpAllowListConfig,
333 tool_timeout: Option<Duration>,
334 ) -> Result<serde_json::Value, McpPoolError> {
335 let provider = self
336 .pool
337 .get_provider(provider_name)
338 .await
339 .ok_or_else(|| McpPoolError::ProviderNotFound(provider_name.to_string()))?;
340
341 let args_ref = &arguments;
343
344 let result = provider
346 .call_tool(tool_name, args_ref, tool_timeout, allowlist)
347 .await
348 .map_err(|e| {
349 McpPoolError::ToolExecutionError(provider_name.to_string(), e.to_string())
350 })?;
351
352 Ok(serde_json::to_value(&result).unwrap_or(serde_json::Value::Null))
354 }
355
356 #[expect(dead_code)]
358 fn is_read_only_tool(&self, tool_name: &str) -> bool {
359 matches!(
362 tool_name,
363 "read_file"
364 | "list_directory"
365 | "search_files"
366 | "get_file_info"
367 | "read_environment"
368 | "get_system_info"
369 | "search_code"
370 | "analyze_code"
371 )
372 }
373
374 pub async fn stats(&self) -> PooledMcpStats {
376 let pool_stats = self.pool.stats().await;
377 let tool_cache_stats = self.tool_cache.stats();
378
379 PooledMcpStats {
380 connection_pool: pool_stats,
381 tool_cache: tool_cache_stats,
382 }
383 }
384
385 pub async fn shutdown(&self) {
387 self.pool.shutdown_all().await;
388 }
389}
390
391#[derive(Debug, Clone)]
393pub struct PooledMcpStats {
394 pub connection_pool: ConnectionPoolStats,
395 pub tool_cache: super::tool_discovery_cache::ToolCacheStats,
396}
397
398fn build_pool_initialize_params(_provider: &McpProvider) -> InitializeRequestParams {
400 InitializeRequestParams::new(
401 ClientCapabilities::default(),
402 super::utils::build_client_implementation(),
403 )
404 .with_protocol_version(rmcp::model::ProtocolVersion::V_2024_11_05)
405}
406
407#[derive(Debug, thiserror::Error)]
409pub enum McpPoolError {
410 #[error("Connection timeout for provider '{0}'")]
411 ConnectionTimeout(String),
412
413 #[error("Connection error for provider '{0}': {1}")]
414 ConnectionError(String, String),
415
416 #[error("Initialization timeout for provider '{0}'")]
417 InitializationTimeout(String),
418
419 #[error("Initialization error for provider '{0}': {1}")]
420 InitializationError(String, String),
421
422 #[error("Provider not found: {0}")]
423 ProviderNotFound(String),
424
425 #[error("Tool execution error for provider '{0}': {1}")]
426 ToolExecutionError(String, String),
427
428 #[error("Semaphore error: {0}")]
429 SemaphoreError(String),
430}
431
432#[cfg(test)]
433pub mod tests {
434 use super::*;
435
436 #[tokio::test]
437 async fn test_connection_pool_creation() {
438 let pool = McpConnectionPool::new(5, 30);
439 let stats = pool.stats().await;
440
441 assert_eq!(stats.active_connections, 0);
442 assert_eq!(stats.max_connections, 5);
443 assert_eq!(stats.available_permits, 5);
444 }
445
446 #[tokio::test]
447 async fn test_connection_pool_semaphore_limits() {
448 let pool = McpConnectionPool::new(3, 30);
449
450 let permit1 = pool.connection_semaphore.acquire().await.unwrap();
452 let _permit2 = pool.connection_semaphore.acquire().await.unwrap();
453 let _permit3 = pool.connection_semaphore.acquire().await.unwrap();
454
455 let stats = pool.stats().await;
456 assert_eq!(stats.available_permits, 0);
457
458 drop(permit1);
460 let _permit4 = pool.connection_semaphore.acquire().await.unwrap();
461
462 let stats = pool.stats().await;
463 assert_eq!(stats.available_permits, 0);
464 }
465
466 #[tokio::test]
467 async fn test_pooled_manager_creation() {
468 let manager = PooledMcpManager::new(10, 30, 100);
469 let stats = manager.stats().await;
470
471 assert_eq!(stats.connection_pool.max_connections, 10);
472 assert_eq!(stats.connection_pool.active_connections, 0);
473 }
474
475 #[tokio::test]
476 async fn test_read_only_tool_detection() {
477 let manager = PooledMcpManager::new(5, 30, 50);
478
479 assert!(manager.is_read_only_tool("read_file"));
480 assert!(manager.is_read_only_tool("search_files"));
481 assert!(manager.is_read_only_tool("get_system_info"));
482 assert!(manager.is_read_only_tool("get_file_info"));
483
484 assert!(!manager.is_read_only_tool("write_file"));
485 assert!(!manager.is_read_only_tool("edit_file"));
486 assert!(!manager.is_read_only_tool("execute_command"));
487 assert!(!manager.is_read_only_tool("delete_file"));
488 }
489
490 #[test]
491 fn test_connection_pool_error_display() {
492 let error = McpPoolError::ConnectionTimeout("test_provider".to_string());
493 assert!(error.to_string().contains("test_provider"));
494
495 let error = McpPoolError::InitializationError(
496 "auth".to_string(),
497 "invalid credentials".to_string(),
498 );
499 assert!(error.to_string().contains("auth"));
500 assert!(error.to_string().contains("invalid credentials"));
501 }
502
503 #[tokio::test]
504 async fn test_pool_provider_not_found() {
505 let pool = McpConnectionPool::new(5, 30);
506 let provider = pool.get_provider("nonexistent").await;
507 assert!(provider.is_none());
508 }
509
510 #[tokio::test]
511 async fn test_pool_has_provider() {
512 let pool = McpConnectionPool::new(5, 30);
513 assert!(!pool.has_provider("test").await);
514 }
515
516 #[tokio::test]
517 async fn test_pool_get_all_providers_empty() {
518 let pool = McpConnectionPool::new(5, 30);
519 let providers = pool.get_all_providers().await;
520 assert_eq!(providers.len(), 0);
521 }
522
523 #[tokio::test]
524 async fn test_pool_stats() {
525 let pool = McpConnectionPool::new(7, 60);
526 let stats = pool.stats().await;
527
528 assert_eq!(stats.max_connections, 7);
529 assert_eq!(stats.available_permits, 7);
530 assert_eq!(stats.active_connections, 0);
531 }
532}