1use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::fmt;
10use thiserror::Error;
11use ultrafast_mcp_core::protocol::JsonRpcMessage;
12
13pub mod stdio;
14
15#[cfg(feature = "http")]
16pub mod streamable_http;
17
18pub type Result<T> = std::result::Result<T, TransportError>;
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
23pub enum ConnectionState {
24 Disconnected,
26 Connecting,
28 Connected,
30 Reconnecting,
32 ShuttingDown,
34 Failed(String),
36}
37
38impl fmt::Display for ConnectionState {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 ConnectionState::Disconnected => write!(f, "disconnected"),
42 ConnectionState::Connecting => write!(f, "connecting"),
43 ConnectionState::Connected => write!(f, "connected"),
44 ConnectionState::Reconnecting => write!(f, "reconnecting"),
45 ConnectionState::ShuttingDown => write!(f, "shutting down"),
46 ConnectionState::Failed(reason) => write!(f, "failed: {reason}"),
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct TransportHealth {
54 pub state: ConnectionState,
55 pub last_activity: Option<std::time::SystemTime>,
56 pub messages_sent: u64,
57 pub messages_received: u64,
58 pub connection_duration: Option<std::time::Duration>,
59 pub error_count: u64,
60 pub last_error: Option<String>,
61}
62
63impl Default for TransportHealth {
64 fn default() -> Self {
65 Self {
66 state: ConnectionState::Disconnected,
67 last_activity: None,
68 messages_sent: 0,
69 messages_received: 0,
70 connection_duration: None,
71 error_count: 0,
72 last_error: None,
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub enum TransportEvent {
80 Connected,
81 Disconnected,
82 Reconnecting,
83 MessageSent,
84 MessageReceived,
85 Error(String),
86 ShutdownRequested,
87 ShutdownComplete,
88}
89
90#[async_trait]
92pub trait TransportEventHandler: Send + Sync {
93 async fn handle_event(&self, event: TransportEvent);
94}
95
96#[derive(Debug, Clone)]
98pub struct RecoveryConfig {
99 pub max_retries: u32,
100 pub initial_delay: std::time::Duration,
101 pub max_delay: std::time::Duration,
102 pub backoff_multiplier: f64,
103 pub enable_jitter: bool,
104}
105
106impl Default for RecoveryConfig {
107 fn default() -> Self {
108 Self {
109 max_retries: 5,
110 initial_delay: std::time::Duration::from_millis(100),
111 max_delay: std::time::Duration::from_secs(30),
112 backoff_multiplier: 2.0,
113 enable_jitter: true,
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct ShutdownConfig {
121 pub graceful_timeout: std::time::Duration,
122 pub force_timeout: std::time::Duration,
123 pub drain_pending_messages: bool,
124}
125
126impl Default for ShutdownConfig {
127 fn default() -> Self {
128 Self {
129 graceful_timeout: std::time::Duration::from_secs(5),
130 force_timeout: std::time::Duration::from_secs(10),
131 drain_pending_messages: true,
132 }
133 }
134}
135
136#[derive(Debug, Error)]
138pub enum TransportError {
139 #[error("Connection error: {message}")]
140 ConnectionError { message: String },
141
142 #[error("Connection closed")]
143 ConnectionClosed,
144
145 #[error("Connection timeout")]
146 ConnectionTimeout,
147
148 #[error("Serialization error: {message}")]
149 SerializationError { message: String },
150
151 #[error("Network error: {message}")]
152 NetworkError { message: String },
153
154 #[error("Authentication error: {message}")]
155 AuthenticationError { message: String },
156
157 #[error("Protocol error: {message}")]
158 ProtocolError { message: String },
159
160 #[error("Initialization error: {message}")]
161 InitializationError { message: String },
162
163 #[error("Internal error: {message}")]
164 InternalError { message: String },
165
166 #[error("Recovery failed after {attempts} attempts: {message}")]
167 RecoveryFailed { attempts: u32, message: String },
168
169 #[error("Shutdown timeout: {message}")]
170 ShutdownTimeout { message: String },
171
172 #[error("Transport not ready: current state is {state}")]
173 NotReady { state: ConnectionState },
174}
175
176#[async_trait]
178pub trait Transport: Send + Sync {
179 async fn send_message(&mut self, message: JsonRpcMessage) -> Result<()>;
181
182 async fn receive_message(&mut self) -> Result<JsonRpcMessage>;
184
185 async fn close(&mut self) -> Result<()>;
187
188 fn get_state(&self) -> ConnectionState {
190 ConnectionState::Connected }
192
193 fn get_health(&self) -> TransportHealth {
195 TransportHealth {
196 state: self.get_state(),
197 ..Default::default()
198 }
199 }
200
201 fn is_ready(&self) -> bool {
203 matches!(self.get_state(), ConnectionState::Connected)
204 }
205
206 async fn shutdown(&mut self, config: ShutdownConfig) -> Result<()> {
208 tokio::time::timeout(config.graceful_timeout, self.close())
210 .await
211 .map_err(|_| TransportError::ShutdownTimeout {
212 message: "Graceful shutdown timeout".to_string(),
213 })?
214 }
215
216 async fn force_shutdown(&mut self) -> Result<()> {
218 self.close().await
220 }
221
222 async fn reconnect(&mut self) -> Result<()> {
224 self.close().await?;
226 Err(TransportError::ConnectionError {
227 message: "Reconnection not supported by this transport".to_string(),
228 })
229 }
230
231 async fn reset(&mut self) -> Result<()> {
233 self.close().await
235 }
236}
237
238pub struct RecoveringTransport {
240 inner: Box<dyn Transport>,
241 recovery_config: RecoveryConfig,
242 health: TransportHealth,
243 event_handler: Option<Box<dyn TransportEventHandler>>,
244 retry_count: u32,
245 last_error: Option<String>,
246}
247
248impl RecoveringTransport {
249 pub fn new(transport: Box<dyn Transport>, recovery_config: RecoveryConfig) -> Self {
250 Self {
251 inner: transport,
252 recovery_config,
253 health: TransportHealth::default(),
254 event_handler: None,
255 retry_count: 0,
256 last_error: None,
257 }
258 }
259
260 pub fn with_event_handler(mut self, handler: Box<dyn TransportEventHandler>) -> Self {
261 self.event_handler = Some(handler);
262 self
263 }
264
265 async fn emit_event(&self, event: TransportEvent) {
266 if let Some(handler) = &self.event_handler {
267 handler.handle_event(event).await;
268 }
269 }
270
271 async fn attempt_recovery(&mut self) -> Result<()> {
272 if self.retry_count >= self.recovery_config.max_retries {
273 let error_msg = format!(
274 "Max retries ({}) exceeded. Last error: {}",
275 self.recovery_config.max_retries,
276 self.last_error.as_deref().unwrap_or("unknown")
277 );
278 self.health.state = ConnectionState::Failed(error_msg.clone());
279 return Err(TransportError::RecoveryFailed {
280 attempts: self.retry_count,
281 message: error_msg,
282 });
283 }
284
285 self.health.state = ConnectionState::Reconnecting;
286 self.emit_event(TransportEvent::Reconnecting).await;
287
288 let delay = self.calculate_retry_delay();
290 tokio::time::sleep(delay).await;
291
292 match self.inner.reconnect().await {
294 Ok(()) => {
295 self.health.state = ConnectionState::Connected;
296 self.retry_count = 0;
297 self.last_error = None;
298 self.emit_event(TransportEvent::Connected).await;
299 Ok(())
300 }
301 Err(e) => {
302 self.retry_count += 1;
303 self.last_error = Some(e.to_string());
304 self.health.error_count += 1;
305 self.health.last_error = Some(e.to_string());
306 Err(e)
307 }
308 }
309 }
310
311 fn calculate_retry_delay(&self) -> std::time::Duration {
312 let base_delay = self.recovery_config.initial_delay.as_millis() as f64;
313 let multiplier = self
314 .recovery_config
315 .backoff_multiplier
316 .powi(self.retry_count as i32);
317 let mut delay_ms = base_delay * multiplier;
318
319 if self.recovery_config.enable_jitter {
321 use rand::Rng;
322 let mut rng = rand::rng();
323 let jitter: f64 = rng.random_range(0.8..1.2);
324 delay_ms *= jitter;
325 }
326
327 let max_delay_ms = self.recovery_config.max_delay.as_millis() as f64;
329 delay_ms = delay_ms.min(max_delay_ms);
330
331 std::time::Duration::from_millis(delay_ms as u64)
332 }
333}
334
335#[async_trait]
336impl Transport for RecoveringTransport {
337 async fn send_message(&mut self, message: JsonRpcMessage) -> Result<()> {
338 loop {
339 match self.inner.send_message(message.clone()).await {
340 Ok(()) => {
341 self.health.messages_sent += 1;
342 self.health.last_activity = Some(std::time::SystemTime::now());
343 self.emit_event(TransportEvent::MessageSent).await;
344 return Ok(());
345 }
346 Err(e) => {
347 self.emit_event(TransportEvent::Error(e.to_string())).await;
348
349 if matches!(
351 e,
352 TransportError::ConnectionClosed | TransportError::ConnectionError { .. }
353 ) {
354 match self.attempt_recovery().await {
355 Ok(()) => continue, Err(recovery_err) => return Err(recovery_err),
357 }
358 } else {
359 return Err(e);
360 }
361 }
362 }
363 }
364 }
365
366 async fn receive_message(&mut self) -> Result<JsonRpcMessage> {
367 loop {
368 match self.inner.receive_message().await {
369 Ok(message) => {
370 self.health.messages_received += 1;
371 self.health.last_activity = Some(std::time::SystemTime::now());
372 self.emit_event(TransportEvent::MessageReceived).await;
373 return Ok(message);
374 }
375 Err(e) => {
376 self.emit_event(TransportEvent::Error(e.to_string())).await;
377
378 if matches!(
380 e,
381 TransportError::ConnectionClosed | TransportError::ConnectionError { .. }
382 ) {
383 match self.attempt_recovery().await {
384 Ok(()) => continue, Err(recovery_err) => return Err(recovery_err),
386 }
387 } else {
388 return Err(e);
389 }
390 }
391 }
392 }
393 }
394
395 async fn close(&mut self) -> Result<()> {
396 self.health.state = ConnectionState::ShuttingDown;
397 self.emit_event(TransportEvent::ShutdownRequested).await;
398
399 let result = self.inner.close().await;
400
401 self.health.state = ConnectionState::Disconnected;
402 self.emit_event(TransportEvent::ShutdownComplete).await;
403
404 result
405 }
406
407 fn get_state(&self) -> ConnectionState {
408 self.health.state.clone()
409 }
410
411 fn get_health(&self) -> TransportHealth {
412 self.health.clone()
413 }
414
415 async fn shutdown(&mut self, config: ShutdownConfig) -> Result<()> {
416 self.inner.shutdown(config).await
417 }
418
419 async fn force_shutdown(&mut self) -> Result<()> {
420 self.inner.force_shutdown().await
421 }
422
423 async fn reconnect(&mut self) -> Result<()> {
424 self.attempt_recovery().await
425 }
426
427 async fn reset(&mut self) -> Result<()> {
428 self.health = TransportHealth::default();
429 self.retry_count = 0;
430 self.last_error = None;
431 self.inner.reset().await
432 }
433}
434
435#[derive(Debug, Clone)]
437pub enum TransportConfig {
438 Stdio,
440
441 #[cfg(feature = "http")]
443 Streamable {
444 base_url: String,
445 auth_token: Option<String>,
446 session_id: Option<String>,
447 },
448}
449
450pub async fn create_transport(config: TransportConfig) -> Result<Box<dyn Transport>> {
452 match config {
453 TransportConfig::Stdio => {
454 let transport = stdio::StdioTransport::new().await?;
455 Ok(Box::new(transport))
456 }
457
458 #[cfg(feature = "http")]
459 TransportConfig::Streamable {
460 base_url,
461 auth_token,
462 session_id,
463 } => {
464 let client_config = streamable_http::client::StreamableHttpClientConfig {
465 base_url,
466 auth_token,
467 session_id,
468 auth_method: None,
469 ..Default::default()
470 };
471
472 let mut client = streamable_http::client::StreamableHttpClient::new(client_config)?;
473 client.connect().await?;
474 Ok(Box::new(client))
475 }
476 }
477}
478
479pub async fn create_recovering_transport(
481 config: TransportConfig,
482 recovery_config: RecoveryConfig,
483) -> Result<Box<dyn Transport>> {
484 let transport = create_transport(config).await?;
485 let recovering_transport = RecoveringTransport::new(transport, recovery_config);
486 Ok(Box::new(recovering_transport))
487}