1use std::{
4 collections::{hash_map::DefaultHasher, HashMap},
5 hash::{Hash, Hasher},
6 sync::{
7 atomic::{AtomicUsize, Ordering},
8 Arc,
9 },
10};
11
12use lru::LruCache;
13use parking_lot::Mutex;
14use rmcp::{service::RunningService, RoleClient};
15
16use super::config::{McpProxyConfig, McpServerConfig, McpTransport};
17use crate::error::McpResult;
18
19type McpClient = RunningService<RoleClient, ()>;
20type EvictionCallback = Arc<dyn Fn(&PoolKey) + Send + Sync>;
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub struct PoolKey {
27 pub url: String,
28 pub auth_hash: u64,
29 pub tenant_id: Option<String>,
30}
31
32impl PoolKey {
33 pub fn new(url: impl Into<String>, auth_hash: u64, tenant_id: Option<String>) -> Self {
34 Self {
35 url: url.into(),
36 auth_hash,
37 tenant_id,
38 }
39 }
40
41 pub fn from_config(config: &McpServerConfig, tenant_id: Option<String>) -> Self {
42 let (url, auth_hash) = match &config.transport {
43 McpTransport::Streamable {
44 url,
45 token,
46 headers,
47 } => (url.clone(), Self::hash_auth(token.as_ref(), headers)),
48 McpTransport::Sse {
49 url,
50 token,
51 headers,
52 } => (url.clone(), Self::hash_auth(token.as_ref(), headers)),
53 McpTransport::Stdio { command, args, .. } => {
54 (format!("{}:{}", command, args.join(" ")), 0)
55 }
56 };
57 Self {
58 url,
59 auth_hash,
60 tenant_id,
61 }
62 }
63
64 fn hash_auth(token: Option<&String>, headers: &HashMap<String, String>) -> u64 {
66 if token.is_none() && headers.is_empty() {
67 return 0;
68 }
69
70 let mut hasher = DefaultHasher::new();
71
72 if let Some(t) = token {
73 t.hash(&mut hasher);
74 }
75
76 if !headers.is_empty() {
77 let mut sorted_headers: Vec<_> = headers.iter().collect();
78 sorted_headers.sort_by_key(|(k, _)| *k);
79 for (key, value) in sorted_headers {
80 key.hash(&mut hasher);
81 value.hash(&mut hasher);
82 }
83 }
84
85 hasher.finish()
86 }
87
88 #[inline]
89 pub fn url(&self) -> &str {
90 &self.url
91 }
92}
93
94#[derive(Clone)]
96pub(crate) struct CachedConnection {
97 pub client: Arc<McpClient>,
98}
99
100impl CachedConnection {
101 pub fn new(client: Arc<McpClient>) -> Self {
102 Self { client }
103 }
104}
105
106pub struct McpConnectionPool {
108 connections: Arc<Mutex<LruCache<PoolKey, CachedConnection>>>,
109 connection_count: AtomicUsize,
111 max_connections: usize,
112 global_proxy: Option<McpProxyConfig>,
113 eviction_callback: Option<EvictionCallback>,
114}
115
116impl McpConnectionPool {
117 const DEFAULT_MAX_CONNECTIONS: usize = 200;
118
119 pub fn new() -> Self {
121 Self::with_full_config(Self::DEFAULT_MAX_CONNECTIONS, McpProxyConfig::from_env())
122 }
123
124 pub fn with_capacity(max_connections: usize) -> Self {
125 Self::with_full_config(max_connections, McpProxyConfig::from_env())
126 }
127
128 pub fn with_full_config(max_connections: usize, global_proxy: Option<McpProxyConfig>) -> Self {
129 let max_connections = max_connections.max(1);
130 let cache_cap =
131 std::num::NonZeroUsize::new(max_connections).unwrap_or(std::num::NonZeroUsize::MIN);
132 Self {
133 connections: Arc::new(Mutex::new(LruCache::new(cache_cap))),
134 connection_count: AtomicUsize::new(0),
135 max_connections,
136 global_proxy,
137 eviction_callback: None,
138 }
139 }
140
141 pub fn set_eviction_callback<F>(&mut self, callback: F)
142 where
143 F: Fn(&PoolKey) + Send + Sync + 'static,
144 {
145 self.eviction_callback = Some(Arc::new(callback));
146 }
147
148 pub async fn get_or_create<F, Fut>(
150 &self,
151 key: PoolKey,
152 server_config: McpServerConfig,
153 connect_fn: F,
154 ) -> McpResult<Arc<McpClient>>
155 where
156 F: FnOnce(McpServerConfig, Option<McpProxyConfig>) -> Fut,
157 Fut: std::future::Future<Output = McpResult<McpClient>>,
158 {
159 {
160 let mut connections = self.connections.lock();
161 if let Some(cached) = connections.get(&key) {
162 return Ok(Arc::clone(&cached.client));
163 }
164 }
165
166 let client = connect_fn(server_config.clone(), self.global_proxy.clone()).await?;
167 let client_arc = Arc::new(client);
168
169 let cached = CachedConnection::new(Arc::clone(&client_arc));
170 {
171 let mut connections = self.connections.lock();
172 match connections.push(key, cached) {
173 Some((evicted_key, _)) => {
174 if let Some(callback) = &self.eviction_callback {
176 callback(&evicted_key);
177 }
178 }
179 None => {
180 self.connection_count.fetch_add(1, Ordering::Relaxed);
182 }
183 }
184 }
185
186 Ok(client_arc)
187 }
188
189 pub fn len(&self) -> usize {
190 self.connection_count.load(Ordering::Relaxed)
191 }
192
193 pub fn is_empty(&self) -> bool {
194 self.connection_count.load(Ordering::Relaxed) == 0
195 }
196
197 pub fn clear(&self) {
198 let mut connections = self.connections.lock();
199 connections.clear();
200 self.connection_count.store(0, Ordering::Relaxed);
201 }
202
203 pub fn stats(&self) -> PoolStats {
204 PoolStats {
205 total_connections: self.connection_count.load(Ordering::Relaxed),
206 capacity: self.max_connections,
207 }
208 }
209
210 pub fn list_keys(&self) -> Vec<PoolKey> {
211 self.connections
212 .lock()
213 .iter()
214 .map(|(key, _)| key.clone())
215 .collect()
216 }
217
218 pub fn get(&self, key: &PoolKey) -> Option<Arc<McpClient>> {
220 self.connections
221 .lock()
222 .get(key)
223 .map(|cached| Arc::clone(&cached.client))
224 }
225
226 pub fn contains(&self, key: &PoolKey) -> bool {
227 self.connections.lock().contains(key)
228 }
229
230 pub fn get_by_url(&self, url: &str) -> Option<Arc<McpClient>> {
236 self.connections
237 .lock()
238 .iter()
239 .find(|(key, _)| key.url == url)
240 .map(|(_, cached)| Arc::clone(&cached.client))
241 }
242
243 pub fn contains_url(&self, url: &str) -> bool {
250 self.connections
251 .lock()
252 .iter()
253 .any(|(key, _)| key.url == url)
254 }
255}
256
257impl Default for McpConnectionPool {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct PoolStats {
266 pub total_connections: usize,
267 pub capacity: usize,
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::core::config::McpTransport;
274
275 fn create_test_config(url: &str) -> McpServerConfig {
277 McpServerConfig {
278 name: "test_server".to_string(),
279 transport: McpTransport::Streamable {
280 url: url.to_string(),
281 token: None,
282 headers: HashMap::new(),
283 },
284 proxy: None,
285 required: false,
286 tools: None,
287 builtin_type: None,
288 builtin_tool_name: None,
289 }
290 }
291
292 #[tokio::test]
293 async fn test_pool_creation() {
294 let pool = McpConnectionPool::new();
295 assert_eq!(pool.len(), 0);
296 assert!(pool.is_empty());
297 }
298
299 #[test]
300 fn test_pool_stats() {
301 let pool = McpConnectionPool::with_capacity(10);
302
303 let stats = pool.stats();
304 assert_eq!(stats.total_connections, 0);
305 assert_eq!(stats.capacity, 10);
306 }
307
308 #[test]
309 fn test_pool_clear() {
310 let pool = McpConnectionPool::new();
311 assert_eq!(pool.len(), 0);
313 pool.clear();
315 assert!(pool.is_empty());
316 }
317
318 #[test]
319 fn test_pool_key_from_config() {
320 let config = create_test_config("http://localhost:3000");
322 let key = PoolKey::from_config(&config, None);
323 assert_eq!(key.url, "http://localhost:3000");
324 assert_eq!(key.auth_hash, 0);
325 assert_eq!(key.tenant_id, None);
326
327 let config_with_token = McpServerConfig {
329 name: "test".to_string(),
330 transport: McpTransport::Streamable {
331 url: "http://localhost:3000".to_string(),
332 token: Some("secret-token".to_string()),
333 headers: HashMap::new(),
334 },
335 proxy: None,
336 required: false,
337 tools: None,
338 builtin_type: None,
339 builtin_tool_name: None,
340 };
341 let key_with_token = PoolKey::from_config(&config_with_token, None);
342 assert_eq!(key_with_token.url, "http://localhost:3000");
343 assert_ne!(key_with_token.auth_hash, 0); let key_with_tenant = PoolKey::from_config(&config, Some("tenant-123".to_string()));
347 assert_eq!(key_with_tenant.tenant_id, Some("tenant-123".to_string()));
348 }
349
350 #[test]
351 fn test_pool_key_different_tokens() {
352 let config1 = McpServerConfig {
353 name: "test".to_string(),
354 transport: McpTransport::Streamable {
355 url: "http://localhost:3000".to_string(),
356 token: Some("token-a".to_string()),
357 headers: HashMap::new(),
358 },
359 proxy: None,
360 required: false,
361 tools: None,
362 builtin_type: None,
363 builtin_tool_name: None,
364 };
365 let config2 = McpServerConfig {
366 name: "test".to_string(),
367 transport: McpTransport::Streamable {
368 url: "http://localhost:3000".to_string(),
369 token: Some("token-b".to_string()),
370 headers: HashMap::new(),
371 },
372 proxy: None,
373 required: false,
374 tools: None,
375 builtin_type: None,
376 builtin_tool_name: None,
377 };
378
379 let key1 = PoolKey::from_config(&config1, None);
380 let key2 = PoolKey::from_config(&config2, None);
381
382 assert_eq!(key1.url, key2.url);
384 assert_ne!(key1.auth_hash, key2.auth_hash);
385 assert_ne!(key1, key2);
386 }
387
388 #[test]
389 fn test_pool_key_with_headers() {
390 let mut headers1 = HashMap::new();
391 headers1.insert("X-API-Key".to_string(), "key-1".to_string());
392
393 let mut headers2 = HashMap::new();
394 headers2.insert("X-API-Key".to_string(), "key-2".to_string());
395
396 let config1 = McpServerConfig {
397 name: "test".to_string(),
398 transport: McpTransport::Sse {
399 url: "http://localhost:3000".to_string(),
400 token: None,
401 headers: headers1,
402 },
403 proxy: None,
404 required: false,
405 tools: None,
406 builtin_type: None,
407 builtin_tool_name: None,
408 };
409 let config2 = McpServerConfig {
410 name: "test".to_string(),
411 transport: McpTransport::Sse {
412 url: "http://localhost:3000".to_string(),
413 token: None,
414 headers: headers2,
415 },
416 proxy: None,
417 required: false,
418 tools: None,
419 builtin_type: None,
420 builtin_tool_name: None,
421 };
422
423 let key1 = PoolKey::from_config(&config1, None);
424 let key2 = PoolKey::from_config(&config2, None);
425
426 assert_eq!(key1.url, key2.url);
428 assert_ne!(key1.auth_hash, key2.auth_hash);
429 assert_ne!(key1, key2);
430 }
431
432 #[test]
433 fn test_pool_with_global_proxy() {
434 use crate::core::config::McpProxyConfig;
435
436 let proxy = McpProxyConfig {
438 http: Some("http://proxy.example.com:8080".to_string()),
439 https: None,
440 no_proxy: Some("localhost,127.0.0.1".to_string()),
441 username: None,
442 password: None,
443 };
444
445 let pool = McpConnectionPool::with_full_config(100, Some(proxy.clone()));
447
448 assert!(pool.global_proxy.is_some());
450 let stored_proxy = pool.global_proxy.as_ref().unwrap();
451 assert_eq!(
452 stored_proxy.http.as_ref().unwrap(),
453 "http://proxy.example.com:8080"
454 );
455 assert_eq!(
456 stored_proxy.no_proxy.as_ref().unwrap(),
457 "localhost,127.0.0.1"
458 );
459 }
460
461 #[test]
462 fn test_pool_proxy_from_env() {
463 let pool = McpConnectionPool::new();
466
467 assert!(pool.global_proxy.is_some() || pool.global_proxy.is_none());
471 }
472}