Skip to main content

smg_mcp/core/
pool.rs

1//! MCP Connection Pool for dynamic servers.
2
3use std::{
4    collections::{hash_map::DefaultHasher, HashMap},
5    hash::{Hash, Hasher},
6    sync::{
7        atomic::{AtomicUsize, Ordering},
8        Arc,
9    },
10};
11
12use lru::LruCache;
13use parking_lot::Mutex;
14use rmcp::{service::RunningService, RoleClient};
15
16use super::config::{McpProxyConfig, McpServerConfig, McpTransport};
17use crate::error::McpResult;
18
19type McpClient = RunningService<RoleClient, ()>;
20type EvictionCallback = Arc<dyn Fn(&PoolKey) + Send + Sync>;
21
22/// Key for connection pool entries (URL + auth hash + tenant ID).
23///
24/// Credentials are hashed, not stored as plaintext.
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub struct PoolKey {
27    pub url: String,
28    pub auth_hash: u64,
29    pub tenant_id: Option<String>,
30}
31
32impl PoolKey {
33    pub fn new(url: impl Into<String>, auth_hash: u64, tenant_id: Option<String>) -> Self {
34        Self {
35            url: url.into(),
36            auth_hash,
37            tenant_id,
38        }
39    }
40
41    pub fn from_config(config: &McpServerConfig, tenant_id: Option<String>) -> Self {
42        let (url, auth_hash) = match &config.transport {
43            McpTransport::Streamable {
44                url,
45                token,
46                headers,
47            } => (url.clone(), Self::hash_auth(token.as_ref(), headers)),
48            McpTransport::Sse {
49                url,
50                token,
51                headers,
52            } => (url.clone(), Self::hash_auth(token.as_ref(), headers)),
53            McpTransport::Stdio { command, args, .. } => {
54                (format!("{}:{}", command, args.join(" ")), 0)
55            }
56        };
57        Self {
58            url,
59            auth_hash,
60            tenant_id,
61        }
62    }
63
64    /// Hash token and headers. Returns 0 if no auth info.
65    fn hash_auth(token: Option<&String>, headers: &HashMap<String, String>) -> u64 {
66        if token.is_none() && headers.is_empty() {
67            return 0;
68        }
69
70        let mut hasher = DefaultHasher::new();
71
72        if let Some(t) = token {
73            t.hash(&mut hasher);
74        }
75
76        if !headers.is_empty() {
77            let mut sorted_headers: Vec<_> = headers.iter().collect();
78            sorted_headers.sort_by_key(|(k, _)| *k);
79            for (key, value) in sorted_headers {
80                key.hash(&mut hasher);
81                value.hash(&mut hasher);
82            }
83        }
84
85        hasher.finish()
86    }
87
88    #[inline]
89    pub fn url(&self) -> &str {
90        &self.url
91    }
92}
93
94/// Cached MCP connection.
95#[derive(Clone)]
96pub(crate) struct CachedConnection {
97    pub client: Arc<McpClient>,
98}
99
100impl CachedConnection {
101    pub fn new(client: Arc<McpClient>) -> Self {
102        Self { client }
103    }
104}
105
106/// Thread-safe LRU connection pool for dynamic MCP servers.
107pub struct McpConnectionPool {
108    connections: Arc<Mutex<LruCache<PoolKey, CachedConnection>>>,
109    /// Lock-free connection count for fast `len()` / `is_empty()` / `stats()`.
110    connection_count: AtomicUsize,
111    max_connections: usize,
112    global_proxy: Option<McpProxyConfig>,
113    eviction_callback: Option<EvictionCallback>,
114}
115
116impl McpConnectionPool {
117    const DEFAULT_MAX_CONNECTIONS: usize = 200;
118
119    /// Create pool with defaults (200 connections, proxy from env).
120    pub fn new() -> Self {
121        Self::with_full_config(Self::DEFAULT_MAX_CONNECTIONS, McpProxyConfig::from_env())
122    }
123
124    pub fn with_capacity(max_connections: usize) -> Self {
125        Self::with_full_config(max_connections, McpProxyConfig::from_env())
126    }
127
128    pub fn with_full_config(max_connections: usize, global_proxy: Option<McpProxyConfig>) -> Self {
129        let max_connections = max_connections.max(1);
130        let cache_cap =
131            std::num::NonZeroUsize::new(max_connections).unwrap_or(std::num::NonZeroUsize::MIN);
132        Self {
133            connections: Arc::new(Mutex::new(LruCache::new(cache_cap))),
134            connection_count: AtomicUsize::new(0),
135            max_connections,
136            global_proxy,
137            eviction_callback: None,
138        }
139    }
140
141    pub fn set_eviction_callback<F>(&mut self, callback: F)
142    where
143        F: Fn(&PoolKey) + Send + Sync + 'static,
144    {
145        self.eviction_callback = Some(Arc::new(callback));
146    }
147
148    /// Get existing connection or create via `connect_fn`.
149    pub async fn get_or_create<F, Fut>(
150        &self,
151        key: PoolKey,
152        server_config: McpServerConfig,
153        connect_fn: F,
154    ) -> McpResult<Arc<McpClient>>
155    where
156        F: FnOnce(McpServerConfig, Option<McpProxyConfig>) -> Fut,
157        Fut: std::future::Future<Output = McpResult<McpClient>>,
158    {
159        {
160            let mut connections = self.connections.lock();
161            if let Some(cached) = connections.get(&key) {
162                return Ok(Arc::clone(&cached.client));
163            }
164        }
165
166        let client = connect_fn(server_config.clone(), self.global_proxy.clone()).await?;
167        let client_arc = Arc::new(client);
168
169        let cached = CachedConnection::new(Arc::clone(&client_arc));
170        {
171            let mut connections = self.connections.lock();
172            match connections.push(key, cached) {
173                Some((evicted_key, _)) => {
174                    // Eviction: count stays the same (replaced one entry).
175                    if let Some(callback) = &self.eviction_callback {
176                        callback(&evicted_key);
177                    }
178                }
179                None => {
180                    // New entry without eviction: count increases.
181                    self.connection_count.fetch_add(1, Ordering::Relaxed);
182                }
183            }
184        }
185
186        Ok(client_arc)
187    }
188
189    pub fn len(&self) -> usize {
190        self.connection_count.load(Ordering::Relaxed)
191    }
192
193    pub fn is_empty(&self) -> bool {
194        self.connection_count.load(Ordering::Relaxed) == 0
195    }
196
197    pub fn clear(&self) {
198        let mut connections = self.connections.lock();
199        connections.clear();
200        self.connection_count.store(0, Ordering::Relaxed);
201    }
202
203    pub fn stats(&self) -> PoolStats {
204        PoolStats {
205            total_connections: self.connection_count.load(Ordering::Relaxed),
206            capacity: self.max_connections,
207        }
208    }
209
210    pub fn list_keys(&self) -> Vec<PoolKey> {
211        self.connections
212            .lock()
213            .iter()
214            .map(|(key, _)| key.clone())
215            .collect()
216    }
217
218    /// Get connection, promoting in LRU.
219    pub fn get(&self, key: &PoolKey) -> Option<Arc<McpClient>> {
220        self.connections
221            .lock()
222            .get(key)
223            .map(|cached| Arc::clone(&cached.client))
224    }
225
226    pub fn contains(&self, key: &PoolKey) -> bool {
227        self.connections.lock().contains(key)
228    }
229
230    /// Look up a connection by URL only (backward compatibility).
231    ///
232    /// **O(n)** — performs a linear scan of all pooled connections under the
233    /// lock. Callers on hot paths should prefer [`get()`](Self::get) with a
234    /// full [`PoolKey`] for O(1) lookup.
235    pub fn get_by_url(&self, url: &str) -> Option<Arc<McpClient>> {
236        self.connections
237            .lock()
238            .iter()
239            .find(|(key, _)| key.url == url)
240            .map(|(_, cached)| Arc::clone(&cached.client))
241    }
242
243    /// Check whether a connection with the given URL exists (backward
244    /// compatibility).
245    ///
246    /// **O(n)** — performs a linear scan of all pooled connections under the
247    /// lock. Callers on hot paths should prefer [`contains()`](Self::contains)
248    /// with a full [`PoolKey`] for O(1) lookup.
249    pub fn contains_url(&self, url: &str) -> bool {
250        self.connections
251            .lock()
252            .iter()
253            .any(|(key, _)| key.url == url)
254    }
255}
256
257impl Default for McpConnectionPool {
258    fn default() -> Self {
259        Self::new()
260    }
261}
262
263/// Connection pool statistics
264#[derive(Debug, Clone)]
265pub struct PoolStats {
266    pub total_connections: usize,
267    pub capacity: usize,
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::core::config::McpTransport;
274
275    // Helper to create test server config
276    fn create_test_config(url: &str) -> McpServerConfig {
277        McpServerConfig {
278            name: "test_server".to_string(),
279            transport: McpTransport::Streamable {
280                url: url.to_string(),
281                token: None,
282                headers: HashMap::new(),
283            },
284            proxy: None,
285            required: false,
286            tools: None,
287            builtin_type: None,
288            builtin_tool_name: None,
289        }
290    }
291
292    #[tokio::test]
293    async fn test_pool_creation() {
294        let pool = McpConnectionPool::new();
295        assert_eq!(pool.len(), 0);
296        assert!(pool.is_empty());
297    }
298
299    #[test]
300    fn test_pool_stats() {
301        let pool = McpConnectionPool::with_capacity(10);
302
303        let stats = pool.stats();
304        assert_eq!(stats.total_connections, 0);
305        assert_eq!(stats.capacity, 10);
306    }
307
308    #[test]
309    fn test_pool_clear() {
310        let pool = McpConnectionPool::new();
311        // Pool starts empty
312        assert_eq!(pool.len(), 0);
313        // Clear on empty pool should work
314        pool.clear();
315        assert!(pool.is_empty());
316    }
317
318    #[test]
319    fn test_pool_key_from_config() {
320        // No token
321        let config = create_test_config("http://localhost:3000");
322        let key = PoolKey::from_config(&config, None);
323        assert_eq!(key.url, "http://localhost:3000");
324        assert_eq!(key.auth_hash, 0);
325        assert_eq!(key.tenant_id, None);
326
327        // With token
328        let config_with_token = McpServerConfig {
329            name: "test".to_string(),
330            transport: McpTransport::Streamable {
331                url: "http://localhost:3000".to_string(),
332                token: Some("secret-token".to_string()),
333                headers: HashMap::new(),
334            },
335            proxy: None,
336            required: false,
337            tools: None,
338            builtin_type: None,
339            builtin_tool_name: None,
340        };
341        let key_with_token = PoolKey::from_config(&config_with_token, None);
342        assert_eq!(key_with_token.url, "http://localhost:3000");
343        assert_ne!(key_with_token.auth_hash, 0); // Token hashed
344
345        // With tenant
346        let key_with_tenant = PoolKey::from_config(&config, Some("tenant-123".to_string()));
347        assert_eq!(key_with_tenant.tenant_id, Some("tenant-123".to_string()));
348    }
349
350    #[test]
351    fn test_pool_key_different_tokens() {
352        let config1 = McpServerConfig {
353            name: "test".to_string(),
354            transport: McpTransport::Streamable {
355                url: "http://localhost:3000".to_string(),
356                token: Some("token-a".to_string()),
357                headers: HashMap::new(),
358            },
359            proxy: None,
360            required: false,
361            tools: None,
362            builtin_type: None,
363            builtin_tool_name: None,
364        };
365        let config2 = McpServerConfig {
366            name: "test".to_string(),
367            transport: McpTransport::Streamable {
368                url: "http://localhost:3000".to_string(),
369                token: Some("token-b".to_string()),
370                headers: HashMap::new(),
371            },
372            proxy: None,
373            required: false,
374            tools: None,
375            builtin_type: None,
376            builtin_tool_name: None,
377        };
378
379        let key1 = PoolKey::from_config(&config1, None);
380        let key2 = PoolKey::from_config(&config2, None);
381
382        // Same URL but different tokens = different keys
383        assert_eq!(key1.url, key2.url);
384        assert_ne!(key1.auth_hash, key2.auth_hash);
385        assert_ne!(key1, key2);
386    }
387
388    #[test]
389    fn test_pool_key_with_headers() {
390        let mut headers1 = HashMap::new();
391        headers1.insert("X-API-Key".to_string(), "key-1".to_string());
392
393        let mut headers2 = HashMap::new();
394        headers2.insert("X-API-Key".to_string(), "key-2".to_string());
395
396        let config1 = McpServerConfig {
397            name: "test".to_string(),
398            transport: McpTransport::Sse {
399                url: "http://localhost:3000".to_string(),
400                token: None,
401                headers: headers1,
402            },
403            proxy: None,
404            required: false,
405            tools: None,
406            builtin_type: None,
407            builtin_tool_name: None,
408        };
409        let config2 = McpServerConfig {
410            name: "test".to_string(),
411            transport: McpTransport::Sse {
412                url: "http://localhost:3000".to_string(),
413                token: None,
414                headers: headers2,
415            },
416            proxy: None,
417            required: false,
418            tools: None,
419            builtin_type: None,
420            builtin_tool_name: None,
421        };
422
423        let key1 = PoolKey::from_config(&config1, None);
424        let key2 = PoolKey::from_config(&config2, None);
425
426        // Same URL but different headers = different keys
427        assert_eq!(key1.url, key2.url);
428        assert_ne!(key1.auth_hash, key2.auth_hash);
429        assert_ne!(key1, key2);
430    }
431
432    #[test]
433    fn test_pool_with_global_proxy() {
434        use crate::core::config::McpProxyConfig;
435
436        // Create proxy config
437        let proxy = McpProxyConfig {
438            http: Some("http://proxy.example.com:8080".to_string()),
439            https: None,
440            no_proxy: Some("localhost,127.0.0.1".to_string()),
441            username: None,
442            password: None,
443        };
444
445        // Create pool with proxy
446        let pool = McpConnectionPool::with_full_config(100, Some(proxy.clone()));
447
448        // Verify proxy is stored
449        assert!(pool.global_proxy.is_some());
450        let stored_proxy = pool.global_proxy.as_ref().unwrap();
451        assert_eq!(
452            stored_proxy.http.as_ref().unwrap(),
453            "http://proxy.example.com:8080"
454        );
455        assert_eq!(
456            stored_proxy.no_proxy.as_ref().unwrap(),
457            "localhost,127.0.0.1"
458        );
459    }
460
461    #[test]
462    fn test_pool_proxy_from_env() {
463        // Note: This test depends on environment variables
464        // In production, proxy is loaded from MCP_HTTP_PROXY or HTTP_PROXY env vars
465        let pool = McpConnectionPool::new();
466
467        // Pool should either have proxy from env or None
468        // We can't assert specific value since it depends on test environment
469        // Just verify it doesn't panic
470        assert!(pool.global_proxy.is_some() || pool.global_proxy.is_none());
471    }
472}