ultrafast_mcp_transport/
lib.rs

1//! UltraFast MCP Transport Layer
2//!
3//! This crate provides high-performance transport implementations for the Model Context Protocol (MCP).
4//! It supports multiple transport types including STDIO and HTTP with advanced features like
5//! connection pooling, rate limiting, and request optimization.
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::fmt;
10use thiserror::Error;
11use ultrafast_mcp_core::protocol::JsonRpcMessage;
12
13pub mod stdio;
14
15#[cfg(feature = "http")]
16pub mod streamable_http;
17
18/// Result type for transport operations
19pub type Result<T> = std::result::Result<T, TransportError>;
20
21/// Connection state for transport lifecycle management
22#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
23pub enum ConnectionState {
24    /// Transport is disconnected
25    Disconnected,
26    /// Transport is connecting
27    Connecting,
28    /// Transport is connected and ready
29    Connected,
30    /// Transport is reconnecting after a failure
31    Reconnecting,
32    /// Transport is shutting down gracefully
33    ShuttingDown,
34    /// Transport has failed and needs recovery
35    Failed(String),
36}
37
38impl fmt::Display for ConnectionState {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            ConnectionState::Disconnected => write!(f, "disconnected"),
42            ConnectionState::Connecting => write!(f, "connecting"),
43            ConnectionState::Connected => write!(f, "connected"),
44            ConnectionState::Reconnecting => write!(f, "reconnecting"),
45            ConnectionState::ShuttingDown => write!(f, "shutting down"),
46            ConnectionState::Failed(reason) => write!(f, "failed: {reason}"),
47        }
48    }
49}
50
51/// Transport health information
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct TransportHealth {
54    pub state: ConnectionState,
55    pub last_activity: Option<std::time::SystemTime>,
56    pub messages_sent: u64,
57    pub messages_received: u64,
58    pub connection_duration: Option<std::time::Duration>,
59    pub error_count: u64,
60    pub last_error: Option<String>,
61}
62
63impl Default for TransportHealth {
64    fn default() -> Self {
65        Self {
66            state: ConnectionState::Disconnected,
67            last_activity: None,
68            messages_sent: 0,
69            messages_received: 0,
70            connection_duration: None,
71            error_count: 0,
72            last_error: None,
73        }
74    }
75}
76
77/// Transport lifecycle events
78#[derive(Debug, Clone)]
79pub enum TransportEvent {
80    Connected,
81    Disconnected,
82    Reconnecting,
83    MessageSent,
84    MessageReceived,
85    Error(String),
86    ShutdownRequested,
87    ShutdownComplete,
88}
89
90/// Callback trait for transport lifecycle events
91#[async_trait]
92pub trait TransportEventHandler: Send + Sync {
93    async fn handle_event(&self, event: TransportEvent);
94}
95
96/// Configuration for transport recovery
97#[derive(Debug, Clone)]
98pub struct RecoveryConfig {
99    pub max_retries: u32,
100    pub initial_delay: std::time::Duration,
101    pub max_delay: std::time::Duration,
102    pub backoff_multiplier: f64,
103    pub enable_jitter: bool,
104}
105
106impl Default for RecoveryConfig {
107    fn default() -> Self {
108        Self {
109            max_retries: 5,
110            initial_delay: std::time::Duration::from_millis(100),
111            max_delay: std::time::Duration::from_secs(30),
112            backoff_multiplier: 2.0,
113            enable_jitter: true,
114        }
115    }
116}
117
118/// Transport shutdown configuration
119#[derive(Debug, Clone)]
120pub struct ShutdownConfig {
121    pub graceful_timeout: std::time::Duration,
122    pub force_timeout: std::time::Duration,
123    pub drain_pending_messages: bool,
124}
125
126impl Default for ShutdownConfig {
127    fn default() -> Self {
128        Self {
129            graceful_timeout: std::time::Duration::from_secs(5),
130            force_timeout: std::time::Duration::from_secs(10),
131            drain_pending_messages: true,
132        }
133    }
134}
135
136/// Transport error types
137#[derive(Debug, Error)]
138pub enum TransportError {
139    #[error("Connection error: {message}")]
140    ConnectionError { message: String },
141
142    #[error("Connection closed")]
143    ConnectionClosed,
144
145    #[error("Connection timeout")]
146    ConnectionTimeout,
147
148    #[error("Serialization error: {message}")]
149    SerializationError { message: String },
150
151    #[error("Network error: {message}")]
152    NetworkError { message: String },
153
154    #[error("Authentication error: {message}")]
155    AuthenticationError { message: String },
156
157    #[error("Protocol error: {message}")]
158    ProtocolError { message: String },
159
160    #[error("Initialization error: {message}")]
161    InitializationError { message: String },
162
163    #[error("Internal error: {message}")]
164    InternalError { message: String },
165
166    #[error("Recovery failed after {attempts} attempts: {message}")]
167    RecoveryFailed { attempts: u32, message: String },
168
169    #[error("Shutdown timeout: {message}")]
170    ShutdownTimeout { message: String },
171
172    #[error("Transport not ready: current state is {state}")]
173    NotReady { state: ConnectionState },
174}
175
176/// Enhanced transport trait with lifecycle management
177#[async_trait]
178pub trait Transport: Send + Sync {
179    /// Send a message through the transport
180    async fn send_message(&mut self, message: JsonRpcMessage) -> Result<()>;
181
182    /// Receive a message from the transport
183    async fn receive_message(&mut self) -> Result<JsonRpcMessage>;
184
185    /// Close the transport connection gracefully
186    async fn close(&mut self) -> Result<()>;
187
188    /// Get current connection state
189    fn get_state(&self) -> ConnectionState {
190        ConnectionState::Connected // Default implementation for backward compatibility
191    }
192
193    /// Get transport health information
194    fn get_health(&self) -> TransportHealth {
195        TransportHealth {
196            state: self.get_state(),
197            ..Default::default()
198        }
199    }
200
201    /// Check if transport is ready for operations
202    fn is_ready(&self) -> bool {
203        matches!(self.get_state(), ConnectionState::Connected)
204    }
205
206    /// Initiate graceful shutdown
207    async fn shutdown(&mut self, config: ShutdownConfig) -> Result<()> {
208        // Default implementation just calls close()
209        tokio::time::timeout(config.graceful_timeout, self.close())
210            .await
211            .map_err(|_| TransportError::ShutdownTimeout {
212                message: "Graceful shutdown timeout".to_string(),
213            })?
214    }
215
216    /// Force immediate shutdown
217    async fn force_shutdown(&mut self) -> Result<()> {
218        // Default implementation just calls close()
219        self.close().await
220    }
221
222    /// Attempt to reconnect the transport
223    async fn reconnect(&mut self) -> Result<()> {
224        // Default implementation: close and let the caller handle reconnection
225        self.close().await?;
226        Err(TransportError::ConnectionError {
227            message: "Reconnection not supported by this transport".to_string(),
228        })
229    }
230
231    /// Reset transport state and clear any cached data
232    async fn reset(&mut self) -> Result<()> {
233        // Default implementation just calls close()
234        self.close().await
235    }
236}
237
238/// Enhanced transport with automatic recovery
239pub struct RecoveringTransport {
240    inner: Box<dyn Transport>,
241    recovery_config: RecoveryConfig,
242    health: TransportHealth,
243    event_handler: Option<Box<dyn TransportEventHandler>>,
244    retry_count: u32,
245    last_error: Option<String>,
246}
247
248impl RecoveringTransport {
249    pub fn new(transport: Box<dyn Transport>, recovery_config: RecoveryConfig) -> Self {
250        Self {
251            inner: transport,
252            recovery_config,
253            health: TransportHealth::default(),
254            event_handler: None,
255            retry_count: 0,
256            last_error: None,
257        }
258    }
259
260    pub fn with_event_handler(mut self, handler: Box<dyn TransportEventHandler>) -> Self {
261        self.event_handler = Some(handler);
262        self
263    }
264
265    async fn emit_event(&self, event: TransportEvent) {
266        if let Some(handler) = &self.event_handler {
267            handler.handle_event(event).await;
268        }
269    }
270
271    async fn attempt_recovery(&mut self) -> Result<()> {
272        if self.retry_count >= self.recovery_config.max_retries {
273            let error_msg = format!(
274                "Max retries ({}) exceeded. Last error: {}",
275                self.recovery_config.max_retries,
276                self.last_error.as_deref().unwrap_or("unknown")
277            );
278            self.health.state = ConnectionState::Failed(error_msg.clone());
279            return Err(TransportError::RecoveryFailed {
280                attempts: self.retry_count,
281                message: error_msg,
282            });
283        }
284
285        self.health.state = ConnectionState::Reconnecting;
286        self.emit_event(TransportEvent::Reconnecting).await;
287
288        // Calculate delay with exponential backoff
289        let delay = self.calculate_retry_delay();
290        tokio::time::sleep(delay).await;
291
292        // Attempt reconnection
293        match self.inner.reconnect().await {
294            Ok(()) => {
295                self.health.state = ConnectionState::Connected;
296                self.retry_count = 0;
297                self.last_error = None;
298                self.emit_event(TransportEvent::Connected).await;
299                Ok(())
300            }
301            Err(e) => {
302                self.retry_count += 1;
303                self.last_error = Some(e.to_string());
304                self.health.error_count += 1;
305                self.health.last_error = Some(e.to_string());
306                Err(e)
307            }
308        }
309    }
310
311    fn calculate_retry_delay(&self) -> std::time::Duration {
312        let base_delay = self.recovery_config.initial_delay.as_millis() as f64;
313        let multiplier = self
314            .recovery_config
315            .backoff_multiplier
316            .powi(self.retry_count as i32);
317        let mut delay_ms = base_delay * multiplier;
318
319        // Add jitter if enabled
320        if self.recovery_config.enable_jitter {
321            use rand::Rng;
322            let mut rng = rand::rng();
323            let jitter: f64 = rng.random_range(0.8..1.2);
324            delay_ms *= jitter;
325        }
326
327        // Cap at max delay
328        let max_delay_ms = self.recovery_config.max_delay.as_millis() as f64;
329        delay_ms = delay_ms.min(max_delay_ms);
330
331        std::time::Duration::from_millis(delay_ms as u64)
332    }
333}
334
335#[async_trait]
336impl Transport for RecoveringTransport {
337    async fn send_message(&mut self, message: JsonRpcMessage) -> Result<()> {
338        loop {
339            match self.inner.send_message(message.clone()).await {
340                Ok(()) => {
341                    self.health.messages_sent += 1;
342                    self.health.last_activity = Some(std::time::SystemTime::now());
343                    self.emit_event(TransportEvent::MessageSent).await;
344                    return Ok(());
345                }
346                Err(e) => {
347                    self.emit_event(TransportEvent::Error(e.to_string())).await;
348
349                    // Try recovery for connection errors
350                    if matches!(
351                        e,
352                        TransportError::ConnectionClosed | TransportError::ConnectionError { .. }
353                    ) {
354                        match self.attempt_recovery().await {
355                            Ok(()) => continue, // Retry the send
356                            Err(recovery_err) => return Err(recovery_err),
357                        }
358                    } else {
359                        return Err(e);
360                    }
361                }
362            }
363        }
364    }
365
366    async fn receive_message(&mut self) -> Result<JsonRpcMessage> {
367        loop {
368            match self.inner.receive_message().await {
369                Ok(message) => {
370                    self.health.messages_received += 1;
371                    self.health.last_activity = Some(std::time::SystemTime::now());
372                    self.emit_event(TransportEvent::MessageReceived).await;
373                    return Ok(message);
374                }
375                Err(e) => {
376                    self.emit_event(TransportEvent::Error(e.to_string())).await;
377
378                    // Try recovery for connection errors
379                    if matches!(
380                        e,
381                        TransportError::ConnectionClosed | TransportError::ConnectionError { .. }
382                    ) {
383                        match self.attempt_recovery().await {
384                            Ok(()) => continue, // Retry the receive
385                            Err(recovery_err) => return Err(recovery_err),
386                        }
387                    } else {
388                        return Err(e);
389                    }
390                }
391            }
392        }
393    }
394
395    async fn close(&mut self) -> Result<()> {
396        self.health.state = ConnectionState::ShuttingDown;
397        self.emit_event(TransportEvent::ShutdownRequested).await;
398
399        let result = self.inner.close().await;
400
401        self.health.state = ConnectionState::Disconnected;
402        self.emit_event(TransportEvent::ShutdownComplete).await;
403
404        result
405    }
406
407    fn get_state(&self) -> ConnectionState {
408        self.health.state.clone()
409    }
410
411    fn get_health(&self) -> TransportHealth {
412        self.health.clone()
413    }
414
415    async fn shutdown(&mut self, config: ShutdownConfig) -> Result<()> {
416        self.inner.shutdown(config).await
417    }
418
419    async fn force_shutdown(&mut self) -> Result<()> {
420        self.inner.force_shutdown().await
421    }
422
423    async fn reconnect(&mut self) -> Result<()> {
424        self.attempt_recovery().await
425    }
426
427    async fn reset(&mut self) -> Result<()> {
428        self.health = TransportHealth::default();
429        self.retry_count = 0;
430        self.last_error = None;
431        self.inner.reset().await
432    }
433}
434
435/// Transport configuration
436#[derive(Debug, Clone)]
437pub enum TransportConfig {
438    /// Standard input/output transport
439    Stdio,
440
441    /// Streamable HTTP transport (PRD recommended)
442    #[cfg(feature = "http")]
443    Streamable {
444        base_url: String,
445        auth_token: Option<String>,
446        session_id: Option<String>,
447    },
448}
449
450/// Create a transport from configuration
451pub async fn create_transport(config: TransportConfig) -> Result<Box<dyn Transport>> {
452    match config {
453        TransportConfig::Stdio => {
454            let transport = stdio::StdioTransport::new().await?;
455            Ok(Box::new(transport))
456        }
457
458        #[cfg(feature = "http")]
459        TransportConfig::Streamable {
460            base_url,
461            auth_token,
462            session_id,
463        } => {
464            let client_config = streamable_http::client::StreamableHttpClientConfig {
465                base_url,
466                auth_token,
467                session_id,
468                auth_method: None,
469                ..Default::default()
470            };
471
472            let mut client = streamable_http::client::StreamableHttpClient::new(client_config)?;
473            client.connect().await?;
474            Ok(Box::new(client))
475        }
476    }
477}
478
479/// Create a transport with automatic recovery
480pub async fn create_recovering_transport(
481    config: TransportConfig,
482    recovery_config: RecoveryConfig,
483) -> Result<Box<dyn Transport>> {
484    let transport = create_transport(config).await?;
485    let recovering_transport = RecoveringTransport::new(transport, recovery_config);
486    Ok(Box::new(recovering_transport))
487}