turbomcp_client/client/
manager.rs

1//! Multi-Server Session Manager for MCP Clients
2//!
3//! This module provides a `SessionManager` for coordinating multiple MCP server sessions.
4//! Unlike traditional HTTP connection pooling, MCP uses long-lived, stateful sessions
5//! where each connection maintains negotiated capabilities and subscription state.
6//!
7//! # Key Concepts
8//!
9//! - **Session**: A long-lived, initialized MCP connection to a server
10//! - **Multi-Server**: Manage connections to different MCP servers (GitHub, filesystem, etc.)
11//! - **Health Monitoring**: Automatic ping-based health checks per session
12//! - **Lifecycle Management**: Proper initialize → operate → shutdown for each session
13//!
14//! # Features
15//!
16//! - Multiple server sessions with independent state
17//! - Automatic health checking with configurable intervals
18//! - Per-session state tracking (healthy, degraded, unhealthy)
19//! - Session lifecycle management
20//! - Metrics and monitoring per session
21//!
22//! # When to Use
23//!
24//! Use `SessionManager` when your application needs to coordinate **multiple different
25//! MCP servers** (e.g., IDE with GitHub server + filesystem server + database server).
26//!
27//! For **single server** scenarios:
28//! - `Client<T>` is cheaply cloneable via Arc - share one session across multiple async tasks
29//! - `TurboTransport` - Add retry/circuit breaker to one session
30
31use std::collections::HashMap;
32use std::sync::Arc;
33use std::time::Duration;
34use tokio::sync::RwLock;
35use tokio::time::Instant;
36use turbomcp_protocol::{Error, Result};
37use turbomcp_transport::Transport;
38
39use super::core::Client;
40
41/// Connection state for a managed client
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub enum ConnectionState {
44    /// Connection is healthy and ready
45    Healthy,
46    /// Connection is degraded but functional
47    Degraded,
48    /// Connection is unhealthy and should be avoided
49    Unhealthy,
50    /// Connection is being established
51    Connecting,
52    /// Connection is disconnected
53    Disconnected,
54}
55
56/// Information about a managed connection
57#[derive(Debug, Clone)]
58pub struct ConnectionInfo {
59    /// Unique identifier for this connection
60    pub id: String,
61    /// Current state of the connection
62    pub state: ConnectionState,
63    /// When the connection was established
64    pub established_at: Instant,
65    /// Last successful health check
66    pub last_health_check: Option<Instant>,
67    /// Number of failed health checks
68    pub failed_health_checks: usize,
69    /// Number of successful requests
70    pub successful_requests: usize,
71    /// Number of failed requests
72    pub failed_requests: usize,
73}
74
75/// Configuration for the connection manager
76#[derive(Debug, Clone)]
77pub struct ManagerConfig {
78    /// Maximum number of concurrent connections
79    pub max_connections: usize,
80    /// Health check interval
81    pub health_check_interval: Duration,
82    /// Number of consecutive failures before marking unhealthy
83    pub health_check_threshold: usize,
84    /// Timeout for health checks
85    pub health_check_timeout: Duration,
86    /// Enable automatic reconnection
87    pub auto_reconnect: bool,
88    /// Initial reconnection delay
89    pub reconnect_delay: Duration,
90    /// Maximum reconnection delay (for exponential backoff)
91    pub max_reconnect_delay: Duration,
92    /// Reconnection backoff multiplier
93    pub reconnect_backoff_multiplier: f64,
94}
95
96impl Default for ManagerConfig {
97    fn default() -> Self {
98        Self {
99            max_connections: 10,
100            health_check_interval: Duration::from_secs(30),
101            health_check_threshold: 3,
102            health_check_timeout: Duration::from_secs(5),
103            auto_reconnect: true,
104            reconnect_delay: Duration::from_secs(1),
105            max_reconnect_delay: Duration::from_secs(60),
106            reconnect_backoff_multiplier: 2.0,
107        }
108    }
109}
110
111impl ManagerConfig {
112    /// Create a new manager configuration with default values
113    pub fn new() -> Self {
114        Self::default()
115    }
116}
117
118/// Managed connection wrapper
119struct ManagedConnection<T: Transport + 'static> {
120    client: Client<T>,
121    info: ConnectionInfo,
122    /// Number of reconnection attempts (reserved for future reconnection logic)
123    #[allow(dead_code)]
124    reconnect_attempts: usize,
125    /// Next reconnection delay (reserved for future reconnection logic)
126    #[allow(dead_code)]
127    next_reconnect_delay: Duration,
128    /// Transport factory for reconnection (reserved for future reconnection logic)
129    #[allow(dead_code)]
130    transport_factory: Option<Arc<dyn Fn() -> T + Send + Sync>>,
131}
132
133/// Server group configuration for failover support
134#[derive(Debug, Clone)]
135pub struct ServerGroup {
136    /// Primary server ID
137    pub primary: String,
138    /// Backup server IDs in priority order
139    pub backups: Vec<String>,
140    /// Minimum health check failures before failover
141    pub failover_threshold: usize,
142}
143
144impl ServerGroup {
145    /// Create a new server group with primary and backups
146    pub fn new(primary: impl Into<String>, backups: Vec<String>) -> Self {
147        Self {
148            primary: primary.into(),
149            backups,
150            failover_threshold: 3,
151        }
152    }
153
154    /// Set the failover threshold
155    pub fn with_failover_threshold(mut self, threshold: usize) -> Self {
156        self.failover_threshold = threshold;
157        self
158    }
159
160    /// Get all server IDs in priority order (primary first, then backups)
161    pub fn all_servers(&self) -> Vec<&str> {
162        std::iter::once(self.primary.as_str())
163            .chain(self.backups.iter().map(|s| s.as_str()))
164            .collect()
165    }
166
167    /// Get the next available server after the current one
168    pub fn next_server(&self, current: &str) -> Option<&str> {
169        let servers = self.all_servers();
170        let current_idx = servers.iter().position(|&s| s == current)?;
171        servers.get(current_idx + 1).copied()
172    }
173}
174
175/// Multi-Server Session Manager for MCP Clients
176///
177/// The `SessionManager` coordinates multiple MCP server sessions with automatic
178/// health monitoring and lifecycle management. Each session represents a long-lived,
179/// initialized connection to a different MCP server.
180///
181/// # Use Cases
182///
183/// - **Multi-Server Applications**: IDE with multiple tool servers
184/// - **Service Coordination**: Orchestrate operations across multiple MCP servers
185/// - **Health Monitoring**: Track health of all connected servers
186/// - **Failover**: Switch between primary/backup servers
187///
188/// # Examples
189///
190/// ```rust,no_run
191/// use turbomcp_client::SessionManager;
192/// use turbomcp_transport::stdio::StdioTransport;
193///
194/// # async fn example() -> turbomcp_protocol::Result<()> {
195/// let mut manager = SessionManager::with_defaults();
196///
197/// // Add sessions to different servers
198/// let github_transport = StdioTransport::new();
199/// let fs_transport = StdioTransport::new();
200/// manager.add_server("github", github_transport).await?;
201/// manager.add_server("filesystem", fs_transport).await?;
202///
203/// // Start health monitoring
204/// manager.start_health_monitoring().await;
205///
206/// // Get stats
207/// let stats = manager.session_stats().await;
208/// println!("Managing {} sessions", stats.len());
209/// # Ok(())
210/// # }
211/// ```
212pub struct SessionManager<T: Transport + 'static> {
213    config: ManagerConfig,
214    connections: Arc<RwLock<HashMap<String, ManagedConnection<T>>>>,
215    health_check_task: Option<tokio::task::JoinHandle<()>>,
216}
217
218impl<T: Transport + Send + 'static> SessionManager<T> {
219    /// Create a new connection manager with the specified configuration
220    pub fn new(config: ManagerConfig) -> Self {
221        Self {
222            config,
223            connections: Arc::new(RwLock::new(HashMap::new())),
224            health_check_task: None,
225        }
226    }
227
228    /// Create a new connection manager with default configuration
229    pub fn with_defaults() -> Self {
230        Self::new(ManagerConfig::default())
231    }
232
233    /// Add a new server session
234    ///
235    /// Creates and initializes a session to the specified MCP server.
236    ///
237    /// # Arguments
238    ///
239    /// * `id` - Unique identifier for this server (e.g., "github", "filesystem")
240    /// * `transport` - Transport implementation for connecting to the server
241    ///
242    /// # Errors
243    ///
244    /// Returns an error if:
245    /// - Maximum sessions limit is reached
246    /// - Server ID already exists
247    /// - Client initialization fails
248    ///
249    /// # Examples
250    ///
251    /// ```rust,no_run
252    /// # use turbomcp_client::SessionManager;
253    /// # use turbomcp_transport::stdio::StdioTransport;
254    /// # async fn example() -> turbomcp_protocol::Result<()> {
255    /// let mut manager = SessionManager::with_defaults();
256    /// let github_transport = StdioTransport::new();
257    /// let fs_transport = StdioTransport::new();
258    /// manager.add_server("github", github_transport).await?;
259    /// manager.add_server("filesystem", fs_transport).await?;
260    /// # Ok(())
261    /// # }
262    /// ```
263    pub async fn add_server(&mut self, id: impl Into<String>, transport: T) -> Result<()> {
264        let id = id.into();
265        let mut connections = self.connections.write().await;
266
267        // Check connection limit
268        if connections.len() >= self.config.max_connections {
269            return Err(Error::bad_request(format!(
270                "Maximum connections limit ({}) reached",
271                self.config.max_connections
272            )));
273        }
274
275        // Check for duplicate ID
276        if connections.contains_key(&id) {
277            return Err(Error::bad_request(format!(
278                "Connection with ID '{}' already exists",
279                id
280            )));
281        }
282
283        // Create client and initialize
284        let client = Client::new(transport);
285        client.initialize().await?;
286
287        let info = ConnectionInfo {
288            id: id.clone(),
289            state: ConnectionState::Healthy,
290            established_at: Instant::now(),
291            last_health_check: Some(Instant::now()),
292            failed_health_checks: 0,
293            successful_requests: 0,
294            failed_requests: 0,
295        };
296
297        connections.insert(
298            id,
299            ManagedConnection {
300                client,
301                info,
302                reconnect_attempts: 0,
303                next_reconnect_delay: self.config.reconnect_delay,
304                transport_factory: None, // No automatic reconnection for manual adds
305            },
306        );
307
308        Ok(())
309    }
310
311    /// Add a server session with automatic reconnection support
312    ///
313    /// This method accepts a factory function that creates new transport instances,
314    /// enabling automatic reconnection if the session fails.
315    ///
316    /// # Arguments
317    ///
318    /// * `id` - Unique identifier for this server
319    /// * `transport_factory` - Function that creates new transport instances
320    ///
321    /// # Examples
322    ///
323    /// ```rust,no_run
324    /// # use turbomcp_client::SessionManager;
325    /// # use turbomcp_transport::stdio::StdioTransport;
326    /// # async fn example() -> turbomcp_protocol::Result<()> {
327    /// let mut manager = SessionManager::with_defaults();
328    ///
329    /// // Transport with reconnection factory
330    /// manager.add_server_with_reconnect("api", || {
331    ///     StdioTransport::new()
332    /// }).await?;
333    /// # Ok(())
334    /// # }
335    /// ```
336    pub async fn add_server_with_reconnect<F>(
337        &mut self,
338        id: impl Into<String>,
339        transport_factory: F,
340    ) -> Result<()>
341    where
342        F: Fn() -> T + Send + Sync + 'static,
343    {
344        let id = id.into();
345        let factory = Arc::new(transport_factory);
346
347        // Create initial transport and client
348        let transport = (factory)();
349        let client = Client::new(transport);
350        client.initialize().await?;
351
352        let info = ConnectionInfo {
353            id: id.clone(),
354            state: ConnectionState::Healthy,
355            established_at: Instant::now(),
356            last_health_check: Some(Instant::now()),
357            failed_health_checks: 0,
358            successful_requests: 0,
359            failed_requests: 0,
360        };
361
362        let mut connections = self.connections.write().await;
363
364        // Check limits
365        if connections.len() >= self.config.max_connections {
366            return Err(Error::bad_request(format!(
367                "Maximum sessions limit ({}) reached",
368                self.config.max_connections
369            )));
370        }
371
372        if connections.contains_key(&id) {
373            return Err(Error::bad_request(format!(
374                "Server with ID '{}' already exists",
375                id
376            )));
377        }
378
379        connections.insert(
380            id,
381            ManagedConnection {
382                client,
383                info,
384                reconnect_attempts: 0,
385                next_reconnect_delay: self.config.reconnect_delay,
386                transport_factory: Some(factory),
387            },
388        );
389
390        Ok(())
391    }
392
393    /// Remove a managed connection
394    ///
395    /// # Arguments
396    ///
397    /// * `id` - ID of the connection to remove
398    ///
399    /// # Returns
400    ///
401    /// Returns `true` if the connection was removed, `false` if not found
402    pub async fn remove_server(&mut self, id: &str) -> bool {
403        let mut connections = self.connections.write().await;
404        connections.remove(id).is_some()
405    }
406
407    /// Get information about a specific connection
408    pub async fn get_session_info(&self, id: &str) -> Option<ConnectionInfo> {
409        let connections = self.connections.read().await;
410        connections.get(id).map(|conn| conn.info.clone())
411    }
412
413    /// List all managed connections
414    pub async fn list_sessions(&self) -> Vec<ConnectionInfo> {
415        let connections = self.connections.read().await;
416        connections.values().map(|conn| conn.info.clone()).collect()
417    }
418
419    /// Get a healthy connection, preferring the one with the fewest active requests
420    ///
421    /// # Returns
422    ///
423    /// Returns the ID of a healthy connection, or None if no healthy connections exist
424    pub async fn get_healthy_connection(&self) -> Option<String> {
425        let connections = self.connections.read().await;
426        connections
427            .iter()
428            .filter(|(_, conn)| conn.info.state == ConnectionState::Healthy)
429            .min_by_key(|(_, conn)| conn.info.successful_requests + conn.info.failed_requests)
430            .map(|(id, _)| id.clone())
431    }
432
433    /// Get count of connections by state
434    pub async fn session_stats(&self) -> HashMap<ConnectionState, usize> {
435        let connections = self.connections.read().await;
436        let mut stats = HashMap::new();
437
438        for conn in connections.values() {
439            *stats.entry(conn.info.state.clone()).or_insert(0) += 1;
440        }
441
442        stats
443    }
444
445    /// Start automatic health monitoring
446    ///
447    /// Spawns a background task that periodically checks the health of all connections
448    pub async fn start_health_monitoring(&mut self) {
449        if self.health_check_task.is_some() {
450            return; // Already running
451        }
452
453        let connections = Arc::clone(&self.connections);
454        let interval = self.config.health_check_interval;
455        let threshold = self.config.health_check_threshold;
456        let timeout = self.config.health_check_timeout;
457
458        let task = tokio::spawn(async move {
459            let mut interval_timer = tokio::time::interval(interval);
460
461            loop {
462                interval_timer.tick().await;
463
464                let mut connections = connections.write().await;
465
466                for (id, managed) in connections.iter_mut() {
467                    // Perform health check (ping)
468                    let health_result = tokio::time::timeout(timeout, managed.client.ping()).await;
469
470                    match health_result {
471                        Ok(Ok(_)) => {
472                            // Health check successful
473                            managed.info.last_health_check = Some(Instant::now());
474                            managed.info.failed_health_checks = 0;
475
476                            if managed.info.state != ConnectionState::Healthy {
477                                tracing::info!(
478                                    connection_id = %id,
479                                    "Connection recovered and is now healthy"
480                                );
481                                managed.info.state = ConnectionState::Healthy;
482                            }
483                        }
484                        Ok(Err(_)) | Err(_) => {
485                            // Health check failed
486                            managed.info.failed_health_checks += 1;
487
488                            if managed.info.failed_health_checks >= threshold {
489                                if managed.info.state != ConnectionState::Unhealthy {
490                                    tracing::warn!(
491                                        connection_id = %id,
492                                        failed_checks = managed.info.failed_health_checks,
493                                        "Connection marked as unhealthy"
494                                    );
495                                    managed.info.state = ConnectionState::Unhealthy;
496                                }
497                            } else if managed.info.state == ConnectionState::Healthy {
498                                tracing::debug!(
499                                    connection_id = %id,
500                                    failed_checks = managed.info.failed_health_checks,
501                                    "Connection degraded"
502                                );
503                                managed.info.state = ConnectionState::Degraded;
504                            }
505                        }
506                    }
507                }
508            }
509        });
510
511        self.health_check_task = Some(task);
512    }
513
514    /// Stop automatic health monitoring
515    pub fn stop_health_monitoring(&mut self) {
516        if let Some(task) = self.health_check_task.take() {
517            task.abort();
518        }
519    }
520
521    /// Get total number of managed connections
522    pub async fn session_count(&self) -> usize {
523        let connections = self.connections.read().await;
524        connections.len()
525    }
526}
527
528impl<T: Transport + 'static> Drop for SessionManager<T> {
529    fn drop(&mut self) {
530        self.stop_health_monitoring();
531    }
532}
533
534// ============================================================================
535// Specialized Implementation for TurboTransport
536// ============================================================================
537
538impl SessionManager<turbomcp_transport::resilience::TurboTransport> {
539    /// Add a server with automatic robustness (specialized for TurboTransport)
540    ///
541    /// This convenience method is only available when using `SessionManager<TurboTransport>`.
542    /// It wraps any transport in TurboTransport with the specified configurations.
543    ///
544    /// # Examples
545    ///
546    /// ```rust,no_run
547    /// # use turbomcp_client::SessionManager;
548    /// # use turbomcp_transport::stdio::StdioTransport;
549    /// # use turbomcp_transport::resilience::*;
550    /// # async fn example() -> turbomcp_protocol::Result<()> {
551    /// let mut manager: SessionManager<TurboTransport> = SessionManager::with_defaults();
552    ///
553    /// // Use explicit configuration for clarity
554    /// use std::time::Duration;
555    /// manager.add_resilient_server(
556    ///     "github",
557    ///     StdioTransport::new(),
558    ///     RetryConfig {
559    ///         max_attempts: 5,
560    ///         base_delay: Duration::from_millis(200),
561    ///         ..Default::default()
562    ///     },
563    ///     CircuitBreakerConfig {
564    ///         failure_threshold: 3,
565    ///         timeout: Duration::from_secs(30),
566    ///         ..Default::default()
567    ///     },
568    ///     HealthCheckConfig {
569    ///         interval: Duration::from_secs(15),
570    ///         timeout: Duration::from_secs(5),
571    ///         ..Default::default()
572    ///     },
573    /// ).await?;
574    /// # Ok(())
575    /// # }
576    /// ```
577    pub async fn add_resilient_server<BaseT>(
578        &mut self,
579        id: impl Into<String>,
580        transport: BaseT,
581        retry_config: turbomcp_transport::resilience::RetryConfig,
582        circuit_config: turbomcp_transport::resilience::CircuitBreakerConfig,
583        health_config: turbomcp_transport::resilience::HealthCheckConfig,
584    ) -> Result<()>
585    where
586        BaseT: Transport + 'static,
587    {
588        use turbomcp_transport::resilience::TurboTransport;
589
590        let robust = TurboTransport::new(
591            Box::new(transport),
592            retry_config,
593            circuit_config,
594            health_config,
595        );
596
597        self.add_server(id, robust).await
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604
605    #[test]
606    fn test_manager_config_defaults() {
607        let config = ManagerConfig::default();
608        assert_eq!(config.max_connections, 10);
609        assert!(config.auto_reconnect);
610    }
611
612    #[test]
613    fn test_connection_state_equality() {
614        assert_eq!(ConnectionState::Healthy, ConnectionState::Healthy);
615        assert_ne!(ConnectionState::Healthy, ConnectionState::Unhealthy);
616    }
617}