turbomcp_client/client/
manager.rs

1//! Multi-Server Session Manager for MCP Clients
2//!
3//! This module provides a `SessionManager` for coordinating multiple MCP server sessions.
4//! Unlike traditional HTTP connection pooling, MCP uses long-lived, stateful sessions
5//! where each connection maintains negotiated capabilities and subscription state.
6//!
7//! # Key Concepts
8//!
9//! - **Session**: A long-lived, initialized MCP connection to a server
10//! - **Multi-Server**: Manage connections to different MCP servers (GitHub, filesystem, etc.)
11//! - **Health Monitoring**: Automatic ping-based health checks per session
12//! - **Lifecycle Management**: Proper initialize → operate → shutdown for each session
13//!
14//! # Features
15//!
16//! - Multiple server sessions with independent state
17//! - Automatic health checking with configurable intervals
18//! - Per-session state tracking (healthy, degraded, unhealthy)
19//! - Session lifecycle management
20//! - Metrics and monitoring per session
21//!
22//! # When to Use
23//!
24//! Use `SessionManager` when your application needs to coordinate **multiple different
25//! MCP servers** (e.g., IDE with GitHub server + filesystem server + database server).
26//!
27//! For **single server** scenarios:
28//! - `Client<T>` is cheaply cloneable via Arc - share one session across multiple async tasks
29//! - `TurboTransport` - Add retry/circuit breaker to one session
30
31use std::collections::HashMap;
32use std::sync::Arc;
33use std::time::Duration;
34use tokio::sync::RwLock;
35use tokio::time::Instant;
36use turbomcp_protocol::{Error, Result};
37use turbomcp_transport::Transport;
38
39use super::core::Client;
40
41/// Connection state for a managed client
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub enum ConnectionState {
44    /// Connection is healthy and ready
45    Healthy,
46    /// Connection is degraded but functional
47    Degraded,
48    /// Connection is unhealthy and should be avoided
49    Unhealthy,
50    /// Connection is being established
51    Connecting,
52    /// Connection is disconnected
53    Disconnected,
54}
55
56/// Information about a managed connection
57#[derive(Debug, Clone)]
58pub struct ConnectionInfo {
59    /// Unique identifier for this connection
60    pub id: String,
61    /// Current state of the connection
62    pub state: ConnectionState,
63    /// When the connection was established
64    pub established_at: Instant,
65    /// Last successful health check
66    pub last_health_check: Option<Instant>,
67    /// Number of failed health checks
68    pub failed_health_checks: usize,
69    /// Number of successful requests
70    pub successful_requests: usize,
71    /// Number of failed requests
72    pub failed_requests: usize,
73}
74
75/// Configuration for the connection manager
76#[derive(Debug, Clone)]
77pub struct ManagerConfig {
78    /// Maximum number of concurrent connections
79    pub max_connections: usize,
80    /// Health check interval
81    pub health_check_interval: Duration,
82    /// Number of consecutive failures before marking unhealthy
83    pub health_check_threshold: usize,
84    /// Timeout for health checks
85    pub health_check_timeout: Duration,
86    /// Enable automatic reconnection
87    pub auto_reconnect: bool,
88    /// Initial reconnection delay
89    pub reconnect_delay: Duration,
90    /// Maximum reconnection delay (for exponential backoff)
91    pub max_reconnect_delay: Duration,
92    /// Reconnection backoff multiplier
93    pub reconnect_backoff_multiplier: f64,
94}
95
96impl Default for ManagerConfig {
97    fn default() -> Self {
98        Self {
99            max_connections: 10,
100            health_check_interval: Duration::from_secs(30),
101            health_check_threshold: 3,
102            health_check_timeout: Duration::from_secs(5),
103            auto_reconnect: true,
104            reconnect_delay: Duration::from_secs(1),
105            max_reconnect_delay: Duration::from_secs(60),
106            reconnect_backoff_multiplier: 2.0,
107        }
108    }
109}
110
111impl ManagerConfig {
112    /// Create a new manager configuration with default values
113    #[must_use]
114    pub fn new() -> Self {
115        Self::default()
116    }
117}
118
119/// Managed connection wrapper
120struct ManagedConnection<T: Transport + 'static> {
121    client: Client<T>,
122    info: ConnectionInfo,
123    /// Number of reconnection attempts (reserved for future reconnection logic)
124    #[allow(dead_code)]
125    reconnect_attempts: usize,
126    /// Next reconnection delay (reserved for future reconnection logic)
127    #[allow(dead_code)]
128    next_reconnect_delay: Duration,
129    /// Transport factory for reconnection (reserved for future reconnection logic)
130    #[allow(dead_code)]
131    transport_factory: Option<Arc<dyn Fn() -> T + Send + Sync>>,
132}
133
134/// Server group configuration for failover support
135#[derive(Debug, Clone)]
136pub struct ServerGroup {
137    /// Primary server ID
138    pub primary: String,
139    /// Backup server IDs in priority order
140    pub backups: Vec<String>,
141    /// Minimum health check failures before failover
142    pub failover_threshold: usize,
143}
144
145impl ServerGroup {
146    /// Create a new server group with primary and backups
147    pub fn new(primary: impl Into<String>, backups: Vec<String>) -> Self {
148        Self {
149            primary: primary.into(),
150            backups,
151            failover_threshold: 3,
152        }
153    }
154
155    /// Set the failover threshold
156    #[must_use]
157    pub fn with_failover_threshold(mut self, threshold: usize) -> Self {
158        self.failover_threshold = threshold;
159        self
160    }
161
162    /// Get all server IDs in priority order (primary first, then backups)
163    #[must_use]
164    pub fn all_servers(&self) -> Vec<&str> {
165        std::iter::once(self.primary.as_str())
166            .chain(self.backups.iter().map(|s| s.as_str()))
167            .collect()
168    }
169
170    /// Get the next available server after the current one
171    #[must_use]
172    pub fn next_server(&self, current: &str) -> Option<&str> {
173        let servers = self.all_servers();
174        let current_idx = servers.iter().position(|&s| s == current)?;
175        servers.get(current_idx + 1).copied()
176    }
177}
178
179/// Multi-Server Session Manager for MCP Clients
180///
181/// The `SessionManager` coordinates multiple MCP server sessions with automatic
182/// health monitoring and lifecycle management. Each session represents a long-lived,
183/// initialized connection to a different MCP server.
184///
185/// # Use Cases
186///
187/// - **Multi-Server Applications**: IDE with multiple tool servers
188/// - **Service Coordination**: Orchestrate operations across multiple MCP servers
189/// - **Health Monitoring**: Track health of all connected servers
190/// - **Failover**: Switch between primary/backup servers
191///
192/// # Examples
193///
194/// ```rust,no_run
195/// use turbomcp_client::SessionManager;
196/// use turbomcp_transport::stdio::StdioTransport;
197///
198/// # async fn example() -> turbomcp_protocol::Result<()> {
199/// let mut manager = SessionManager::with_defaults();
200///
201/// // Add sessions to different servers
202/// let github_transport = StdioTransport::new();
203/// let fs_transport = StdioTransport::new();
204/// manager.add_server("github", github_transport).await?;
205/// manager.add_server("filesystem", fs_transport).await?;
206///
207/// // Start health monitoring
208/// manager.start_health_monitoring().await;
209///
210/// // Get stats
211/// let stats = manager.session_stats().await;
212/// println!("Managing {} sessions", stats.len());
213/// # Ok(())
214/// # }
215/// ```
216pub struct SessionManager<T: Transport + 'static> {
217    config: ManagerConfig,
218    connections: Arc<RwLock<HashMap<String, ManagedConnection<T>>>>,
219    health_check_task: Option<tokio::task::JoinHandle<()>>,
220}
221
222impl<T: Transport + Send + 'static> SessionManager<T> {
223    /// Create a new connection manager with the specified configuration
224    #[must_use]
225    pub fn new(config: ManagerConfig) -> Self {
226        Self {
227            config,
228            connections: Arc::new(RwLock::new(HashMap::new())),
229            health_check_task: None,
230        }
231    }
232
233    /// Create a new connection manager with default configuration
234    #[must_use]
235    pub fn with_defaults() -> Self {
236        Self::new(ManagerConfig::default())
237    }
238
239    /// Add a new server session
240    ///
241    /// Creates and initializes a session to the specified MCP server.
242    ///
243    /// # Arguments
244    ///
245    /// * `id` - Unique identifier for this server (e.g., "github", "filesystem")
246    /// * `transport` - Transport implementation for connecting to the server
247    ///
248    /// # Errors
249    ///
250    /// Returns an error if:
251    /// - Maximum sessions limit is reached
252    /// - Server ID already exists
253    /// - Client initialization fails
254    ///
255    /// # Examples
256    ///
257    /// ```rust,no_run
258    /// # use turbomcp_client::SessionManager;
259    /// # use turbomcp_transport::stdio::StdioTransport;
260    /// # async fn example() -> turbomcp_protocol::Result<()> {
261    /// let mut manager = SessionManager::with_defaults();
262    /// let github_transport = StdioTransport::new();
263    /// let fs_transport = StdioTransport::new();
264    /// manager.add_server("github", github_transport).await?;
265    /// manager.add_server("filesystem", fs_transport).await?;
266    /// # Ok(())
267    /// # }
268    /// ```
269    pub async fn add_server(&mut self, id: impl Into<String>, transport: T) -> Result<()> {
270        let id = id.into();
271        let mut connections = self.connections.write().await;
272
273        // Check connection limit
274        if connections.len() >= self.config.max_connections {
275            return Err(Error::bad_request(format!(
276                "Maximum connections limit ({}) reached",
277                self.config.max_connections
278            )));
279        }
280
281        // Check for duplicate ID
282        if connections.contains_key(&id) {
283            return Err(Error::bad_request(format!(
284                "Connection with ID '{}' already exists",
285                id
286            )));
287        }
288
289        // Create client and initialize
290        let client = Client::new(transport);
291        client.initialize().await?;
292
293        let info = ConnectionInfo {
294            id: id.clone(),
295            state: ConnectionState::Healthy,
296            established_at: Instant::now(),
297            last_health_check: Some(Instant::now()),
298            failed_health_checks: 0,
299            successful_requests: 0,
300            failed_requests: 0,
301        };
302
303        connections.insert(
304            id,
305            ManagedConnection {
306                client,
307                info,
308                reconnect_attempts: 0,
309                next_reconnect_delay: self.config.reconnect_delay,
310                transport_factory: None, // No automatic reconnection for manual adds
311            },
312        );
313
314        Ok(())
315    }
316
317    /// Add a server session with automatic reconnection support
318    ///
319    /// This method accepts a factory function that creates new transport instances,
320    /// enabling automatic reconnection if the session fails.
321    ///
322    /// # Arguments
323    ///
324    /// * `id` - Unique identifier for this server
325    /// * `transport_factory` - Function that creates new transport instances
326    ///
327    /// # Examples
328    ///
329    /// ```rust,no_run
330    /// # use turbomcp_client::SessionManager;
331    /// # use turbomcp_transport::stdio::StdioTransport;
332    /// # async fn example() -> turbomcp_protocol::Result<()> {
333    /// let mut manager = SessionManager::with_defaults();
334    ///
335    /// // Transport with reconnection factory
336    /// manager.add_server_with_reconnect("api", || {
337    ///     StdioTransport::new()
338    /// }).await?;
339    /// # Ok(())
340    /// # }
341    /// ```
342    pub async fn add_server_with_reconnect<F>(
343        &mut self,
344        id: impl Into<String>,
345        transport_factory: F,
346    ) -> Result<()>
347    where
348        F: Fn() -> T + Send + Sync + 'static,
349    {
350        let id = id.into();
351        let factory = Arc::new(transport_factory);
352
353        // Create initial transport and client
354        let transport = (factory)();
355        let client = Client::new(transport);
356        client.initialize().await?;
357
358        let info = ConnectionInfo {
359            id: id.clone(),
360            state: ConnectionState::Healthy,
361            established_at: Instant::now(),
362            last_health_check: Some(Instant::now()),
363            failed_health_checks: 0,
364            successful_requests: 0,
365            failed_requests: 0,
366        };
367
368        let mut connections = self.connections.write().await;
369
370        // Check limits
371        if connections.len() >= self.config.max_connections {
372            return Err(Error::bad_request(format!(
373                "Maximum sessions limit ({}) reached",
374                self.config.max_connections
375            )));
376        }
377
378        if connections.contains_key(&id) {
379            return Err(Error::bad_request(format!(
380                "Server with ID '{}' already exists",
381                id
382            )));
383        }
384
385        connections.insert(
386            id,
387            ManagedConnection {
388                client,
389                info,
390                reconnect_attempts: 0,
391                next_reconnect_delay: self.config.reconnect_delay,
392                transport_factory: Some(factory),
393            },
394        );
395
396        Ok(())
397    }
398
399    /// Remove a managed connection
400    ///
401    /// # Arguments
402    ///
403    /// * `id` - ID of the connection to remove
404    ///
405    /// # Returns
406    ///
407    /// Returns `true` if the connection was removed, `false` if not found
408    pub async fn remove_server(&mut self, id: &str) -> bool {
409        let mut connections = self.connections.write().await;
410        connections.remove(id).is_some()
411    }
412
413    /// Get information about a specific connection
414    pub async fn get_session_info(&self, id: &str) -> Option<ConnectionInfo> {
415        let connections = self.connections.read().await;
416        connections.get(id).map(|conn| conn.info.clone())
417    }
418
419    /// List all managed connections
420    pub async fn list_sessions(&self) -> Vec<ConnectionInfo> {
421        let connections = self.connections.read().await;
422        connections.values().map(|conn| conn.info.clone()).collect()
423    }
424
425    /// Get a healthy connection, preferring the one with the fewest active requests
426    ///
427    /// # Returns
428    ///
429    /// Returns the ID of a healthy connection, or None if no healthy connections exist
430    pub async fn get_healthy_connection(&self) -> Option<String> {
431        let connections = self.connections.read().await;
432        connections
433            .iter()
434            .filter(|(_, conn)| conn.info.state == ConnectionState::Healthy)
435            .min_by_key(|(_, conn)| conn.info.successful_requests + conn.info.failed_requests)
436            .map(|(id, _)| id.clone())
437    }
438
439    /// Get count of connections by state
440    pub async fn session_stats(&self) -> HashMap<ConnectionState, usize> {
441        let connections = self.connections.read().await;
442        let mut stats = HashMap::new();
443
444        for conn in connections.values() {
445            *stats.entry(conn.info.state.clone()).or_insert(0) += 1;
446        }
447
448        stats
449    }
450
451    /// Start automatic health monitoring
452    ///
453    /// Spawns a background task that periodically checks the health of all connections
454    pub async fn start_health_monitoring(&mut self) {
455        if self.health_check_task.is_some() {
456            return; // Already running
457        }
458
459        let connections = Arc::clone(&self.connections);
460        let interval = self.config.health_check_interval;
461        let threshold = self.config.health_check_threshold;
462        let timeout = self.config.health_check_timeout;
463
464        let task = tokio::spawn(async move {
465            let mut interval_timer = tokio::time::interval(interval);
466
467            loop {
468                interval_timer.tick().await;
469
470                let mut connections = connections.write().await;
471
472                for (id, managed) in connections.iter_mut() {
473                    // Perform health check (ping)
474                    let health_result = tokio::time::timeout(timeout, managed.client.ping()).await;
475
476                    match health_result {
477                        Ok(Ok(_)) => {
478                            // Health check successful
479                            managed.info.last_health_check = Some(Instant::now());
480                            managed.info.failed_health_checks = 0;
481
482                            if managed.info.state != ConnectionState::Healthy {
483                                tracing::info!(
484                                    connection_id = %id,
485                                    "Connection recovered and is now healthy"
486                                );
487                                managed.info.state = ConnectionState::Healthy;
488                            }
489                        }
490                        Ok(Err(_)) | Err(_) => {
491                            // Health check failed
492                            managed.info.failed_health_checks += 1;
493
494                            if managed.info.failed_health_checks >= threshold {
495                                if managed.info.state != ConnectionState::Unhealthy {
496                                    tracing::warn!(
497                                        connection_id = %id,
498                                        failed_checks = managed.info.failed_health_checks,
499                                        "Connection marked as unhealthy"
500                                    );
501                                    managed.info.state = ConnectionState::Unhealthy;
502                                }
503                            } else if managed.info.state == ConnectionState::Healthy {
504                                tracing::debug!(
505                                    connection_id = %id,
506                                    failed_checks = managed.info.failed_health_checks,
507                                    "Connection degraded"
508                                );
509                                managed.info.state = ConnectionState::Degraded;
510                            }
511                        }
512                    }
513                }
514            }
515        });
516
517        self.health_check_task = Some(task);
518    }
519
520    /// Stop automatic health monitoring
521    pub fn stop_health_monitoring(&mut self) {
522        if let Some(task) = self.health_check_task.take() {
523            task.abort();
524        }
525    }
526
527    /// Get total number of managed connections
528    pub async fn session_count(&self) -> usize {
529        let connections = self.connections.read().await;
530        connections.len()
531    }
532}
533
534impl<T: Transport + 'static> Drop for SessionManager<T> {
535    fn drop(&mut self) {
536        self.stop_health_monitoring();
537    }
538}
539
540// ============================================================================
541// Specialized Implementation for TurboTransport
542// ============================================================================
543
544impl SessionManager<turbomcp_transport::resilience::TurboTransport> {
545    /// Add a server with automatic robustness (specialized for TurboTransport)
546    ///
547    /// This convenience method is only available when using `SessionManager<TurboTransport>`.
548    /// It wraps any transport in TurboTransport with the specified configurations.
549    ///
550    /// # Examples
551    ///
552    /// ```rust,no_run
553    /// # use turbomcp_client::SessionManager;
554    /// # use turbomcp_transport::stdio::StdioTransport;
555    /// # use turbomcp_transport::resilience::*;
556    /// # async fn example() -> turbomcp_protocol::Result<()> {
557    /// let mut manager: SessionManager<TurboTransport> = SessionManager::with_defaults();
558    ///
559    /// // Use explicit configuration for clarity
560    /// use std::time::Duration;
561    /// manager.add_resilient_server(
562    ///     "github",
563    ///     StdioTransport::new(),
564    ///     RetryConfig {
565    ///         max_attempts: 5,
566    ///         base_delay: Duration::from_millis(200),
567    ///         ..Default::default()
568    ///     },
569    ///     CircuitBreakerConfig {
570    ///         failure_threshold: 3,
571    ///         timeout: Duration::from_secs(30),
572    ///         ..Default::default()
573    ///     },
574    ///     HealthCheckConfig {
575    ///         interval: Duration::from_secs(15),
576    ///         timeout: Duration::from_secs(5),
577    ///         ..Default::default()
578    ///     },
579    /// ).await?;
580    /// # Ok(())
581    /// # }
582    /// ```
583    pub async fn add_resilient_server<BaseT>(
584        &mut self,
585        id: impl Into<String>,
586        transport: BaseT,
587        retry_config: turbomcp_transport::resilience::RetryConfig,
588        circuit_config: turbomcp_transport::resilience::CircuitBreakerConfig,
589        health_config: turbomcp_transport::resilience::HealthCheckConfig,
590    ) -> Result<()>
591    where
592        BaseT: Transport + 'static,
593    {
594        use turbomcp_transport::resilience::TurboTransport;
595
596        let robust = TurboTransport::new(
597            Box::new(transport),
598            retry_config,
599            circuit_config,
600            health_config,
601        );
602
603        self.add_server(id, robust).await
604    }
605}
606
607#[cfg(test)]
608mod tests {
609    use super::*;
610
611    #[test]
612    fn test_manager_config_defaults() {
613        let config = ManagerConfig::default();
614        assert_eq!(config.max_connections, 10);
615        assert!(config.auto_reconnect);
616    }
617
618    #[test]
619    fn test_connection_state_equality() {
620        assert_eq!(ConnectionState::Healthy, ConnectionState::Healthy);
621        assert_ne!(ConnectionState::Healthy, ConnectionState::Unhealthy);
622    }
623}